|
|
|
@ -399,7 +399,7 @@ class DaViT(nn.Module):
|
|
|
|
|
self.num_features = embed_dims[-1]
|
|
|
|
|
self.drop_rate=drop_rate
|
|
|
|
|
self.grad_checkpointing = False
|
|
|
|
|
self.features_only = False
|
|
|
|
|
self._features_only = False
|
|
|
|
|
|
|
|
|
|
self.patch_embeds = nn.ModuleList([
|
|
|
|
|
PatchEmbed(patch_size=patch_size if i == 0 else 2,
|
|
|
|
@ -458,11 +458,31 @@ class DaViT(nn.Module):
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.features_only == True:
|
|
|
|
|
self.forward = _get_forward_fn()
|
|
|
|
|
'''
|
|
|
|
|
if self._features_only == True:
|
|
|
|
|
self.forward = self.forward_features_full
|
|
|
|
|
else:
|
|
|
|
|
self.forward = self.forward_classification
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_forward_fn(self):
|
|
|
|
|
if self._features_only == True:
|
|
|
|
|
return self.forward_features_full
|
|
|
|
|
else:
|
|
|
|
|
return self.forward_classification
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def features_only(self):
|
|
|
|
|
return self._features_only
|
|
|
|
|
|
|
|
|
|
@value.setter
|
|
|
|
|
def features_only(self, new_value):
|
|
|
|
|
self._features_only = new_value
|
|
|
|
|
self.forward = _get_forward_fn()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _init_weights(self, m):
|
|
|
|
|
if isinstance(m, nn.Linear):
|
|
|
|
|