|
|
@ -160,9 +160,15 @@ class WindowAttention(nn.Module):
|
|
|
|
# get relative_coords_table
|
|
|
|
# get relative_coords_table
|
|
|
|
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
|
|
|
|
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
|
|
|
|
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
|
|
|
|
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
|
|
|
|
relative_coords_table = torch.stack(torch.meshgrid([
|
|
|
|
relative_coords_table = torch.stack(
|
|
|
|
|
|
|
|
torch.meshgrid(
|
|
|
|
|
|
|
|
[
|
|
|
|
relative_coords_h,
|
|
|
|
relative_coords_h,
|
|
|
|
relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
|
|
|
|
relative_coords_w
|
|
|
|
|
|
|
|
],
|
|
|
|
|
|
|
|
indexing='ij',
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
|
|
|
|
if pretrained_window_size[0] > 0:
|
|
|
|
if pretrained_window_size[0] > 0:
|
|
|
|
relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
|
|
|
|
relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
|
|
|
|
relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
|
|
|
|
relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
|
|
|
@ -178,7 +184,7 @@ class WindowAttention(nn.Module):
|
|
|
|
# get pair-wise relative position index for each token inside the window
|
|
|
|
# get pair-wise relative position index for each token inside the window
|
|
|
|
coords_h = torch.arange(self.window_size[0])
|
|
|
|
coords_h = torch.arange(self.window_size[0])
|
|
|
|
coords_w = torch.arange(self.window_size[1])
|
|
|
|
coords_w = torch.arange(self.window_size[1])
|
|
|
|
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
|
|
|
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww
|
|
|
|
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
|
|
|
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
|
|
|
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
|
|
|
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
|
|
|
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
|
|
|
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
|
|
|