diff --git a/timm/models/davit.py b/timm/models/davit.py index 509fd93f..45d60a67 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -534,7 +534,7 @@ class DaViT(nn.Module): self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) def forward_features(self, x): - x = self.patch_embed(x) + x = self.stem(x) x = self.stages(x) x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) return x