From e7c5ab9d1e487740dafa3d90eee5c8e4b70f1a7c Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Thu, 8 Dec 2022 01:57:46 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index e03d25a6..d7e29fe5 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -485,7 +485,7 @@ class DaViT(nn.Module): @torch.jit.ignore def _update_forward_fn(self): if self._features_only == True: - self.forward = self.forward_features_full + self.forward = self.forward_pyramid_features else: self.forward = self.forward_classification @@ -527,7 +527,7 @@ class DaViT(nn.Module): self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) - def forward_features_full(self, x): + def forward_network(self, x): #x, size = self.patch_embeds[0](x, (x.size(2), x.size(3))) size: Tuple[int, int] = (x.size(2), x.size(3)) features = [x] @@ -601,9 +601,18 @@ class DaViT(nn.Module): # non-normalized pyramid features + corresponding sizes return features[1:], sizes[:-1] + + def forward_pyramid_features(self, x): + x, sizes = self.forward_network(x) + outs = [] + for i, out in enumerate(x): + H, W = sizes[i] + outs.append(x.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous()) + return outs + def forward_features(self, x): - x, sizes = self.forward_features_full(x) + x, sizes = self.forward_network(x) # take final feature and norm x = self.norms(x[-1]) H, W = sizes[-1]