|
|
@ -392,13 +392,15 @@ class SwinTransformerBlock(nn.Module):
|
|
|
|
x = x.view(B, H, W, C)
|
|
|
|
x = x.view(B, H, W, C)
|
|
|
|
|
|
|
|
|
|
|
|
# cyclic shift
|
|
|
|
# cyclic shift
|
|
|
|
|
|
|
|
sh, sw = self.shift_size
|
|
|
|
if any(self.shift_size):
|
|
|
|
if any(self.shift_size):
|
|
|
|
shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))
|
|
|
|
# FIXME PyTorch XLA needs cat impl, roll not lowered
|
|
|
|
else:
|
|
|
|
# x = torch.cat([x[:, sh:], x[:, :sh]], dim=1)
|
|
|
|
shifted_x = x
|
|
|
|
# x = torch.cat([x[:, :, sw:], x[:, :, :sw]], dim=2)
|
|
|
|
|
|
|
|
x = torch.roll(x, shifts=(-sh, -sw), dims=(1, 2))
|
|
|
|
|
|
|
|
|
|
|
|
# partition windows
|
|
|
|
# partition windows
|
|
|
|
x_windows = window_partition(shifted_x, self.window_size) # num_windows * B, window_size, window_size, C
|
|
|
|
x_windows = window_partition(x, self.window_size) # num_windows * B, window_size, window_size, C
|
|
|
|
x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1], C)
|
|
|
|
x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1], C)
|
|
|
|
|
|
|
|
|
|
|
|
# W-MSA/SW-MSA
|
|
|
|
# W-MSA/SW-MSA
|
|
|
@ -406,13 +408,14 @@ class SwinTransformerBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
# merge windows
|
|
|
|
# merge windows
|
|
|
|
attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
|
|
|
|
attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
|
|
|
|
shifted_x = window_reverse(attn_windows, self.window_size, self.feat_size) # B H' W' C
|
|
|
|
x = window_reverse(attn_windows, self.window_size, self.feat_size) # B H' W' C
|
|
|
|
|
|
|
|
|
|
|
|
# reverse cyclic shift
|
|
|
|
# reverse cyclic shift
|
|
|
|
if any(self.shift_size):
|
|
|
|
if any(self.shift_size):
|
|
|
|
x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2))
|
|
|
|
# FIXME PyTorch XLA needs cat impl, roll not lowered
|
|
|
|
else:
|
|
|
|
# x = torch.cat([x[:, -sh:], x[:, :-sh]], dim=1)
|
|
|
|
x = shifted_x
|
|
|
|
# x = torch.cat([x[:, :, -sw:], x[:, :, :-sw]], dim=2)
|
|
|
|
|
|
|
|
x = torch.roll(x, shifts=(sh, sw), dims=(1, 2))
|
|
|
|
|
|
|
|
|
|
|
|
x = x.view(B, L, C)
|
|
|
|
x = x.view(B, L, C)
|
|
|
|
return x
|
|
|
|
return x
|
|
|
@ -452,8 +455,10 @@ class PatchMerging(nn.Module):
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
|
output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2]
|
|
|
|
output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2]
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
x = bchw_to_bhwc(x).unfold(dimension=1, size=2, step=2).unfold(dimension=2, size=2, step=2)
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
x = x.permute(0, 1, 2, 5, 4, 3).flatten(3) # permute maintains compat with ch order in official swin impl
|
|
|
|
# unfold + BCHW -> BHWC together
|
|
|
|
|
|
|
|
# ordering, 5, 3, 1 instead of 3, 5, 1 maintains compat with original swin v1 merge
|
|
|
|
|
|
|
|
x = x.reshape(B, C, H // 2, 2, W // 2, 2).permute(0, 2, 4, 5, 3, 1).flatten(3)
|
|
|
|
x = self.norm(x)
|
|
|
|
x = self.norm(x)
|
|
|
|
x = bhwc_to_bchw(self.reduction(x))
|
|
|
|
x = bhwc_to_bchw(self.reduction(x))
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|