|
|
|
@ -81,7 +81,7 @@ class Attention(nn.Module):
|
|
|
|
|
B, T, N, C = x.shape
|
|
|
|
|
# result of next line is (qkv, B, num (H)eads, T, N, (C')hannels per head)
|
|
|
|
|
qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C // self.num_heads).permute(3, 0, 4, 1, 2, 5)
|
|
|
|
|
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
|
|
|
|
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
|
|
|
|
|
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale # (B, H, T, N, N)
|
|
|
|
|
attn = attn.softmax(dim=-1)
|
|
|
|
|