Fix spacing misalignment for fast norm path in LayerNorm modules

pull/1522/head
Ross Wightman 2 years ago
parent 475ecdfa3d
commit 803254bb40

@ -50,7 +50,7 @@ class LayerNorm(nn.LayerNorm):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
if self._fast_norm: if self._fast_norm:
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
else: else:
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
return x return x
@ -65,7 +65,7 @@ class LayerNorm2d(nn.LayerNorm):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.permute(0, 2, 3, 1) x = x.permute(0, 2, 3, 1)
if self._fast_norm: if self._fast_norm:
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
else: else:
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
x = x.permute(0, 3, 1, 2) x = x.permute(0, 3, 1, 2)

Loading…
Cancel
Save