Fix some TPU (XLA) issues with swin transformer v2

pull/1414/head
Ross Wightman 2 years ago
parent ab16a358bb
commit ef57561d51

@ -392,13 +392,15 @@ class SwinTransformerBlock(nn.Module):
x = x.view(B, H, W, C)
# cyclic shift
sh, sw = self.shift_size
if any(self.shift_size):
shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))
else:
shifted_x = x
# FIXME PyTorch XLA needs cat impl, roll not lowered
# x = torch.cat([x[:, sh:], x[:, :sh]], dim=1)
# x = torch.cat([x[:, :, sw:], x[:, :, :sw]], dim=2)
x = torch.roll(x, shifts=(-sh, -sw), dims=(1, 2))
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # num_windows * B, window_size, window_size, C
x_windows = window_partition(x, self.window_size) # num_windows * B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1], C)
# W-MSA/SW-MSA
@ -406,13 +408,14 @@ class SwinTransformerBlock(nn.Module):
# merge windows
attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
shifted_x = window_reverse(attn_windows, self.window_size, self.feat_size) # B H' W' C
x = window_reverse(attn_windows, self.window_size, self.feat_size) # B H' W' C
# reverse cyclic shift
if any(self.shift_size):
x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2))
else:
x = shifted_x
# FIXME PyTorch XLA needs cat impl, roll not lowered
# x = torch.cat([x[:, -sh:], x[:, :-sh]], dim=1)
# x = torch.cat([x[:, :, -sw:], x[:, :, :-sw]], dim=2)
x = torch.roll(x, shifts=(sh, sw), dims=(1, 2))
x = x.view(B, L, C)
return x
@ -452,8 +455,10 @@ class PatchMerging(nn.Module):
Returns:
output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2]
"""
x = bchw_to_bhwc(x).unfold(dimension=1, size=2, step=2).unfold(dimension=2, size=2, step=2)
x = x.permute(0, 1, 2, 5, 4, 3).flatten(3) # permute maintains compat with ch order in official swin impl
B, C, H, W = x.shape
# unfold + BCHW -> BHWC together
# ordering, 5, 3, 1 instead of 3, 5, 1 maintains compat with original swin v1 merge
x = x.reshape(B, C, H // 2, 2, W // 2, 2).permute(0, 2, 4, 5, 3, 1).flatten(3)
x = self.norm(x)
x = bhwc_to_bchw(self.reduction(x))
return x

Loading…
Cancel
Save