diff --git a/timm/models/davit.py b/timm/models/davit.py index bde025cb..bf46917a 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -458,13 +458,13 @@ class DaViT(nn.Module): self.apply(self._init_weights) - self.forward = self._get_forward_fn() - ''' + #self.forward = self._get_forward_fn() + if self._features_only == True: self.forward = self.forward_features_full else: self.forward = self.forward_classification - ''' + ''' def _get_forward_fn(self):