diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 7a90b448..1cb13141 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -573,7 +573,7 @@ class MetaFormer(nn.Module): return x if pre_logits else self.head.fc(x) def forward_features(self, x): - x = self.patch_embed(x) + x = self.stem(x) x = self.stages(x) x = self.norm_pre(x) return x