From 889c2ca728015ef1735fa161c9d4e229f9334a72 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Thu, 8 Dec 2022 00:55:26 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index 0e2ea4cc..9d2cf85f 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -399,7 +399,7 @@ class DaViT(nn.Module): self.num_features = embed_dims[-1] self.drop_rate=drop_rate self.grad_checkpointing = False - self.features_only = False + self._features_only = False self.patch_embeds = nn.ModuleList([ PatchEmbed(patch_size=patch_size if i == 0 else 2, @@ -458,12 +458,32 @@ class DaViT(nn.Module): self.apply(self._init_weights) - - if self.features_only == True: + self.forward = _get_forward_fn() + ''' + if self._features_only == True: self.forward = self.forward_features_full else: self.forward = self.forward_classification - + ''' + + + def _get_forward_fn(self): + if self._features_only == True: + return self.forward_features_full + else: + return self.forward_classification + + @property + def features_only(self): + return self._features_only + + @value.setter + def features_only(self, new_value): + self._features_only = new_value + self.forward = _get_forward_fn() + + + def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02)