Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent f6993dbf20
commit cd82d53149

@ -474,21 +474,7 @@ class DaViT(nn.Module):
return self.forward_classification
'''
@torch.jit.ignore
def _get_forward_fn(self):
if self._features_only == True:
return self.forward_features_full
else:
return self.forward_classification
@property
def features_only(self):
return self._features_only
@features_only.setter
def features_only(self, new_value : bool):
self._features_only = new_value
self.forward = self._get_forward_fn()
@ -612,6 +598,22 @@ class DaViT(nn.Module):
def forward(self, x):
return x
@torch.jit.ignore
def _get_forward_fn(self):
if self._features_only == True:
return self.forward_features_full
else:
return self.forward_classification
@property
def features_only(self):
return self._features_only
@features_only.setter
def features_only(self, new_value : bool):
self._features_only = new_value
self.forward = self._get_forward_fn()
def checkpoint_filter_fn(state_dict, model):

Loading…
Cancel
Save