Post merge cleanup

pull/693/head
Ross Wightman 3 years ago
parent 45dec179e5
commit 2a63d0246b

@ -64,6 +64,7 @@ class GPSA(nn.Module):
self.dim = dim
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.locality_strength = locality_strength
self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.v = nn.Linear(dim, dim, bias=qkv_bias)
@ -72,7 +73,6 @@ class GPSA(nn.Module):
self.proj = nn.Linear(dim, dim)
self.pos_proj = nn.Linear(3, num_heads)
self.proj_drop = nn.Dropout(proj_drop)
self.locality_strength = locality_strength
self.gating_param = nn.Parameter(torch.ones(self.num_heads))
self.rel_indices: torch.Tensor = torch.zeros(1, 1, 1, 3) # silly torchscript hack, won't work with None

@ -92,8 +92,7 @@ class LinearBottleneck(nn.Module):
if self.use_shortcut:
if self.drop_path is not None:
x = self.drop_path(x)
x[:, 0:self.in_channels] += shortcut
x[:, 0:self.in_channels] += shortcut
return x

Loading…
Cancel
Save