Update metaformers.py

pull/1647/head
Fredo Guan 2 years ago
parent d17bcb10a5
commit 01f671ed08

@ -347,18 +347,16 @@ class LayerNormGeneral(nn.Module):
self.normalized_dim = normalized_dim self.normalized_dim = normalized_dim
self.use_scale = scale self.use_scale = scale
self.use_bias = bias self.use_bias = bias
self.weight = nn.Parameter(torch.ones(affine_shape)) if scale else None self.weight = nn.Parameter(torch.ones(affine_shape)) if scale else torch.ones(affine_shape)
self.bias = nn.Parameter(torch.zeros(affine_shape)) if bias else None self.bias = nn.Parameter(torch.zeros(affine_shape)) if bias else torch.zeros(affine_shape)
self.eps = eps self.eps = eps
def forward(self, x): def forward(self, x):
c = x - x.mean(self.normalized_dim, keepdim=True) c = x - x.mean(self.normalized_dim, keepdim=True)
s = c.pow(2).mean(self.normalized_dim, keepdim=True) s = c.pow(2).mean(self.normalized_dim, keepdim=True)
x = c / torch.sqrt(s + self.eps) x = c / torch.sqrt(s + self.eps)
if self.use_scale: x = x * self.weight
x = x * self.weight x = x + self.bias
if self.use_bias:
x = x + self.bias
return x return x

Loading…
Cancel
Save