From 3b993a9301f82e49f50fa226c20fe7ffabbc50ba Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Thu, 8 Dec 2022 00:37:14 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index 9fb8fdea..b9bcc8b9 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -383,6 +383,7 @@ class DaViT(nn.Module): img_size=224, num_classes=1000, global_pool='avg' + features_only = False ): super().__init__() @@ -398,7 +399,7 @@ class DaViT(nn.Module): self.num_features = embed_dims[-1] self.drop_rate=drop_rate self.grad_checkpointing = False - + self.features_only = False self.patch_embeds = nn.ModuleList([ PatchEmbed(patch_size=patch_size if i == 0 else 2, @@ -567,13 +568,18 @@ class DaViT(nn.Module): return x def forward_head(self, x, pre_logits: bool = False): - return self.head(x, pre_logits=pre_logits) - def forward(self, x): + def forward_classification(self, x): x = self.forward_features(x) x = self.forward_head(x) return x + + def forward(self, x): + if self.features_only == True: + return forward_features_full(x) + else: + return forward(x) def checkpoint_filter_fn(state_dict, model): """ Remap MSFT checkpoints -> timm """