Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 8f32133dbb
commit 007be8319b

@ -577,9 +577,9 @@ class DaViT(nn.Module):
def forward(self, x): def forward(self, x):
if self.features_only == True: if self.features_only == True:
return forward_features_full(x) return self.forward_features_full(x)
else: else:
return forward(x) return self.forward_classification(x)
def checkpoint_filter_fn(state_dict, model): def checkpoint_filter_fn(state_dict, model):
""" Remap MSFT checkpoints -> timm """ """ Remap MSFT checkpoints -> timm """

Loading…
Cancel
Save