|
|
|
@ -64,6 +64,7 @@ class GPSA(nn.Module):
|
|
|
|
|
self.dim = dim
|
|
|
|
|
head_dim = dim // num_heads
|
|
|
|
|
self.scale = qk_scale or head_dim ** -0.5
|
|
|
|
|
self.locality_strength = locality_strength
|
|
|
|
|
|
|
|
|
|
self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
|
|
|
|
self.v = nn.Linear(dim, dim, bias=qkv_bias)
|
|
|
|
@ -72,7 +73,6 @@ class GPSA(nn.Module):
|
|
|
|
|
self.proj = nn.Linear(dim, dim)
|
|
|
|
|
self.pos_proj = nn.Linear(3, num_heads)
|
|
|
|
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
|
|
|
self.locality_strength = locality_strength
|
|
|
|
|
self.gating_param = nn.Parameter(torch.ones(self.num_heads))
|
|
|
|
|
self.rel_indices: torch.Tensor = torch.zeros(1, 1, 1, 3) # silly torchscript hack, won't work with None
|
|
|
|
|
|
|
|
|
|