Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent cd82d53149
commit 15b168e305

@ -474,7 +474,21 @@ class DaViT(nn.Module):
return self.forward_classification return self.forward_classification
''' '''
@torch.jit.ignore
def _get_forward_fn(self):
if self._features_only == True:
return self.forward_features_full
else:
return self.forward_classification
@property
def features_only(self):
return self._features_only
@features_only.setter
def features_only(self, new_value : bool):
self._features_only = new_value
self.forward = self._get_forward_fn()
@ -577,7 +591,7 @@ class DaViT(nn.Module):
# non-normalized pyramid features + corresponding sizes # non-normalized pyramid features + corresponding sizes
return features, sizes return features[1:], sizes[:-1]
def forward_features(self, x): def forward_features(self, x):
x, sizes = self.forward_features_full(x) x, sizes = self.forward_features_full(x)
@ -598,22 +612,6 @@ class DaViT(nn.Module):
def forward(self, x): def forward(self, x):
return x return x
@torch.jit.ignore
def _get_forward_fn(self):
if self._features_only == True:
return self.forward_features_full
else:
return self.forward_classification
@property
def features_only(self):
return self._features_only
@features_only.setter
def features_only(self, new_value : bool):
self._features_only = new_value
self.forward = self._get_forward_fn()
def checkpoint_filter_fn(state_dict, model): def checkpoint_filter_fn(state_dict, model):

Loading…
Cancel
Save