|
|
|
@ -64,7 +64,6 @@ class GPSA(nn.Module):
|
|
|
|
|
self.dim = dim
|
|
|
|
|
head_dim = dim // num_heads
|
|
|
|
|
self.scale = qk_scale or head_dim ** -0.5
|
|
|
|
|
self.locality_strength = locality_strength
|
|
|
|
|
|
|
|
|
|
self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
|
|
|
|
self.v = nn.Linear(dim, dim, bias=qkv_bias)
|
|
|
|
|