diff --git a/timm/models/davit.py b/timm/models/davit.py index f0fc52da..fd13161b 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -458,7 +458,7 @@ class DaViT(nn.Module): self.apply(self._init_weights) - self.forward = _get_forward_fn() + self.forward = self._get_forward_fn() ''' if self._features_only == True: self.forward = self.forward_features_full @@ -480,7 +480,7 @@ class DaViT(nn.Module): @features_only.setter def features_only(self, new_value): self._features_only = new_value - self.forward = _get_forward_fn() + self.forward = self._get_forward_fn()