diff --git a/timm/models/davit.py b/timm/models/davit.py index c006fbce..48b387c4 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -55,7 +55,6 @@ class ConvPosEnc(nn.Module): def forward(self, x : Tensor, size: Tuple[int, int]): B, N, C = x.shape H, W = size - assert N == H * W feat = x.transpose(1, 2).view(B, C, H, W) feat = self.proj(feat)