diff --git a/timm/models/layers/evo_norm.py b/timm/models/layers/evo_norm.py index 8c08e49f..6ef0c881 100644 --- a/timm/models/layers/evo_norm.py +++ b/timm/models/layers/evo_norm.py @@ -34,18 +34,17 @@ class EvoNormBatch2d(nn.Module): nn.init.ones_(self.v) def forward(self, x): - assert x.dim() == 4, 'expected 4D input' + _assert(x.dim() == 4, 'expected 4D input') x_type = x.dtype - running_var = self.running_var.view(1, -1, 1, 1) - if self.training: - var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True) - n = x.numel() / x.shape[1] - running_var = var.detach() * self.momentum * (n / (n - 1)) + running_var * (1 - self.momentum) - self.running_var.copy_(running_var.view(self.running_var.shape)) - else: - var = running_var - if self.v is not None: + running_var = self.running_var.view(1, -1, 1, 1) + if self.training: + var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True) + n = x.numel() / x.shape[1] + running_var = var.detach() * self.momentum * (n / (n - 1)) + running_var * (1 - self.momentum) + self.running_var.copy_(running_var.view(self.running_var.shape)) + else: + var = running_var v = self.v.to(dtype=x_type).reshape(1, -1, 1, 1) d = x * v + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type) d = d.max((var + self.eps).sqrt().to(dtype=x_type))