diff --git a/timm/models/davit.py b/timm/models/davit.py index a2c39a49..c006fbce 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -309,7 +309,6 @@ class SpatialBlock(nn.Module): H, W = size B, L, C = x.shape - assert L == H * W, "input feature has wrong size" shortcut = self.cpe[0](x, size) x = self.norm1(shortcut) @@ -334,8 +333,8 @@ class SpatialBlock(nn.Module): C) x = window_reverse(attn_windows, self.window_size, Hp, Wp) - if pad_r > 0 or pad_b > 0: - x = x[:, :H, :W, :].contiguous() + #if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() x = x.view(B, H * W, C) x = shortcut + self.drop_path(x)