|
|
|
@ -86,9 +86,11 @@ class Attention(nn.Module):
|
|
|
|
|
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
|
|
|
|
if qkv_bias:
|
|
|
|
|
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
|
|
|
|
self.register_buffer('k_bias', torch.zeros(all_head_dim), persistent=False)
|
|
|
|
|
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
|
|
|
|
else:
|
|
|
|
|
self.q_bias = None
|
|
|
|
|
self.k_bias = None
|
|
|
|
|
self.v_bias = None
|
|
|
|
|
|
|
|
|
|
if window_size:
|
|
|
|
@ -127,13 +129,7 @@ class Attention(nn.Module):
|
|
|
|
|
|
|
|
|
|
def forward(self, x, rel_pos_bias: Optional[torch.Tensor] = None):
|
|
|
|
|
B, N, C = x.shape
|
|
|
|
|
qkv_bias = None
|
|
|
|
|
if self.q_bias is not None:
|
|
|
|
|
if torch.jit.is_scripting():
|
|
|
|
|
# FIXME requires_grad breaks w/ torchscript
|
|
|
|
|
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias), self.v_bias))
|
|
|
|
|
else:
|
|
|
|
|
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
|
|
|
|
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None
|
|
|
|
|
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
|
|
|
|
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
|
|
|
|
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
|
|
|
|