From 4c8c7faa12f3a1e7724bfc3c2e29d98f42b627df Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 10 Dec 2022 20:41:46 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index db73b903..edf6edbc 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -792,34 +792,36 @@ class DaViT(nn.Module): def forward_network(self, x : Tensor): size: Tuple[int, int] = (x.size(2), x.size(3)) features = [x] - sizes = [size] + #sizes = [size] for stage in self.stages: - features[-1], sizes[-1] = stage(features[-1], sizes[-1]) + features[-1] = stage(features[-1]) # don't append outputs of last stage, since they are already there if(len(features) < self.num_stages): features.append(features[-1]) - sizes.append(sizes[-1]) + # non-normalized pyramid features + corresponding sizes - return features, sizes + return features def forward_pyramid_features(self, x) -> List[Tensor]: - x, sizes = self.forward_network(x) + x = self.forward_network(x) + ''' outs = [] for i, out in enumerate(x): H, W = sizes[i] - outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous()) - return outs + outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous()) + ''' + return x def forward_features(self, x): - x, sizes = self.forward_network(x) + x = self.forward_network(x) # take final feature and norm x = self.norms(x[-1]) - H, W = sizes[-1] - x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous() + #H, W = sizes[-1] + #x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous() return x def forward_head(self, x, pre_logits: bool = False):