diff --git a/timm/models/davit.py b/timm/models/davit.py index 0e2ea4cc..9d2cf85f 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -399,7 +399,7 @@ class DaViT(nn.Module): self.num_features = embed_dims[-1] self.drop_rate=drop_rate self.grad_checkpointing = False - self.features_only = False + self._features_only = False self.patch_embeds = nn.ModuleList([ PatchEmbed(patch_size=patch_size if i == 0 else 2, @@ -458,12 +458,32 @@ class DaViT(nn.Module): self.apply(self._init_weights) - - if self.features_only == True: + self.forward = _get_forward_fn() + ''' + if self._features_only == True: self.forward = self.forward_features_full else: self.forward = self.forward_classification - + ''' + + + def _get_forward_fn(self): + if self._features_only == True: + return self.forward_features_full + else: + return self.forward_classification + + @property + def features_only(self): + return self._features_only + + @value.setter + def features_only(self, new_value): + self._features_only = new_value + self.forward = _get_forward_fn() + + + def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02)