Merge pull request #933 from t-vi/unbind

use .unbind instead of explicitly listing the indices
more_datasets
Ross Wightman 3 years ago committed by GitHub
commit 7da1b0b61c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -136,7 +136,7 @@ class Attention(nn.Module):
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))

@ -81,7 +81,7 @@ class Attention(nn.Module):
B, T, N, C = x.shape
# result of next line is (qkv, B, num (H)eads, T, N, (C')hannels per head)
qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C // self.num_heads).permute(3, 0, 4, 1, 2, 5)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale # (B, H, T, N, N)
attn = attn.softmax(dim=-1)

@ -172,7 +172,7 @@ class WindowAttention(nn.Module):
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))

@ -61,7 +61,7 @@ class Attention(nn.Module):
def forward(self, x):
B, N, C = x.shape
qk = self.qk(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k = qk[0], qk[1] # make torchscript happy (cannot use tensor as tuple)
q, k = qk.unbind(0) # make torchscript happy (cannot use tensor as tuple)
v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
attn = (q @ k.transpose(-2, -1)) * self.scale

@ -190,7 +190,7 @@ class Attention(nn.Module):
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)

@ -267,7 +267,7 @@ class XCA(nn.Module):
B, N, C = x.shape
# Result of next line is (qkv, B, num (H)eads, (C')hannels per head, N)
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
# Paper section 3.2 l2-Normalization and temperature scaling
q = torch.nn.functional.normalize(q, dim=-1)

Loading…
Cancel
Save