From 27c42f0830afab4b2ff40b948cf612328ed26680 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 13 May 2022 09:29:33 -0700 Subject: [PATCH] Fix torchscript use for offician Swin-V2, add support for non-square window/shift to WindowAttn/Block --- timm/models/swin_transformer_v2.py | 80 ++++++++++++++++-------------- 1 file changed, 43 insertions(+), 37 deletions(-) diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index 700012fe..8b4eff64 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -13,6 +13,7 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W # Written by Ze Liu # -------------------------------------------------------- import math +from typing import Tuple, Optional import torch import torch.nn as nn @@ -91,7 +92,7 @@ default_cfgs = { } -def window_partition(x, window_size): +def window_partition(x, window_size: Tuple[int, int]): """ Args: x: (B, H, W, C) @@ -101,25 +102,25 @@ def window_partition(x, window_size): windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C) return windows @register_notrace_function # reason: int argument is a Proxy -def window_reverse(windows, window_size, H, W): +def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]): """ Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image + windows: (num_windows * B, window_size[0], window_size[1], C) + window_size (Tuple[int, int]): Window size + img_size (Tuple[int, int]): Image size Returns: x: (B, H, W, C) """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + H, W = img_size + B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) + x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x @@ -148,7 +149,7 @@ class WindowAttention(nn.Module): self.pretrained_window_size = pretrained_window_size self.num_heads = num_heads - self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) # mlp to generate continuous relative position bias self.cpb_mlp = nn.Sequential( @@ -202,7 +203,7 @@ class WindowAttention(nn.Module): self.proj_drop = nn.Dropout(proj_drop) self.softmax = nn.Softmax(dim=-1) - def forward(self, x, mask=None): + def forward(self, x, mask: Optional[torch.Tensor] = None): """ Args: x: input features with shape of (num_windows*B, N, C) @@ -218,7 +219,7 @@ class WindowAttention(nn.Module): # cosine attention attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) - logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp() + logit_scale = torch.clamp(self.logit_scale, max=math.log(1. / 0.01)).exp() attn = attn * logit_scale relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) @@ -269,16 +270,13 @@ class SwinTransformerBlock(nn.Module): act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0): super().__init__() self.dim = dim - self.input_resolution = input_resolution + self.input_resolution = to_2tuple(input_resolution) self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size + ws, ss = self._calc_window_shift(window_size, shift_size) + self.window_size: Tuple[int, int] = ws + self.shift_size: Tuple[int, int] = ss + self.window_area = self.window_size[0] * self.window_size[1] self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - _assert(0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size") self.attn = WindowAttention( dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, @@ -291,23 +289,23 @@ class SwinTransformerBlock(nn.Module): self.norm2 = norm_layer(dim) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - if self.shift_size > 0: + if any(self.shift_size): # calculate attention mask for SW-MSA H, W = self.input_resolution img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 cnt = 0 for h in ( - slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)): + slice(0, -self.window_size[0]), + slice(-self.window_size[0], -self.shift_size[0]), + slice(-self.shift_size[0], None)): for w in ( - slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)): + slice(0, -self.window_size[1]), + slice(-self.window_size[1], -self.shift_size[1]), + slice(-self.shift_size[1], None)): img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + mask_windows = mask_windows.view(-1, self.window_area) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) else: @@ -315,6 +313,13 @@ class SwinTransformerBlock(nn.Module): self.register_buffer("attn_mask", attn_mask) + def _calc_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]: + target_window_size = to_2tuple(target_window_size) + target_shift_size = to_2tuple(target_shift_size) + window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)] + shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)] + return tuple(window_size), tuple(shift_size) + def _attn(self, x): H, W = self.input_resolution B, L, C = x.shape @@ -322,25 +327,26 @@ class SwinTransformerBlock(nn.Module): x = x.view(B, H, W, C) # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + has_shift = any(self.shift_size) + if has_shift: + shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2)) else: shifted_x = x # partition windows x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + x_windows = x_windows.view(-1, self.window_area, C) # nW*B, window_size*window_size, C # W-MSA/SW-MSA attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C) + shifted_x = window_reverse(attn_windows, self.window_size, self.input_resolution) # B H' W' C # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + if has_shift: + x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2)) else: x = shifted_x x = x.view(B, H * W, C) @@ -445,7 +451,7 @@ class BasicLayer(nn.Module): def forward(self, x): for blk in self.blocks: - if self.grad_checkpointing: + if not torch.jit.is_scripting() and self.grad_checkpointing: x = checkpoint.checkpoint(blk, x) else: x = blk(x)