Realized LayerNorm2d won't work in all cases as is, fixed.

pull/738/head
Ross Wightman 3 years ago
parent 81cd6863c8
commit 8165cacd82

@ -15,9 +15,10 @@ class GroupNorm(nn.GroupNorm):
class LayerNorm2d(nn.LayerNorm):
""" Layernorm for channels of '2d' spatial BCHW tensors """
""" LayerNorm for channels of '2D' spatial BCHW tensors """
def __init__(self, num_channels):
super().__init__([num_channels, 1, 1])
super().__init__(num_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
return F.layer_norm(
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)

Loading…
Cancel
Save