|
|
|
@ -50,7 +50,7 @@ class LayerNorm(nn.LayerNorm):
|
|
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
if self._fast_norm:
|
|
|
|
|
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
|
|
|
|
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
|
|
|
|
else:
|
|
|
|
|
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
|
|
|
|
return x
|
|
|
|
@ -65,7 +65,7 @@ class LayerNorm2d(nn.LayerNorm):
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
x = x.permute(0, 2, 3, 1)
|
|
|
|
|
if self._fast_norm:
|
|
|
|
|
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
|
|
|
|
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
|
|
|
|
else:
|
|
|
|
|
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
|
|
|
|
x = x.permute(0, 3, 1, 2)
|
|
|
|
|