diff --git a/timm/models/davit.py b/timm/models/davit.py index 97395df1..917fc738 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -457,8 +457,9 @@ class DaViT(nn.Module): self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) self.apply(self._init_weights) + self._update_forward_fn() - self.forward = self._get_forward_fn() + #self.forward = self._get_forward_fn() ''' if self._features_only == True: self.forward = self.forward_features_full @@ -495,7 +496,8 @@ class DaViT(nn.Module): @features_only.setter def features_only(self, new_value : bool): self._features_only = new_value - self.forward = self._get_forward_fn() + #self.forward = self._get_forward_fn() + self._update_forward_fn()