From 803254bb40526cadb9a087a90917e599c79d1e94 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 24 Oct 2022 21:43:49 -0700 Subject: [PATCH] Fix spacing misalignment for fast norm path in LayerNorm modules --- timm/models/layers/norm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/layers/norm.py b/timm/models/layers/norm.py index 42445a49..77d719ed 100644 --- a/timm/models/layers/norm.py +++ b/timm/models/layers/norm.py @@ -50,7 +50,7 @@ class LayerNorm(nn.LayerNorm): def forward(self, x: torch.Tensor) -> torch.Tensor: if self._fast_norm: - x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) else: x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) return x @@ -65,7 +65,7 @@ class LayerNorm2d(nn.LayerNorm): def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.permute(0, 2, 3, 1) if self._fast_norm: - x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) else: x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) x = x.permute(0, 3, 1, 2)