Fix spacing misalignment for fast norm path in LayerNorm modules

pull/1649/head
Ross Wightman 2 years ago
parent 3044b1a1cb
commit 3ec5918a74

@ -50,7 +50,7 @@ class LayerNorm(nn.LayerNorm):
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self._fast_norm:
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
else:
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
return x
@ -65,7 +65,7 @@ class LayerNorm2d(nn.LayerNorm):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.permute(0, 2, 3, 1)
if self._fast_norm:
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
else:
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
x = x.permute(0, 3, 1, 2)

Loading…
Cancel
Save