diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index dc83c185..87018b47 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -809,12 +809,12 @@ class MetaFormer(nn.Module): if pre_logits: return x - #x = self.global_pool(x) - #x = x.squeeze() - #x = self.norm(x) + x = self.global_pool(x) + x = x.squeeze() + x = self.norm(x) # (B, H, W, C) -> (B, C) - #x = self.head(x) - x=self.head(self.norm(x.mean([2, 3]))) + x = self.head(x) + #x=self.head(self.norm(x.mean([2, 3]))) return x def forward_features(self, x):