Last change wasn't complete, missed adding full evo_norm changeset

pull/1239/head
Ross Wightman 3 years ago
parent 7bbbd5ef1b
commit 66daee4f31

@ -34,8 +34,14 @@ from .trace_utils import _assert
def instance_std(x, eps: float = 1e-5): def instance_std(x, eps: float = 1e-5):
rms = x.float().var(dim=(2, 3), unbiased=False, keepdim=True).add(eps).sqrt().to(x.dtype) std = x.float().var(dim=(2, 3), unbiased=False, keepdim=True).add(eps).sqrt().to(x.dtype)
return rms.expand(x.shape) return std.expand(x.shape)
def instance_std_tpu(x, eps: float = 1e-5):
std = manual_var(x, dim=(2, 3)).add(eps).sqrt()
return std.expand(x.shape)
# instance_std = instance_std_tpu
def instance_rms(x, eps: float = 1e-5): def instance_rms(x, eps: float = 1e-5):
@ -47,9 +53,9 @@ def manual_var(x, dim: Union[int, Sequence[int]], diff_sqm: bool = False):
xm = x.mean(dim=dim, keepdim=True) xm = x.mean(dim=dim, keepdim=True)
if diff_sqm: if diff_sqm:
# difference of squared mean and mean squared, faster on TPU can be less stable # difference of squared mean and mean squared, faster on TPU can be less stable
var = (x.square().mean(dim=(2, 3, 4), keepdim=True) - xm.square()).clamp(0) var = (x.square().mean(dim=dim, keepdim=True) - xm.square()).clamp(0)
else: else:
var = (x - xm).square().mean(dim=(2, 3, 4), keepdim=True) var = (x - xm).square().mean(dim=dim, keepdim=True)
return var return var
@ -57,7 +63,6 @@ def group_std(x, groups: int = 32, eps: float = 1e-5, flatten: bool = False):
B, C, H, W = x.shape B, C, H, W = x.shape
x_dtype = x.dtype x_dtype = x.dtype
_assert(C % groups == 0, '') _assert(C % groups == 0, '')
torch.var()
if flatten: if flatten:
x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues
std = x.float().var(dim=2, unbiased=False, keepdim=True).add(eps).sqrt().to(x_dtype) std = x.float().var(dim=2, unbiased=False, keepdim=True).add(eps).sqrt().to(x_dtype)
@ -116,6 +121,7 @@ class EvoNorm2dB0(nn.Module):
if self.v is not None: if self.v is not None:
if self.training: if self.training:
var = x.float().var(dim=(0, 2, 3), unbiased=False) var = x.float().var(dim=(0, 2, 3), unbiased=False)
# var = manual_var(x, dim=(0, 2, 3))
n = x.numel() / x.shape[1] n = x.numel() / x.shape[1]
self.running_var.copy_( self.running_var.copy_(
self.running_var * (1 - self.momentum) + self.running_var * (1 - self.momentum) +

Loading…
Cancel
Save