diff --git a/timm/models/davit.py b/timm/models/davit.py index e03d25a6..d7e29fe5 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -485,7 +485,7 @@ class DaViT(nn.Module): @torch.jit.ignore def _update_forward_fn(self): if self._features_only == True: - self.forward = self.forward_features_full + self.forward = self.forward_pyramid_features else: self.forward = self.forward_classification @@ -527,7 +527,7 @@ class DaViT(nn.Module): self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) - def forward_features_full(self, x): + def forward_network(self, x): #x, size = self.patch_embeds[0](x, (x.size(2), x.size(3))) size: Tuple[int, int] = (x.size(2), x.size(3)) features = [x] @@ -601,9 +601,18 @@ class DaViT(nn.Module): # non-normalized pyramid features + corresponding sizes return features[1:], sizes[:-1] + + def forward_pyramid_features(self, x): + x, sizes = self.forward_network(x) + outs = [] + for i, out in enumerate(x): + H, W = sizes[i] + outs.append(x.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous()) + return outs + def forward_features(self, x): - x, sizes = self.forward_features_full(x) + x, sizes = self.forward_network(x) # take final feature and norm x = self.norms(x[-1]) H, W = sizes[-1]