diff --git a/timm/models/davit.py b/timm/models/davit.py index e31b5e58..97395df1 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -473,13 +473,20 @@ class DaViT(nn.Module): else: return self.forward_classification ''' - + ''' @torch.jit.ignore def _get_forward_fn(self): if self._features_only == True: return self.forward_features_full else: return self.forward_classification + ''' + + def _update_forward_fn(self): + if self._features_only == True: + self.forward = self.forward_features_full + else: + self.forward = self.forward_classification @property def features_only(self):