Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent f6993dbf20
commit cd82d53149

@ -474,21 +474,7 @@ class DaViT(nn.Module):
return self.forward_classification return self.forward_classification
''' '''
@torch.jit.ignore
def _get_forward_fn(self):
if self._features_only == True:
return self.forward_features_full
else:
return self.forward_classification
@property
def features_only(self):
return self._features_only
@features_only.setter
def features_only(self, new_value : bool):
self._features_only = new_value
self.forward = self._get_forward_fn()
@ -613,6 +599,22 @@ class DaViT(nn.Module):
def forward(self, x): def forward(self, x):
return x return x
@torch.jit.ignore
def _get_forward_fn(self):
if self._features_only == True:
return self.forward_features_full
else:
return self.forward_classification
@property
def features_only(self):
return self._features_only
@features_only.setter
def features_only(self, new_value : bool):
self._features_only = new_value
self.forward = self._get_forward_fn()
def checkpoint_filter_fn(state_dict, model): def checkpoint_filter_fn(state_dict, model):
""" Remap MSFT checkpoints -> timm """ """ Remap MSFT checkpoints -> timm """

Loading…
Cancel
Save