|
|
|
@ -474,7 +474,21 @@ class DaViT(nn.Module):
|
|
|
|
|
return self.forward_classification
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
|
def _get_forward_fn(self):
|
|
|
|
|
if self._features_only == True:
|
|
|
|
|
return self.forward_features_full
|
|
|
|
|
else:
|
|
|
|
|
return self.forward_classification
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def features_only(self):
|
|
|
|
|
return self._features_only
|
|
|
|
|
|
|
|
|
|
@features_only.setter
|
|
|
|
|
def features_only(self, new_value : bool):
|
|
|
|
|
self._features_only = new_value
|
|
|
|
|
self.forward = self._get_forward_fn()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -577,7 +591,7 @@ class DaViT(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# non-normalized pyramid features + corresponding sizes
|
|
|
|
|
return features, sizes
|
|
|
|
|
return features[1:], sizes[:-1]
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
|
x, sizes = self.forward_features_full(x)
|
|
|
|
@ -599,22 +613,6 @@ class DaViT(nn.Module):
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
|
def _get_forward_fn(self):
|
|
|
|
|
if self._features_only == True:
|
|
|
|
|
return self.forward_features_full
|
|
|
|
|
else:
|
|
|
|
|
return self.forward_classification
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def features_only(self):
|
|
|
|
|
return self._features_only
|
|
|
|
|
|
|
|
|
|
@features_only.setter
|
|
|
|
|
def features_only(self, new_value : bool):
|
|
|
|
|
self._features_only = new_value
|
|
|
|
|
self.forward = self._get_forward_fn()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def checkpoint_filter_fn(state_dict, model):
|
|
|
|
|
""" Remap MSFT checkpoints -> timm """
|
|
|
|
|