diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 656f6294..99b9b53f 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -601,10 +601,11 @@ class MetaFormer(nn.Module): res_scale_init_value=res_scale_init_values[i], ) for j in range(depths[i])] ) + stages.append(downsample_layers[i]) stages.append(stage) cur += depths[i] - self.stages = nn.Sequential(zip(*downsample_layers, *stages)) + self.stages = nn.Sequential(*stages) self.norm = output_norm(dims[-1])