""" EvoNorm in PyTorch Based on `Evolving Normalization-Activation Layers` - https://arxiv.org/abs/2004.02967 @inproceedings{NEURIPS2020, author = {Liu, Hanxiao and Brock, Andy and Simonyan, Karen and Le, Quoc}, booktitle = {Advances in Neural Information Processing Systems}, editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H. Lin}, pages = {13539--13550}, publisher = {Curran Associates, Inc.}, title = {Evolving Normalization-Activation Layers}, url = {https://proceedings.neurips.cc/paper/2020/file/9d4c03631b8b0c85ae08bf05eda37d0f-Paper.pdf}, volume = {33}, year = {2020} } An attempt at getting decent performing EvoNorms running in PyTorch. While faster than other PyTorch impl, still quite a ways off the built-in BatchNorm in terms of memory usage and throughput on GPUs. I'm testing these modules on TPU w/ PyTorch XLA. Promising start but currently working around some issues with builtin torch/tensor.var/std. Unlike GPU, similar train speeds for EvoNormS variants and BatchNorm. Hacked together by / Copyright 2020 Ross Wightman """ import torch import torch.nn as nn import torch.nn.functional as F from .create_act import create_act_layer from .trace_utils import _assert def instance_std(x, eps: float = 1e-5): rms = x.float().var(dim=(2, 3), unbiased=False, keepdim=True).add(eps).sqrt().to(x.dtype) return rms.expand(x.shape) def instance_rms(x, eps: float = 1e-5): rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).sqrt().to(dtype=x.dtype) return rms.expand(x.shape) def group_std(x, groups: int = 32, eps: float = 1e-5, flatten: bool = False): B, C, H, W = x.shape x_dtype = x.dtype _assert(C % groups == 0, '') # x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues # std = x.float().var(dim=2, unbiased=False, keepdim=True).add(eps).sqrt() x = x.reshape(B, groups, C // groups, H, W) std = x.float().var(dim=(2, 3, 4), unbiased=False, keepdim=True).add(eps).sqrt() return std.expand(x.shape).reshape(B, C, H, W).to(x_dtype) def group_std_tpu(x, groups: int = 32, eps: float = 1e-5, diff_sqm: bool = False): # This is a workaround for some stability / odd behaviour of .var and .std # running on PyTorch XLA w/ TPUs. These manual var impl are producing much better results B, C, H, W = x.shape _assert(C % groups == 0, '') x_dtype = x.dtype x = x.float().reshape(B, groups, C // groups, H, W) xm = x.mean(dim=(2, 3, 4), keepdim=True) if diff_sqm: # difference of squared mean and mean squared, faster on TPU var = (x.square().mean(dim=(2, 3, 4), keepdim=True) - xm.square()).clamp(0) else: var = (x - xm).square().mean(dim=(2, 3, 4), keepdim=True) return var.add(eps).sqrt().expand(x.shape).reshape(B, C, H, W).to(x_dtype) # group_std = group_std_tpu # temporary, for TPU / PT XLA def group_rms(x, groups: int = 32, eps: float = 1e-5): B, C, H, W = x.shape _assert(C % groups == 0, '') x_dtype = x.dtype x = x.reshape(B, groups, C // groups, H, W) sqm = x.square().mean(dim=(2, 3, 4), keepdim=True).add(eps).sqrt_().to(dtype=x_dtype) return sqm.expand(x.shape).reshape(B, C, H, W) class EvoNorm2dB0(nn.Module): def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, **_): super().__init__() self.apply_act = apply_act # apply activation (non-linearity) self.momentum = momentum self.eps = eps self.weight = nn.Parameter(torch.ones(num_features)) self.bias = nn.Parameter(torch.zeros(num_features)) self.v = nn.Parameter(torch.ones(num_features)) if apply_act else None self.register_buffer('running_var', torch.ones(num_features)) self.reset_parameters() def reset_parameters(self): nn.init.ones_(self.weight) nn.init.zeros_(self.bias) if self.v is not None: nn.init.ones_(self.v) def forward(self, x): _assert(x.dim() == 4, 'expected 4D input') x_dtype = x.dtype v_shape = (1, -1, 1, 1) if self.v is not None: if self.training: var = x.float().var(dim=(0, 2, 3), unbiased=False) n = x.numel() / x.shape[1] self.running_var.copy_( self.running_var * (1 - self.momentum) + var.detach() * self.momentum * (n / (n - 1))) else: var = self.running_var left = var.add(self.eps).sqrt_().to(x_dtype).view(v_shape).expand_as(x) v = self.v.to(x_dtype).view(v_shape) right = x * v + instance_std(x, self.eps) x = x / left.max(right) return x * self.weight.to(x_dtype).view(v_shape) + self.bias.to(x_dtype).view(v_shape) class EvoNorm2dB1(nn.Module): def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, **_): super().__init__() self.apply_act = apply_act # apply activation (non-linearity) self.momentum = momentum self.eps = eps self.weight = nn.Parameter(torch.ones(num_features)) self.bias = nn.Parameter(torch.zeros(num_features)) self.register_buffer('running_var', torch.ones(num_features)) self.reset_parameters() def reset_parameters(self): nn.init.ones_(self.weight) nn.init.zeros_(self.bias) def forward(self, x): _assert(x.dim() == 4, 'expected 4D input') x_dtype = x.dtype v_shape = (1, -1, 1, 1) if self.apply_act: if self.training: var = x.float().var(dim=(0, 2, 3), unbiased=False) n = x.numel() / x.shape[1] self.running_var.copy_( self.running_var * (1 - self.momentum) + var.detach().to(dtype=self.running_var.dtype) * self.momentum * (n / (n - 1))) else: var = self.running_var var = var.to(dtype=x_dtype).view(v_shape) left = var.add(self.eps).sqrt_() right = (x + 1) * instance_rms(x, self.eps) x = x / left.max(right) return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) class EvoNorm2dB2(nn.Module): def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, **_): super().__init__() self.apply_act = apply_act # apply activation (non-linearity) self.momentum = momentum self.eps = eps self.weight = nn.Parameter(torch.ones(num_features)) self.bias = nn.Parameter(torch.zeros(num_features)) self.register_buffer('running_var', torch.ones(num_features)) self.reset_parameters() def reset_parameters(self): nn.init.ones_(self.weight) nn.init.zeros_(self.bias) def forward(self, x): _assert(x.dim() == 4, 'expected 4D input') x_dtype = x.dtype v_shape = (1, -1, 1, 1) if self.apply_act: if self.training: var = x.float().var(dim=(0, 2, 3), unbiased=False) n = x.numel() / x.shape[1] self.running_var.copy_( self.running_var * (1 - self.momentum) + var.detach().to(dtype=self.running_var.dtype) * self.momentum * (n / (n - 1))) else: var = self.running_var var = var.to(dtype=x_dtype).view(v_shape) left = var.add(self.eps).sqrt_() right = instance_rms(x, self.eps) - x x = x / left.max(right) return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) class EvoNorm2dS0(nn.Module): def __init__(self, num_features, groups=32, group_size=None, apply_act=True, eps=1e-5, **_): super().__init__() self.apply_act = apply_act # apply activation (non-linearity) if group_size: assert num_features % group_size == 0 self.groups = num_features // group_size else: self.groups = groups self.eps = eps self.weight = nn.Parameter(torch.ones(num_features)) self.bias = nn.Parameter(torch.zeros(num_features)) self.v = nn.Parameter(torch.ones(num_features)) if apply_act else None self.reset_parameters() def reset_parameters(self): nn.init.ones_(self.weight) nn.init.zeros_(self.bias) if self.v is not None: nn.init.ones_(self.v) def forward(self, x): _assert(x.dim() == 4, 'expected 4D input') x_dtype = x.dtype v_shape = (1, -1, 1, 1) if self.v is not None: v = self.v.view(v_shape).to(dtype=x_dtype) x = x * (x * v).sigmoid() / group_std(x, self.groups, self.eps) return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) class EvoNorm2dS0a(EvoNorm2dS0): def __init__(self, num_features, groups=32, group_size=None, apply_act=True, eps=1e-5, **_): super().__init__( num_features, groups=groups, group_size=group_size, apply_act=apply_act, eps=eps) def forward(self, x): _assert(x.dim() == 4, 'expected 4D input') x_dtype = x.dtype v_shape = (1, -1, 1, 1) d = group_std(x, self.groups, self.eps) if self.v is not None: v = self.v.view(v_shape).to(dtype=x_dtype) x = x * (x * v).sigmoid_() x = x / d return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) class EvoNorm2dS1(nn.Module): def __init__( self, num_features, groups=32, group_size=None, apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_): super().__init__() self.apply_act = apply_act # apply activation (non-linearity) if act_layer is not None and apply_act: self.act = create_act_layer(act_layer) else: self.act = nn.Identity() if group_size: assert num_features % group_size == 0 self.groups = num_features // group_size else: self.groups = groups self.eps = eps self.pre_act_norm = False self.weight = nn.Parameter(torch.ones(num_features)) self.bias = nn.Parameter(torch.zeros(num_features)) self.reset_parameters() def reset_parameters(self): nn.init.ones_(self.weight) nn.init.zeros_(self.bias) def forward(self, x): _assert(x.dim() == 4, 'expected 4D input') x_dtype = x.dtype v_shape = (1, -1, 1, 1) if self.apply_act: x = self.act(x) / group_std(x, self.groups, self.eps) return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) class EvoNorm2dS1a(EvoNorm2dS1): def __init__( self, num_features, groups=32, group_size=None, apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_): super().__init__( num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps) def forward(self, x): _assert(x.dim() == 4, 'expected 4D input') x_dtype = x.dtype v_shape = (1, -1, 1, 1) x = self.act(x) / group_std(x, self.groups, self.eps) return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) class EvoNorm2dS2(nn.Module): def __init__( self, num_features, groups=32, group_size=None, apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_): super().__init__() self.apply_act = apply_act # apply activation (non-linearity) if act_layer is not None and apply_act: self.act = create_act_layer(act_layer) else: self.act = nn.Identity() if group_size: assert num_features % group_size == 0 self.groups = num_features // group_size else: self.groups = groups self.eps = eps self.weight = nn.Parameter(torch.ones(num_features)) self.bias = nn.Parameter(torch.zeros(num_features)) self.reset_parameters() def reset_parameters(self): nn.init.ones_(self.weight) nn.init.zeros_(self.bias) def forward(self, x): _assert(x.dim() == 4, 'expected 4D input') x_dtype = x.dtype v_shape = (1, -1, 1, 1) if self.apply_act: x = self.act(x) / group_rms(x, self.groups, self.eps) return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) class EvoNorm2dS2a(EvoNorm2dS2): def __init__( self, num_features, groups=32, group_size=None, apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_): super().__init__( num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps) def forward(self, x): _assert(x.dim() == 4, 'expected 4D input') x_dtype = x.dtype v_shape = (1, -1, 1, 1) x = self.act(x) / group_rms(x, self.groups, self.eps) return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)