From f6993dbf20a47587f321295d1febbfa59cd3af82 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Thu, 8 Dec 2022 01:23:10 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index bf46917a..15674187 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -458,13 +458,13 @@ class DaViT(nn.Module): self.apply(self._init_weights) - #self.forward = self._get_forward_fn() - + self.forward = self._get_forward_fn() + ''' if self._features_only == True: self.forward = self.forward_features_full else: self.forward = self.forward_classification - + ''' ''' def _get_forward_fn(self): @@ -609,13 +609,10 @@ class DaViT(nn.Module): x = self.forward_features(x) x = self.forward_head(x) return x - ''' + def forward(self, x): - if self.features_only == True: - return self.forward_features_full(x) - else: - return self.forward_classification(x) - ''' + return x + def checkpoint_filter_fn(state_dict, model): """ Remap MSFT checkpoints -> timm """