diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index e0189720..8098958c 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -521,14 +521,11 @@ class MetaFormerBlock(nn.Module): norm_layer=nn.LayerNorm, drop=0., drop_path=0., layer_scale_init_value=None, - res_scale_init_value=None, - downsample = nn.Identity() + res_scale_init_value=None ): super().__init__() - - self.downsample = downsample - + self.norm1 = norm_layer(dim) self.token_mixer = token_mixer(dim=dim, drop=drop) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -546,7 +543,6 @@ class MetaFormerBlock(nn.Module): if res_scale_init_value else nn.Identity() def forward(self, x): - x = self.downsample(x) x = self.res_scale1(x) + \ self.layer_scale1( self.drop_path1( @@ -653,18 +649,19 @@ class MetaFormer(nn.Module): stages = nn.ModuleList() # each stage consists of multiple metaformer blocks cur = 0 for i in range(num_stage): - stage = nn.Sequential(*[MetaFormerBlock( - dim=dims[i], - token_mixer=token_mixers[i], - mlp=mlps[i], - norm_layer=norm_layers[i], - drop_path=dp_rates[cur + j], - layer_scale_init_value=layer_scale_init_values[i], - res_scale_init_value=res_scale_init_values[i], - downsample = downsample_layers[i] - ) for j in range(depths[i])] + stage = nn.Sequential(OrderedDict[ + ('downsample', downsample_layers[i]), + ('blocks', nn.Sequential(*[MetaFormerBlock( + dim=dims[i], + token_mixer=token_mixers[i], + mlp=mlps[i], + norm_layer=norm_layers[i], + drop_path=dp_rates[cur + j], + layer_scale_init_value=layer_scale_init_values[i], + res_scale_init_value=res_scale_init_values[i] + ) for j in range(depths[i])]) + )] ) - stages.append(stage) cur += depths[i]