diff --git a/timm/models/davit.py b/timm/models/davit.py index 4ab7b6b0..8de6d0dc 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -216,9 +216,9 @@ def window_reverse(windows : Tensor, window_size: int, H: int, W: int): x: (B, H, W, C) """ - B : float = (windows.size(dim=0) / (H * W / window_size / window_size)) - x = windows.view(int(B), H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(int(B), H, W, -1) + B = torch.floor(torch.tensor(windows.size(dim=0)) / (H * W / window_size / window_size)).int() + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x