diff --git a/timm/models/layers/evo_norm.py b/timm/models/layers/evo_norm.py index cc0a691a..d89aa424 100644 --- a/timm/models/layers/evo_norm.py +++ b/timm/models/layers/evo_norm.py @@ -53,9 +53,9 @@ def manual_var(x, dim: Union[int, Sequence[int]], diff_sqm: bool = False): xm = x.mean(dim=dim, keepdim=True) if diff_sqm: # difference of squared mean and mean squared, faster on TPU can be less stable - var = (x.square().mean(dim=dim, keepdim=True) - xm.square()).clamp(0) + var = ((x * x).mean(dim=dim, keepdim=True) - (xm * xm)).clamp(0) else: - var = (x - xm).square().mean(dim=dim, keepdim=True) + var = ((x - xm) * (x - xm)).mean(dim=dim, keepdim=True) return var @@ -121,7 +121,7 @@ class EvoNorm2dB0(nn.Module): if self.v is not None: if self.training: var = x.float().var(dim=(0, 2, 3), unbiased=False) - # var = manual_var(x, dim=(0, 2, 3)) + # var = manual_var(x, dim=(0, 2, 3)).squeeze() n = x.numel() / x.shape[1] self.running_var.copy_( self.running_var * (1 - self.momentum) +