Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 7dede5d9ac
commit 7d268438f7

@ -519,7 +519,9 @@ class DaViT(nn.Module):
return x return x
def forward(self, x): def forward(self, x):
return self.forward_classifier(self, x) x = self.forward_classifier(self, x)
return x
class DaViTFeatures(DaViT): class DaViTFeatures(DaViT):
@ -528,8 +530,8 @@ class DaViTFeatures(DaViT):
self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_inices', (1, 2, 3, 4))) self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_inices', (1, 2, 3, 4)))
def forward(self, x) -> List[Tensor]: def forward(self, x) -> List[Tensor]:
return self.forward_pyramid_features(self, x) x = self.forward_pyramid_features(self, x)
return x
def checkpoint_filter_fn(state_dict, model): def checkpoint_filter_fn(state_dict, model):

Loading…
Cancel
Save