diff --git a/timm/models/convnext.py b/timm/models/convnext.py index be0c9a66..379937e0 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -109,6 +109,7 @@ class ConvNeXtBlock(nn.Module): dim, dim_out=None, stride=1, + dilation=1, mlp_ratio=4, conv_mlp=False, conv_bias=True, @@ -124,7 +125,8 @@ class ConvNeXtBlock(nn.Module): mlp_layer = ConvMlp if conv_mlp else Mlp self.use_conv_mlp = conv_mlp - self.conv_dw = create_conv2d(dim, dim_out, kernel_size=7, stride=stride, depthwise=True, bias=conv_bias) + self.conv_dw = create_conv2d( + dim, dim_out, kernel_size=7, stride=stride, dilation=dilation, depthwise=True, bias=conv_bias) self.norm = norm_layer(dim_out) self.mlp = mlp_layer(dim_out, int(mlp_ratio * dim_out), act_layer=act_layer) self.gamma = nn.Parameter(ls_init_value * torch.ones(dim_out)) if ls_init_value > 0 else None @@ -156,6 +158,7 @@ class ConvNeXtStage(nn.Module): out_chs, stride=2, depth=2, + dilation=(1, 1), drop_path_rates=None, ls_init_value=1.0, conv_mlp=False, @@ -166,10 +169,14 @@ class ConvNeXtStage(nn.Module): super().__init__() self.grad_checkpointing = False - if in_chs != out_chs or stride > 1: + if in_chs != out_chs or stride > 1 or dilation[0] != dilation[1]: + ds_ks = 2 if stride > 1 or dilation[0] != dilation[1] else 1 + pad = 'same' if dilation[1] > 1 else 0 # same padding needed if dilation used self.downsample = nn.Sequential( norm_layer(in_chs), - nn.Conv2d(in_chs, out_chs, kernel_size=stride, stride=stride, bias=conv_bias), + create_conv2d( + in_chs, out_chs, kernel_size=ds_ks, stride=stride, + dilation=dilation[0], padding=pad, bias=conv_bias), ) in_chs = out_chs else: @@ -181,6 +188,7 @@ class ConvNeXtStage(nn.Module): stage_blocks.append(ConvNeXtBlock( dim=in_chs, dim_out=out_chs, + dilation=dilation[1], drop_path=drop_path_rates[i], ls_init_value=ls_init_value, conv_mlp=conv_mlp, @@ -235,7 +243,7 @@ class ConvNeXt(nn.Module): drop_path_rate=0., ): super().__init__() - assert output_stride == 32 + assert output_stride in (8, 16, 32) if norm_layer is None: norm_layer = partial(LayerNorm2d, eps=1e-6) norm_layer_cl = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6) @@ -263,22 +271,27 @@ class ConvNeXt(nn.Module): padding=stem_kernel_size // 2, bias=conv_bias), norm_layer(dims[0]), ) - prev_chs = dims[0] - curr_stride = stem_stride self.stages = nn.Sequential() dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] stages = [] + prev_chs = dims[0] + curr_stride = stem_stride + dilation = 1 # 4 feature resolution stages, each consisting of multiple residual blocks for i in range(4): stride = 2 if curr_stride == 2 or i > 0 else 1 - # FIXME support dilation / output_stride + if curr_stride >= output_stride and stride > 1: + dilation *= stride + stride = 1 curr_stride *= stride + first_dilation = 1 if dilation in (1, 2) else 2 out_chs = dims[i] stages.append(ConvNeXtStage( prev_chs, out_chs, stride=stride, + dilation=(first_dilation, dilation), depth=depths[i], drop_path_rates=dp_rates[i], ls_init_value=ls_init_value,