Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 7861c9dbf7
commit 4ecf422cd3

@ -576,8 +576,8 @@ class DaViT(nn.Module):
def forward_features(self, x):
x, sizes = self.forward_network(x)
# take final feature and norm
x = self.norms(x)
H, W = size
x = self.norms(x[-1])
H, W = sizes[-1]
x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous()
return x

Loading…
Cancel
Save