From c4f9bd4fc41e3f32c33a0142cf03cc8f282a77d6 Mon Sep 17 00:00:00 2001 From: kozistr Date: Thu, 24 Nov 2022 16:21:46 +0900 Subject: [PATCH] update: set `indexing` parameter with `ij` --- timm/models/swin_transformer.py | 2 +- timm/models/swin_transformer_v2.py | 14 ++++++++++---- timm/models/swin_transformer_v2_cr.py | 13 ++++++++++--- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index f2305fb2..be83be96 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -133,7 +133,7 @@ def window_reverse(windows, window_size: int, H: int, W: int): def get_relative_position_index(win_h, win_w): # get pair-wise relative position index for each token inside the window - coords = torch.stack(torch.meshgrid([torch.arange(win_h), torch.arange(win_w)])) # 2, Wh, Ww + coords = torch.stack(torch.meshgrid([torch.arange(win_h), torch.arange(win_w)], indexing='ij')) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index 0c9db3dd..d2adf01c 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -160,9 +160,15 @@ class WindowAttention(nn.Module): # get relative_coords_table relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) - relative_coords_table = torch.stack(torch.meshgrid([ - relative_coords_h, - relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 + relative_coords_table = torch.stack( + torch.meshgrid( + [ + relative_coords_h, + relative_coords_w + ], + indexing='ij', + ) + ).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 if pretrained_window_size[0] > 0: relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1) relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1) @@ -178,7 +184,7 @@ class WindowAttention(nn.Module): # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index d143c14c..8466f3f7 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -194,9 +194,16 @@ class WindowMultiHeadAttention(nn.Module): def _make_pair_wise_relative_positions(self) -> None: """Method initializes the pair-wise relative positions to compute the positional biases.""" device = self.logit_scale.device - coordinates = torch.stack(torch.meshgrid([ - torch.arange(self.window_size[0], device=device), - torch.arange(self.window_size[1], device=device)]), dim=0).flatten(1) + coordinates = torch.stack( + torch.meshgrid( + [ + torch.arange(self.window_size[0], device=device), + torch.arange(self.window_size[1], device=device), + ], + indexing='ij', + ), + dim=0, + ).flatten(1) relative_coordinates = coordinates[:, :, None] - coordinates[:, None, :] relative_coordinates = relative_coordinates.permute(1, 2, 0).reshape(-1, 2).float() relative_coordinates_log = torch.sign(relative_coordinates) * torch.log(