import torch import torch.nn as nn import torch.nn.functional as F from .padding import get_padding, get_padding_value, pad_same def get_weight(module): std, mean = torch.std_mean(module.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) weight = (module.weight - mean) / (std + module.eps) return weight class StdConv2d(nn.Conv2d): """Conv2d with Weight Standardization. Used for BiT ResNet-V2 models. Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` - https://arxiv.org/abs/1903.10520v2 """ def __init__( self, in_channel, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1, bias=False, eps=1e-5): if padding is None: padding = get_padding(kernel_size, stride, dilation) super().__init__( in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.eps = eps def get_weight(self): std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) weight = (self.weight - mean) / (std + self.eps) return weight def forward(self, x): x = F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups) return x class StdConv2dSame(nn.Conv2d): """Conv2d with Weight Standardization. TF compatible SAME padding. Used for ViT Hybrid model. Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` - https://arxiv.org/abs/1903.10520v2 """ def __init__( self, in_channel, out_channels, kernel_size, stride=1, padding='SAME', dilation=1, groups=1, bias=False, eps=1e-5): padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) super().__init__( in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.same_pad = is_dynamic self.eps = eps def get_weight(self): std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) weight = (self.weight - mean) / (std + self.eps) return weight def forward(self, x): if self.same_pad: x = pad_same(x, self.kernel_size, self.stride, self.dilation) x = F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups) return x class ScaledStdConv2d(nn.Conv2d): """Conv2d layer with Scaled Weight Standardization. Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - https://arxiv.org/abs/2101.08692 NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor. """ def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-5, use_layernorm=False): if padding is None: padding = get_padding(kernel_size, stride, dilation) super().__init__( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in) self.eps = eps ** 2 if use_layernorm else eps self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel def get_weight(self): if self.use_layernorm: weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps) else: std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) weight = self.scale * (self.weight - mean) / (std + self.eps) return self.gain * weight def forward(self, x): return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups) class ScaledStdConv2dSame(nn.Conv2d): """Conv2d layer with Scaled Weight Standardization and Tensorflow-like SAME padding support Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - https://arxiv.org/abs/2101.08692 NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor. """ def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding='SAME', dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-5, use_layernorm=False): padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) super().__init__( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) self.scale = gamma * self.weight[0].numel() ** -0.5 self.same_pad = is_dynamic self.eps = eps ** 2 if use_layernorm else eps self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel # NOTE an alternate formulation to consider, closer to DeepMind Haiku impl but doesn't seem # to make much numerical difference (+/- .002 to .004) in top-1 during eval. # def get_weight(self): # var, mean = torch.var_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) # scale = torch.rsqrt((self.weight[0].numel() * var).clamp_(self.eps)) * self.gain # weight = (self.weight - mean) * scale # return self.gain * weight def get_weight(self): if self.use_layernorm: weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps) else: std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) weight = self.scale * (self.weight - mean) / (std + self.eps) return self.gain * weight def forward(self, x): if self.same_pad: x = pad_same(x, self.kernel_size, self.stride, self.dilation) return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)