From 7dede5d9ac0fc9396de5ef68c8f4074f56f5e09a Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Thu, 8 Dec 2022 07:28:32 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index e3d9d719..ff69dc89 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -493,6 +493,14 @@ class DaViT(nn.Module): # non-normalized pyramid features + corresponding sizes return features, sizes + def forward_pyramid_features(self, x) -> List[Tensor]: + x, sizes = self.forward_network(x) + outs = [] + for i, out in enumerate(x): + H, W = sizes[i] + outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous()) + + return outs def forward_features(self, x): x, sizes = self.forward_network(x) @@ -505,10 +513,13 @@ class DaViT(nn.Module): def forward_head(self, x, pre_logits: bool = False): return self.head(x, pre_logits=pre_logits) - def forward(self, x): + def forward_classifier(self, x): x = self.forward_features(x) x = self.forward_head(x) return x + + def forward(self, x): + return self.forward_classifier(self, x) class DaViTFeatures(DaViT): @@ -517,13 +528,8 @@ class DaViTFeatures(DaViT): self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_inices', (1, 2, 3, 4))) def forward(self, x) -> List[Tensor]: - x, sizes = self.forward_network(x) - outs = [] - for i, out in enumerate(x): - H, W = sizes[i] - outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous()) - - return outs + return self.forward_pyramid_features(self, x) + def checkpoint_filter_fn(state_dict, model):