diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 92a8851f..82e44579 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -749,7 +749,7 @@ class MetaFormer(nn.Module): return x x = self.global_pool(x) - x = x.flatten() + x = x.flatten(1) x = self.norm(x) # (B, H, W, C) -> (B, C) x = self.head(x)