Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent b80fb5652c
commit 34529ed363

@ -215,8 +215,8 @@ def window_reverse(windows : Tensor, window_size: int, H: int, W: int):
Returns: Returns:
x: (B, H, W, C) x: (B, H, W, C)
""" """
window_size_dim_0 : int = windows.size(dim=0)
B = int(window_size_dim_0 / (H * W / window_size / window_size)) B = torch.floor(torch.Tensor(windows.size(dim=0)) / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x return x

Loading…
Cancel
Save