Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent d1e68fa33f
commit c42e26f866

@ -458,7 +458,7 @@ class DaViT(nn.Module):
self.apply(self._init_weights) self.apply(self._init_weights)
self.forward = self._get_forward_fn() self._update_forward_fn()
''' '''
if self._features_only == True: if self._features_only == True:
self.forward = self.forward_features_full self.forward = self.forward_features_full
@ -466,7 +466,15 @@ class DaViT(nn.Module):
self.forward = self.forward_classification self.forward = self.forward_classification
''' '''
'''
def _get_forward_fn(self):
if self._features_only == True:
return self.forward_features_full
else:
return self.forward_classification
'''
@torch.jit.ignore
def _get_forward_fn(self): def _get_forward_fn(self):
if self._features_only == True: if self._features_only == True:
return self.forward_features_full return self.forward_features_full

Loading…
Cancel
Save