diff --git a/timm/models/beit.py b/timm/models/beit.py index e8d1dd2c..199c2a4b 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -136,7 +136,7 @@ class Attention(nn.Module): qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = (q @ k.transpose(-2, -1)) diff --git a/timm/models/nest.py b/timm/models/nest.py index fe0645cc..9a477bf9 100644 --- a/timm/models/nest.py +++ b/timm/models/nest.py @@ -81,7 +81,7 @@ class Attention(nn.Module): B, T, N, C = x.shape # result of next line is (qkv, B, num (H)eads, T, N, (C')hannels per head) qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C // self.num_heads).permute(3, 0, 4, 1, 2, 5) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) attn = (q @ k.transpose(-2, -1)) * self.scale # (B, H, T, N, N) attn = attn.softmax(dim=-1) diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index 2ee106d2..822aeef8 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -172,7 +172,7 @@ class WindowAttention(nn.Module): """ B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = (q @ k.transpose(-2, -1)) @@ -649,4 +649,4 @@ def swin_large_patch4_window7_224_in22k(pretrained=False, **kwargs): """ model_kwargs = dict( patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs) - return _create_swin_transformer('swin_large_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs) \ No newline at end of file + return _create_swin_transformer('swin_large_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs) diff --git a/timm/models/tnt.py b/timm/models/tnt.py index 8186cc4a..9829653c 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -61,7 +61,7 @@ class Attention(nn.Module): def forward(self, x): B, N, C = x.shape qk = self.qk(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) - q, k = qk[0], qk[1] # make torchscript happy (cannot use tensor as tuple) + q, k = qk.unbind(0) # make torchscript happy (cannot use tensor as tuple) v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) attn = (q @ k.transpose(-2, -1)) * self.scale diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index ca8f52de..94ae2666 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -190,7 +190,7 @@ class Attention(nn.Module): def forward(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) @@ -893,4 +893,4 @@ def vit_base_patch16_224_miil(pretrained=False, **kwargs): """ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs) - return model \ No newline at end of file + return model diff --git a/timm/models/xcit.py b/timm/models/xcit.py index b7af3b26..2942ed8a 100644 --- a/timm/models/xcit.py +++ b/timm/models/xcit.py @@ -267,7 +267,7 @@ class XCA(nn.Module): B, N, C = x.shape # Result of next line is (qkv, B, num (H)eads, (C')hannels per head, N) qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) # Paper section 3.2 l2-Normalization and temperature scaling q = torch.nn.functional.normalize(q, dim=-1)