|
|
|
@ -53,9 +53,9 @@ def manual_var(x, dim: Union[int, Sequence[int]], diff_sqm: bool = False):
|
|
|
|
|
xm = x.mean(dim=dim, keepdim=True)
|
|
|
|
|
if diff_sqm:
|
|
|
|
|
# difference of squared mean and mean squared, faster on TPU can be less stable
|
|
|
|
|
var = (x.square().mean(dim=dim, keepdim=True) - xm.square()).clamp(0)
|
|
|
|
|
var = ((x * x).mean(dim=dim, keepdim=True) - (xm * xm)).clamp(0)
|
|
|
|
|
else:
|
|
|
|
|
var = (x - xm).square().mean(dim=dim, keepdim=True)
|
|
|
|
|
var = ((x - xm) * (x - xm)).mean(dim=dim, keepdim=True)
|
|
|
|
|
return var
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -121,7 +121,7 @@ class EvoNorm2dB0(nn.Module):
|
|
|
|
|
if self.v is not None:
|
|
|
|
|
if self.training:
|
|
|
|
|
var = x.float().var(dim=(0, 2, 3), unbiased=False)
|
|
|
|
|
# var = manual_var(x, dim=(0, 2, 3))
|
|
|
|
|
# var = manual_var(x, dim=(0, 2, 3)).squeeze()
|
|
|
|
|
n = x.numel() / x.shape[1]
|
|
|
|
|
self.running_var.copy_(
|
|
|
|
|
self.running_var * (1 - self.momentum) +
|
|
|
|
|