diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 2e6a59af..e6c12d7f 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -721,7 +721,7 @@ class MetaFormer(nn.Module): if pre_logits: return x - x = x.mean([1,2]) # TODO use adaptive pool instead of mean + x = x.mean([-1,-2]) # TODO use adaptive pool instead of mean x = self.norm(x) # (B, H, W, C) -> (B, C) x = self.head(x)