|
|
@ -577,9 +577,9 @@ class DaViT(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
if self.features_only == True:
|
|
|
|
if self.features_only == True:
|
|
|
|
return forward_features_full(x)
|
|
|
|
return self.forward_features_full(x)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
return forward(x)
|
|
|
|
return self.forward_classification(x)
|
|
|
|
|
|
|
|
|
|
|
|
def checkpoint_filter_fn(state_dict, model):
|
|
|
|
def checkpoint_filter_fn(state_dict, model):
|
|
|
|
""" Remap MSFT checkpoints -> timm """
|
|
|
|
""" Remap MSFT checkpoints -> timm """
|
|
|
|