Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 15b168e305
commit dbf38cd45b

@ -473,13 +473,20 @@ class DaViT(nn.Module):
else: else:
return self.forward_classification return self.forward_classification
''' '''
'''
@torch.jit.ignore @torch.jit.ignore
def _get_forward_fn(self): def _get_forward_fn(self):
if self._features_only == True: if self._features_only == True:
return self.forward_features_full return self.forward_features_full
else: else:
return self.forward_classification return self.forward_classification
'''
def _update_forward_fn(self):
if self._features_only == True:
self.forward = self.forward_features_full
else:
self.forward = self.forward_classification
@property @property
def features_only(self): def features_only(self):

Loading…
Cancel
Save