Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent dbf38cd45b
commit 0de8313183

@ -457,8 +457,9 @@ class DaViT(nn.Module):
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
self.apply(self._init_weights) self.apply(self._init_weights)
self._update_forward_fn()
self.forward = self._get_forward_fn() #self.forward = self._get_forward_fn()
''' '''
if self._features_only == True: if self._features_only == True:
self.forward = self.forward_features_full self.forward = self.forward_features_full
@ -495,7 +496,8 @@ class DaViT(nn.Module):
@features_only.setter @features_only.setter
def features_only(self, new_value : bool): def features_only(self, new_value : bool):
self._features_only = new_value self._features_only = new_value
self.forward = self._get_forward_fn() #self.forward = self._get_forward_fn()
self._update_forward_fn()

Loading…
Cancel
Save