import torch import torch.nn as nn import torch.nn.functional as F from .padding import get_padding, get_padding_value, pad_same def get_weight(module): std, mean = torch.std_mean(module.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) weight = (module.weight - mean) / (std + module.eps) return weight class StdConv2d(nn.Conv2d): """Conv2d with Weight Standardization. Used for BiT ResNet-V2 models. Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` - https://arxiv.org/abs/1903.10520v2 """ def __init__( self, in_channel, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1, bias=False, eps=1e-5, use_layernorm=True): if padding is None: padding = get_padding(kernel_size, stride, dilation) super().__init__( in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.eps = eps self.use_layernorm = use_layernorm def get_weight(self): if self.use_layernorm: # NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op weight = F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps) else: std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) weight = (self.weight - mean) / (std + self.eps) return weight def forward(self, x): x = F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups) return x class StdConv2dSame(nn.Conv2d): """Conv2d with Weight Standardization. TF compatible SAME padding. Used for ViT Hybrid model. Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` - https://arxiv.org/abs/1903.10520v2 """ def __init__( self, in_channel, out_channels, kernel_size, stride=1, padding='SAME', dilation=1, groups=1, bias=False, eps=1e-5, use_layernorm=True): padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) super().__init__( in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.same_pad = is_dynamic self.eps = eps self.use_layernorm = use_layernorm def get_weight(self): if self.use_layernorm: # NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op weight = F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps) else: std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) weight = (self.weight - mean) / (std + self.eps) return weight def forward(self, x): if self.same_pad: x = pad_same(x, self.kernel_size, self.stride, self.dilation) x = F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups) return x class ScaledStdConv2d(nn.Conv2d): """Conv2d layer with Scaled Weight Standardization. Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - https://arxiv.org/abs/2101.08692 NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor. """ def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-5, gain_init=1.0, use_layernorm=True): if padding is None: padding = get_padding(kernel_size, stride, dilation) super().__init__( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init)) self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in) self.eps = eps self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel def get_weight(self): if self.use_layernorm: # NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op weight = F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps) else: std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) weight = (self.weight - mean) / (std + self.eps) return weight.mul_(self.gain * self.scale) def forward(self, x): return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups) class ScaledStdConv2dSame(nn.Conv2d): """Conv2d layer with Scaled Weight Standardization and Tensorflow-like SAME padding support Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - https://arxiv.org/abs/2101.08692 NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor. """ def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding='SAME', dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-5, gain_init=1.0, use_layernorm=True): padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) super().__init__( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init)) self.scale = gamma * self.weight[0].numel() ** -0.5 self.same_pad = is_dynamic self.eps = eps self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel def get_weight(self): if self.use_layernorm: # NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op weight = F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps) else: std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) weight = (self.weight - mean) / (std + self.eps) return weight.mul_(self.gain * self.scale) def forward(self, x): if self.same_pad: x = pad_same(x, self.kernel_size, self.stride, self.dilation) return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)