|
|
@ -473,13 +473,20 @@ class DaViT(nn.Module):
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
return self.forward_classification
|
|
|
|
return self.forward_classification
|
|
|
|
'''
|
|
|
|
'''
|
|
|
|
|
|
|
|
'''
|
|
|
|
@torch.jit.ignore
|
|
|
|
@torch.jit.ignore
|
|
|
|
def _get_forward_fn(self):
|
|
|
|
def _get_forward_fn(self):
|
|
|
|
if self._features_only == True:
|
|
|
|
if self._features_only == True:
|
|
|
|
return self.forward_features_full
|
|
|
|
return self.forward_features_full
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
return self.forward_classification
|
|
|
|
return self.forward_classification
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _update_forward_fn(self):
|
|
|
|
|
|
|
|
if self._features_only == True:
|
|
|
|
|
|
|
|
self.forward = self.forward_features_full
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
self.forward = self.forward_classification
|
|
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
@property
|
|
|
|
def features_only(self):
|
|
|
|
def features_only(self):
|
|
|
|