|
|
|
@ -457,8 +457,9 @@ class DaViT(nn.Module):
|
|
|
|
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
|
self._update_forward_fn()
|
|
|
|
|
|
|
|
|
|
self.forward = self._get_forward_fn()
|
|
|
|
|
#self.forward = self._get_forward_fn()
|
|
|
|
|
'''
|
|
|
|
|
if self._features_only == True:
|
|
|
|
|
self.forward = self.forward_features_full
|
|
|
|
@ -495,7 +496,8 @@ class DaViT(nn.Module):
|
|
|
|
|
@features_only.setter
|
|
|
|
|
def features_only(self, new_value : bool):
|
|
|
|
|
self._features_only = new_value
|
|
|
|
|
self.forward = self._get_forward_fn()
|
|
|
|
|
#self.forward = self._get_forward_fn()
|
|
|
|
|
self._update_forward_fn()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|