diff --git a/timm/models/davit.py b/timm/models/davit.py index fd13161b..a404a4b4 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -478,7 +478,7 @@ class DaViT(nn.Module): return self._features_only @features_only.setter - def features_only(self, new_value): + def features_only(self, new_value : bool): self._features_only = new_value self.forward = self._get_forward_fn()