diff --git a/timm/models/layers/evo_norm.py b/timm/models/layers/evo_norm.py index 5032a527..cc0a691a 100644 --- a/timm/models/layers/evo_norm.py +++ b/timm/models/layers/evo_norm.py @@ -34,8 +34,14 @@ from .trace_utils import _assert def instance_std(x, eps: float = 1e-5): - rms = x.float().var(dim=(2, 3), unbiased=False, keepdim=True).add(eps).sqrt().to(x.dtype) - return rms.expand(x.shape) + std = x.float().var(dim=(2, 3), unbiased=False, keepdim=True).add(eps).sqrt().to(x.dtype) + return std.expand(x.shape) + + +def instance_std_tpu(x, eps: float = 1e-5): + std = manual_var(x, dim=(2, 3)).add(eps).sqrt() + return std.expand(x.shape) +# instance_std = instance_std_tpu def instance_rms(x, eps: float = 1e-5): @@ -47,9 +53,9 @@ def manual_var(x, dim: Union[int, Sequence[int]], diff_sqm: bool = False): xm = x.mean(dim=dim, keepdim=True) if diff_sqm: # difference of squared mean and mean squared, faster on TPU can be less stable - var = (x.square().mean(dim=(2, 3, 4), keepdim=True) - xm.square()).clamp(0) + var = (x.square().mean(dim=dim, keepdim=True) - xm.square()).clamp(0) else: - var = (x - xm).square().mean(dim=(2, 3, 4), keepdim=True) + var = (x - xm).square().mean(dim=dim, keepdim=True) return var @@ -57,7 +63,6 @@ def group_std(x, groups: int = 32, eps: float = 1e-5, flatten: bool = False): B, C, H, W = x.shape x_dtype = x.dtype _assert(C % groups == 0, '') - torch.var() if flatten: x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues std = x.float().var(dim=2, unbiased=False, keepdim=True).add(eps).sqrt().to(x_dtype) @@ -116,6 +121,7 @@ class EvoNorm2dB0(nn.Module): if self.v is not None: if self.training: var = x.float().var(dim=(0, 2, 3), unbiased=False) + # var = manual_var(x, dim=(0, 2, 3)) n = x.numel() / x.shape[1] self.running_var.copy_( self.running_var * (1 - self.momentum) +