diff --git a/timm/models/davit.py b/timm/models/davit.py index 7a6d8142..fe2c1bf3 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -582,7 +582,7 @@ class DaViT(nn.Module): # take final feature and norm x = self.norms(x) - H, W = sizes + H, W = size x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous() return x