diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 1a074944..92a8851f 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -749,6 +749,7 @@ class MetaFormer(nn.Module): return x x = self.global_pool(x) + x = x.flatten() x = self.norm(x) # (B, H, W, C) -> (B, C) x = self.head(x)