Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 8f32133dbb
commit 007be8319b

@ -577,9 +577,9 @@ class DaViT(nn.Module):
def forward(self, x):
if self.features_only == True:
return forward_features_full(x)
return self.forward_features_full(x)
else:
return forward(x)
return self.forward_classification(x)
def checkpoint_filter_fn(state_dict, model):
""" Remap MSFT checkpoints -> timm """

Loading…
Cancel
Save