From ef57561d5124f831051e5996d8346e95ded69c14 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 16 Mar 2022 14:55:36 -0700 Subject: [PATCH] Fix some TPU (XLA) issues with swin transformer v2 --- timm/models/swin_transformer_v2_cr.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index 39ea993e..d3bf8c85 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -392,13 +392,15 @@ class SwinTransformerBlock(nn.Module): x = x.view(B, H, W, C) # cyclic shift + sh, sw = self.shift_size if any(self.shift_size): - shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2)) - else: - shifted_x = x + # FIXME PyTorch XLA needs cat impl, roll not lowered + # x = torch.cat([x[:, sh:], x[:, :sh]], dim=1) + # x = torch.cat([x[:, :, sw:], x[:, :, :sw]], dim=2) + x = torch.roll(x, shifts=(-sh, -sw), dims=(1, 2)) # partition windows - x_windows = window_partition(shifted_x, self.window_size) # num_windows * B, window_size, window_size, C + x_windows = window_partition(x, self.window_size) # num_windows * B, window_size, window_size, C x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1], C) # W-MSA/SW-MSA @@ -406,13 +408,14 @@ class SwinTransformerBlock(nn.Module): # merge windows attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C) - shifted_x = window_reverse(attn_windows, self.window_size, self.feat_size) # B H' W' C + x = window_reverse(attn_windows, self.window_size, self.feat_size) # B H' W' C # reverse cyclic shift if any(self.shift_size): - x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2)) - else: - x = shifted_x + # FIXME PyTorch XLA needs cat impl, roll not lowered + # x = torch.cat([x[:, -sh:], x[:, :-sh]], dim=1) + # x = torch.cat([x[:, :, -sw:], x[:, :, :-sw]], dim=2) + x = torch.roll(x, shifts=(sh, sw), dims=(1, 2)) x = x.view(B, L, C) return x @@ -452,8 +455,10 @@ class PatchMerging(nn.Module): Returns: output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2] """ - x = bchw_to_bhwc(x).unfold(dimension=1, size=2, step=2).unfold(dimension=2, size=2, step=2) - x = x.permute(0, 1, 2, 5, 4, 3).flatten(3) # permute maintains compat with ch order in official swin impl + B, C, H, W = x.shape + # unfold + BCHW -> BHWC together + # ordering, 5, 3, 1 instead of 3, 5, 1 maintains compat with original swin v1 merge + x = x.reshape(B, C, H // 2, 2, W // 2, 2).permute(0, 2, 4, 5, 3, 1).flatten(3) x = self.norm(x) x = bhwc_to_bchw(self.reduction(x)) return x