update: set `indexing` parameter with `ij`

pull/1565/head
kozistr 3 years ago
parent 25ffac6880
commit c4f9bd4fc4

@ -133,7 +133,7 @@ def window_reverse(windows, window_size: int, H: int, W: int):
def get_relative_position_index(win_h, win_w): def get_relative_position_index(win_h, win_w):
# get pair-wise relative position index for each token inside the window # get pair-wise relative position index for each token inside the window
coords = torch.stack(torch.meshgrid([torch.arange(win_h), torch.arange(win_w)])) # 2, Wh, Ww coords = torch.stack(torch.meshgrid([torch.arange(win_h), torch.arange(win_w)], indexing='ij')) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2

@ -160,9 +160,15 @@ class WindowAttention(nn.Module):
# get relative_coords_table # get relative_coords_table
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
relative_coords_table = torch.stack(torch.meshgrid([ relative_coords_table = torch.stack(
torch.meshgrid(
[
relative_coords_h, relative_coords_h,
relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 relative_coords_w
],
indexing='ij',
)
).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
if pretrained_window_size[0] > 0: if pretrained_window_size[0] > 0:
relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1) relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1) relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
@ -178,7 +184,7 @@ class WindowAttention(nn.Module):
# get pair-wise relative position index for each token inside the window # get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0]) coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1]) coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2

@ -194,9 +194,16 @@ class WindowMultiHeadAttention(nn.Module):
def _make_pair_wise_relative_positions(self) -> None: def _make_pair_wise_relative_positions(self) -> None:
"""Method initializes the pair-wise relative positions to compute the positional biases.""" """Method initializes the pair-wise relative positions to compute the positional biases."""
device = self.logit_scale.device device = self.logit_scale.device
coordinates = torch.stack(torch.meshgrid([ coordinates = torch.stack(
torch.meshgrid(
[
torch.arange(self.window_size[0], device=device), torch.arange(self.window_size[0], device=device),
torch.arange(self.window_size[1], device=device)]), dim=0).flatten(1) torch.arange(self.window_size[1], device=device),
],
indexing='ij',
),
dim=0,
).flatten(1)
relative_coordinates = coordinates[:, :, None] - coordinates[:, None, :] relative_coordinates = coordinates[:, :, None] - coordinates[:, None, :]
relative_coordinates = relative_coordinates.permute(1, 2, 0).reshape(-1, 2).float() relative_coordinates = relative_coordinates.permute(1, 2, 0).reshape(-1, 2).float()
relative_coordinates_log = torch.sign(relative_coordinates) * torch.log( relative_coordinates_log = torch.sign(relative_coordinates) * torch.log(

Loading…
Cancel
Save