From 7d9e321b761a673000af312ad21ef1dec491b1e9 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 16 Feb 2023 23:46:52 -0800 Subject: [PATCH] Improve tracing of window attn models with simpler reshape logic --- timm/models/davit.py | 6 +++--- timm/models/gcvit.py | 6 +++--- timm/models/swin_transformer.py | 6 +++--- timm/models/swin_transformer_v2.py | 6 +++--- timm/models/swin_transformer_v2_cr.py | 6 +++--- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index e9871265..eb6492f9 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -217,9 +217,9 @@ def window_reverse(windows: Tensor, window_size: Tuple[int, int], H: int, W: int Returns: x: (B, H, W, C) """ - B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) - x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + C = windows.shape[-1] + x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C) return x diff --git a/timm/models/gcvit.py b/timm/models/gcvit.py index ec9b7e5e..2423a954 100644 --- a/timm/models/gcvit.py +++ b/timm/models/gcvit.py @@ -243,9 +243,9 @@ def window_partition(x, window_size: Tuple[int, int]): @register_notrace_function # reason: int argument is a Proxy def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]): H, W = img_size - B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) - x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + C = windows.shape[-1] + x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C) return x diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index 5df06d4d..bbc97036 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -126,9 +126,9 @@ def window_reverse(windows, window_size: int, H: int, W: int): Returns: x: (B, H, W, C) """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + C = windows.shape[-1] + x = windows.view(-1, H // window_size, W // window_size, window_size, window_size, C) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C) return x diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index efaaa9e9..ffdb85e0 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -120,9 +120,9 @@ def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, i x: (B, H, W, C) """ H, W = img_size - B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) - x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + C = windows.shape[-1] + x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C) return x diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index cf10b39c..9185e3e7 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -139,9 +139,9 @@ def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, i x: (B, H, W, C) """ H, W = img_size - B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) - x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + C = windows.shape[-1] + x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C) return x