diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index be03eec3..e0189720 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -527,7 +527,7 @@ class MetaFormerBlock(nn.Module): super().__init__() - self.downsample = nn.Identity() + self.downsample = downsample self.norm1 = norm_layer(dim) self.token_mixer = token_mixer(dim=dim, drop=drop)