From 7d268438f70b460f54e017b24882bfc561adcc1a Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Thu, 8 Dec 2022 07:29:54 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index ff69dc89..e32aedf8 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -519,8 +519,10 @@ class DaViT(nn.Module): return x def forward(self, x): - return self.forward_classifier(self, x) - + x = self.forward_classifier(self, x) + return x + + class DaViTFeatures(DaViT): def __init__(*args): @@ -528,8 +530,8 @@ class DaViTFeatures(DaViT): self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_inices', (1, 2, 3, 4))) def forward(self, x) -> List[Tensor]: - return self.forward_pyramid_features(self, x) - + x = self.forward_pyramid_features(self, x) + return x def checkpoint_filter_fn(state_dict, model):