Improve tracing of window attn models with simpler reshape logic

pull/1553/merge
Ross Wightman 2 years ago committed by Ross Wightman
parent a3c6685e20
commit 7d9e321b76

@ -217,9 +217,9 @@ def window_reverse(windows: Tensor, window_size: Tuple[int, int], H: int, W: int
Returns: Returns:
x: (B, H, W, C) x: (B, H, W, C)
""" """
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) C = windows.shape[-1]
x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1) x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
return x return x

@ -243,9 +243,9 @@ def window_partition(x, window_size: Tuple[int, int]):
@register_notrace_function # reason: int argument is a Proxy @register_notrace_function # reason: int argument is a Proxy
def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]): def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]):
H, W = img_size H, W = img_size
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) C = windows.shape[-1]
x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1) x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
return x return x

@ -126,9 +126,9 @@ def window_reverse(windows, window_size: int, H: int, W: int):
Returns: Returns:
x: (B, H, W, C) x: (B, H, W, C)
""" """
B = int(windows.shape[0] / (H * W / window_size / window_size)) C = windows.shape[-1]
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = windows.view(-1, H // window_size, W // window_size, window_size, window_size, C)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
return x return x

@ -120,9 +120,9 @@ def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, i
x: (B, H, W, C) x: (B, H, W, C)
""" """
H, W = img_size H, W = img_size
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) C = windows.shape[-1]
x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1) x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
return x return x

@ -139,9 +139,9 @@ def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, i
x: (B, H, W, C) x: (B, H, W, C)
""" """
H, W = img_size H, W = img_size
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) C = windows.shape[-1]
x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1) x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
return x return x

Loading…
Cancel
Save