diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index d5bf9692..1356c769 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -814,7 +814,7 @@ class MetaFormer(nn.Module): #x = self.norm(x) # (B, H, W, C) -> (B, C) #x = self.head(x) - x=self.head(self.norm(x.mean([1, 2]))) + x=self.head(self.norm(x.mean([2, 3]))) return x def forward_features(self, x):