|
|
@ -140,6 +140,11 @@ class Attention(nn.Module):
|
|
|
|
self.scale = qk_scale or head_dim ** -0.5
|
|
|
|
self.scale = qk_scale or head_dim ** -0.5
|
|
|
|
|
|
|
|
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#Uncomment this line for Container-PAM
|
|
|
|
|
|
|
|
#self.static_a =nn.Parameter(torch.Tensor(1, num_heads, 1 + seq_l , 1 + seq_l))
|
|
|
|
|
|
|
|
#trunc_normal_(self.static_a)
|
|
|
|
|
|
|
|
|
|
|
|
self.attn_drop = nn.Dropout(attn_drop)
|
|
|
|
self.attn_drop = nn.Dropout(attn_drop)
|
|
|
|
self.proj = nn.Linear(dim, dim)
|
|
|
|
self.proj = nn.Linear(dim, dim)
|
|
|
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
|
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
|
@ -151,6 +156,10 @@ class Attention(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
|
|
attn = attn.softmax(dim=-1)
|
|
|
|
attn = attn.softmax(dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#Uncomment this line for Container-PAM
|
|
|
|
|
|
|
|
#attn = attn + self.static_a
|
|
|
|
|
|
|
|
|
|
|
|
attn = self.attn_drop(attn)
|
|
|
|
attn = self.attn_drop(attn)
|
|
|
|
|
|
|
|
|
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
|
|
|