Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent ce5cef9ce0
commit 8408551195

@ -819,7 +819,7 @@ class DaViT(nn.Module):
def forward_features(self, x): def forward_features(self, x):
x = self.stages(x) x = self.stages(x)
# take final feature and norm # take final feature and norm
x = self.norms(x[-1].permute(0, 2, 3, 1)).permute(0, 3, 1, 2) x = self.norms(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
#H, W = sizes[-1] #H, W = sizes[-1]
#x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous() #x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous()
return x return x

Loading…
Cancel
Save