Remove experimental downsample in block support in ConvNeXt. Experiment further before keeping it in.

pull/1327/head
Ross Wightman 2 years ago
parent bfc0dccb0e
commit 06307b8b41

@ -17,7 +17,6 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_module
from .helpers import named_apply, build_model_with_cfg, checkpoint_seq from .helpers import named_apply, build_model_with_cfg, checkpoint_seq
from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, create_conv2d from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, create_conv2d
from .registry import register_model from .registry import register_model
@ -124,7 +123,6 @@ class ConvNeXtBlock(nn.Module):
norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6) norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
mlp_layer = ConvMlp if conv_mlp else Mlp mlp_layer = ConvMlp if conv_mlp else Mlp
self.use_conv_mlp = conv_mlp self.use_conv_mlp = conv_mlp
self.shortcut_after_dw = stride > 1
self.conv_dw = create_conv2d(dim, dim_out, kernel_size=7, stride=stride, depthwise=True, bias=conv_bias) self.conv_dw = create_conv2d(dim, dim_out, kernel_size=7, stride=stride, depthwise=True, bias=conv_bias)
self.norm = norm_layer(dim_out) self.norm = norm_layer(dim_out)
@ -135,9 +133,6 @@ class ConvNeXtBlock(nn.Module):
def forward(self, x): def forward(self, x):
shortcut = x shortcut = x
x = self.conv_dw(x) x = self.conv_dw(x)
if self.shortcut_after_dw:
shortcut = x
if self.use_conv_mlp: if self.use_conv_mlp:
x = self.norm(x) x = self.norm(x)
x = self.mlp(x) x = self.mlp(x)
@ -150,7 +145,6 @@ class ConvNeXtBlock(nn.Module):
x = x.mul(self.gamma.reshape(1, -1, 1, 1)) x = x.mul(self.gamma.reshape(1, -1, 1, 1))
x = self.drop_path(x) + shortcut x = self.drop_path(x) + shortcut
#print('b', x.shape)
return x return x
@ -164,7 +158,6 @@ class ConvNeXtStage(nn.Module):
depth=2, depth=2,
drop_path_rates=None, drop_path_rates=None,
ls_init_value=1.0, ls_init_value=1.0,
downsample_block=False,
conv_mlp=False, conv_mlp=False,
conv_bias=True, conv_bias=True,
norm_layer=None, norm_layer=None,
@ -173,14 +166,14 @@ class ConvNeXtStage(nn.Module):
super().__init__() super().__init__()
self.grad_checkpointing = False self.grad_checkpointing = False
if downsample_block or (in_chs == out_chs and stride == 1): if in_chs != out_chs or stride > 1:
self.downsample = nn.Identity()
else:
self.downsample = nn.Sequential( self.downsample = nn.Sequential(
norm_layer(in_chs), norm_layer(in_chs),
nn.Conv2d(in_chs, out_chs, kernel_size=stride, stride=stride, bias=conv_bias), nn.Conv2d(in_chs, out_chs, kernel_size=stride, stride=stride, bias=conv_bias),
) )
in_chs = out_chs in_chs = out_chs
else:
self.downsample = nn.Identity()
drop_path_rates = drop_path_rates or [0.] * depth drop_path_rates = drop_path_rates or [0.] * depth
stage_blocks = [] stage_blocks = []
@ -188,7 +181,6 @@ class ConvNeXtStage(nn.Module):
stage_blocks.append(ConvNeXtBlock( stage_blocks.append(ConvNeXtBlock(
dim=in_chs, dim=in_chs,
dim_out=out_chs, dim_out=out_chs,
stride=stride if downsample_block and i == 0 else 1,
drop_path=drop_path_rates[i], drop_path=drop_path_rates[i],
ls_init_value=ls_init_value, ls_init_value=ls_init_value,
conv_mlp=conv_mlp, conv_mlp=conv_mlp,
@ -236,7 +228,6 @@ class ConvNeXt(nn.Module):
stem_stride=4, stem_stride=4,
head_init_scale=1., head_init_scale=1.,
head_norm_first=False, head_norm_first=False,
downsample_block=False,
conv_mlp=False, conv_mlp=False,
conv_bias=True, conv_bias=True,
norm_layer=None, norm_layer=None,
@ -291,7 +282,6 @@ class ConvNeXt(nn.Module):
depth=depths[i], depth=depths[i],
drop_path_rates=dp_rates[i], drop_path_rates=dp_rates[i],
ls_init_value=ls_init_value, ls_init_value=ls_init_value,
downsample_block=downsample_block,
conv_mlp=conv_mlp, conv_mlp=conv_mlp,
conv_bias=conv_bias, conv_bias=conv_bias,
norm_layer=norm_layer, norm_layer=norm_layer,
@ -418,7 +408,7 @@ def convnext_nano_hnf(pretrained=False, **kwargs):
@register_model @register_model
def convnext_nano_ols(pretrained=False, **kwargs): def convnext_nano_ols(pretrained=False, **kwargs):
model_args = dict( model_args = dict(
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), downsample_block=True, depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), head_norm_first=True, conv_mlp=True,
conv_bias=False, stem_type='overlap', stem_kernel_size=9, **kwargs) conv_bias=False, stem_type='overlap', stem_kernel_size=9, **kwargs)
model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **model_args) model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **model_args)
return model return model
@ -426,7 +416,8 @@ def convnext_nano_ols(pretrained=False, **kwargs):
@register_model @register_model
def convnext_tiny_hnf(pretrained=False, **kwargs): def convnext_tiny_hnf(pretrained=False, **kwargs):
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs) model_args = dict(
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs)
model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args) model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args)
return model return model

Loading…
Cancel
Save