diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index ab1b0f0c..257d4e74 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -568,7 +568,7 @@ class MetaFormer(nn.Module): if not isinstance(downsample_layers, (list, tuple)): downsample_layers = [downsample_layers] * num_stage down_dims = [in_chans] + dims - self.downsample_layers = nn.ModuleList( + downsample_layers = nn.ModuleList( [downsample_layers[i](down_dims[i], down_dims[i+1]) for i in range(num_stage)] ) @@ -604,7 +604,7 @@ class MetaFormer(nn.Module): self.stages.append(stage) cur += depths[i] - self.stages = nn.Sequential(*stages) + self.stages = nn.Sequential(zip(*downsample_layers, *stages)) self.norm = output_norm(dims[-1])