diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 138e5030..be0c9a66 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -17,7 +17,6 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_module from .helpers import named_apply, build_model_with_cfg, checkpoint_seq from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, create_conv2d from .registry import register_model @@ -124,7 +123,6 @@ class ConvNeXtBlock(nn.Module): norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6) mlp_layer = ConvMlp if conv_mlp else Mlp self.use_conv_mlp = conv_mlp - self.shortcut_after_dw = stride > 1 self.conv_dw = create_conv2d(dim, dim_out, kernel_size=7, stride=stride, depthwise=True, bias=conv_bias) self.norm = norm_layer(dim_out) @@ -135,9 +133,6 @@ class ConvNeXtBlock(nn.Module): def forward(self, x): shortcut = x x = self.conv_dw(x) - if self.shortcut_after_dw: - shortcut = x - if self.use_conv_mlp: x = self.norm(x) x = self.mlp(x) @@ -150,7 +145,6 @@ class ConvNeXtBlock(nn.Module): x = x.mul(self.gamma.reshape(1, -1, 1, 1)) x = self.drop_path(x) + shortcut - #print('b', x.shape) return x @@ -164,7 +158,6 @@ class ConvNeXtStage(nn.Module): depth=2, drop_path_rates=None, ls_init_value=1.0, - downsample_block=False, conv_mlp=False, conv_bias=True, norm_layer=None, @@ -173,14 +166,14 @@ class ConvNeXtStage(nn.Module): super().__init__() self.grad_checkpointing = False - if downsample_block or (in_chs == out_chs and stride == 1): - self.downsample = nn.Identity() - else: + if in_chs != out_chs or stride > 1: self.downsample = nn.Sequential( norm_layer(in_chs), nn.Conv2d(in_chs, out_chs, kernel_size=stride, stride=stride, bias=conv_bias), ) in_chs = out_chs + else: + self.downsample = nn.Identity() drop_path_rates = drop_path_rates or [0.] * depth stage_blocks = [] @@ -188,7 +181,6 @@ class ConvNeXtStage(nn.Module): stage_blocks.append(ConvNeXtBlock( dim=in_chs, dim_out=out_chs, - stride=stride if downsample_block and i == 0 else 1, drop_path=drop_path_rates[i], ls_init_value=ls_init_value, conv_mlp=conv_mlp, @@ -236,7 +228,6 @@ class ConvNeXt(nn.Module): stem_stride=4, head_init_scale=1., head_norm_first=False, - downsample_block=False, conv_mlp=False, conv_bias=True, norm_layer=None, @@ -291,7 +282,6 @@ class ConvNeXt(nn.Module): depth=depths[i], drop_path_rates=dp_rates[i], ls_init_value=ls_init_value, - downsample_block=downsample_block, conv_mlp=conv_mlp, conv_bias=conv_bias, norm_layer=norm_layer, @@ -418,7 +408,7 @@ def convnext_nano_hnf(pretrained=False, **kwargs): @register_model def convnext_nano_ols(pretrained=False, **kwargs): model_args = dict( - depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), downsample_block=True, + depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), head_norm_first=True, conv_mlp=True, conv_bias=False, stem_type='overlap', stem_kernel_size=9, **kwargs) model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **model_args) return model @@ -426,7 +416,8 @@ def convnext_nano_ols(pretrained=False, **kwargs): @register_model def convnext_tiny_hnf(pretrained=False, **kwargs): - model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs) + model_args = dict( + depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs) model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args) return model