|
|
|
@ -23,6 +23,7 @@ GPU, similar train speeds for EvoNormS variants and BatchNorm.
|
|
|
|
|
|
|
|
|
|
Hacked together by / Copyright 2020 Ross Wightman
|
|
|
|
|
"""
|
|
|
|
|
from typing import Sequence, Union
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
@ -38,36 +39,47 @@ def instance_std(x, eps: float = 1e-5):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def instance_rms(x, eps: float = 1e-5):
|
|
|
|
|
rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).sqrt().to(dtype=x.dtype)
|
|
|
|
|
rms = x.float().square().mean(dim=(2, 3), keepdim=True).add(eps).sqrt().to(x.dtype)
|
|
|
|
|
return rms.expand(x.shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def manual_var(x, dim: Union[int, Sequence[int]], diff_sqm: bool = False):
|
|
|
|
|
xm = x.mean(dim=dim, keepdim=True)
|
|
|
|
|
if diff_sqm:
|
|
|
|
|
# difference of squared mean and mean squared, faster on TPU can be less stable
|
|
|
|
|
var = (x.square().mean(dim=(2, 3, 4), keepdim=True) - xm.square()).clamp(0)
|
|
|
|
|
else:
|
|
|
|
|
var = (x - xm).square().mean(dim=(2, 3, 4), keepdim=True)
|
|
|
|
|
return var
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def group_std(x, groups: int = 32, eps: float = 1e-5, flatten: bool = False):
|
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
|
x_dtype = x.dtype
|
|
|
|
|
_assert(C % groups == 0, '')
|
|
|
|
|
# x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues
|
|
|
|
|
# std = x.float().var(dim=2, unbiased=False, keepdim=True).add(eps).sqrt()
|
|
|
|
|
x = x.reshape(B, groups, C // groups, H, W)
|
|
|
|
|
std = x.float().var(dim=(2, 3, 4), unbiased=False, keepdim=True).add(eps).sqrt()
|
|
|
|
|
return std.expand(x.shape).reshape(B, C, H, W).to(x_dtype)
|
|
|
|
|
torch.var()
|
|
|
|
|
if flatten:
|
|
|
|
|
x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues
|
|
|
|
|
std = x.float().var(dim=2, unbiased=False, keepdim=True).add(eps).sqrt().to(x_dtype)
|
|
|
|
|
else:
|
|
|
|
|
x = x.reshape(B, groups, C // groups, H, W)
|
|
|
|
|
std = x.float().var(dim=(2, 3, 4), unbiased=False, keepdim=True).add(eps).sqrt().to(x_dtype)
|
|
|
|
|
return std.expand(x.shape).reshape(B, C, H, W)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def group_std_tpu(x, groups: int = 32, eps: float = 1e-5, diff_sqm: bool = False):
|
|
|
|
|
def group_std_tpu(x, groups: int = 32, eps: float = 1e-5, diff_sqm: bool = False, flatten: bool = False):
|
|
|
|
|
# This is a workaround for some stability / odd behaviour of .var and .std
|
|
|
|
|
# running on PyTorch XLA w/ TPUs. These manual var impl are producing much better results
|
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
|
_assert(C % groups == 0, '')
|
|
|
|
|
x_dtype = x.dtype
|
|
|
|
|
x = x.float().reshape(B, groups, C // groups, H, W)
|
|
|
|
|
xm = x.mean(dim=(2, 3, 4), keepdim=True)
|
|
|
|
|
if diff_sqm:
|
|
|
|
|
# difference of squared mean and mean squared, faster on TPU
|
|
|
|
|
var = (x.square().mean(dim=(2, 3, 4), keepdim=True) - xm.square()).clamp(0)
|
|
|
|
|
if flatten:
|
|
|
|
|
x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues
|
|
|
|
|
var = manual_var(x, dim=-1, diff_sqm=diff_sqm)
|
|
|
|
|
else:
|
|
|
|
|
var = (x - xm).square().mean(dim=(2, 3, 4), keepdim=True)
|
|
|
|
|
return var.add(eps).sqrt().expand(x.shape).reshape(B, C, H, W).to(x_dtype)
|
|
|
|
|
# group_std = group_std_tpu # temporary, for TPU / PT XLA
|
|
|
|
|
x = x.reshape(B, groups, C // groups, H, W)
|
|
|
|
|
var = manual_var(x, dim=(2, 3, 4), diff_sqm=diff_sqm)
|
|
|
|
|
return var.add(eps).sqrt().expand(x.shape).reshape(B, C, H, W)
|
|
|
|
|
#group_std = group_std_tpu # FIXME TPU temporary
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def group_rms(x, groups: int = 32, eps: float = 1e-5):
|
|
|
|
@ -75,8 +87,8 @@ def group_rms(x, groups: int = 32, eps: float = 1e-5):
|
|
|
|
|
_assert(C % groups == 0, '')
|
|
|
|
|
x_dtype = x.dtype
|
|
|
|
|
x = x.reshape(B, groups, C // groups, H, W)
|
|
|
|
|
sqm = x.square().mean(dim=(2, 3, 4), keepdim=True).add(eps).sqrt_().to(dtype=x_dtype)
|
|
|
|
|
return sqm.expand(x.shape).reshape(B, C, H, W)
|
|
|
|
|
rms = x.float().square().mean(dim=(2, 3, 4), keepdim=True).add(eps).sqrt_().to(dtype=x_dtype)
|
|
|
|
|
return rms.expand(x.shape).reshape(B, C, H, W)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EvoNorm2dB0(nn.Module):
|
|
|
|
|