diff --git a/timm/models/davit.py b/timm/models/davit.py index 2e3e33ba..1028c82f 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -508,11 +508,11 @@ class SpatialBlock(nn.Module): x = x.view(B, H * W, C) x = shortcut + self.drop_path(x) - x = self.cpe2(x) + x = self.cpe2(x.transpose(1, 2).view(B, C, H, W)).flatten(2).transpose(1, 2) if self.ffn: x = x + self.drop_path(self.mlp(self.norm2(x))) - x = x.transpose(1, 2).view(B, C, H, W) + x = x.transpose(1, 2).view(B, C, H, W) return x