diff --git a/timm/models/layers/norm.py b/timm/models/layers/norm.py index 42445a49..77d719ed 100644 --- a/timm/models/layers/norm.py +++ b/timm/models/layers/norm.py @@ -50,7 +50,7 @@ class LayerNorm(nn.LayerNorm): def forward(self, x: torch.Tensor) -> torch.Tensor: if self._fast_norm: - x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) else: x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) return x @@ -65,7 +65,7 @@ class LayerNorm2d(nn.LayerNorm): def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.permute(0, 2, 3, 1) if self._fast_norm: - x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) else: x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) x = x.permute(0, 3, 1, 2)