diff --git a/timm/models/davit.py b/timm/models/davit.py index e32aedf8..93ef42c9 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -519,8 +519,7 @@ class DaViT(nn.Module): return x def forward(self, x): - x = self.forward_classifier(self, x) - return x + return self.forward_classifier(x) class DaViTFeatures(DaViT): @@ -530,8 +529,7 @@ class DaViTFeatures(DaViT): self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_inices', (1, 2, 3, 4))) def forward(self, x) -> List[Tensor]: - x = self.forward_pyramid_features(self, x) - return x + return self.forward_pyramid_features(x) def checkpoint_filter_fn(state_dict, model):