You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
95 lines
3.8 KiB
95 lines
3.8 KiB
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from .padding import get_padding
|
|
from .conv2d_same import conv2d_same
|
|
|
|
|
|
def get_weight(module):
|
|
std, mean = torch.std_mean(module.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
|
weight = (module.weight - mean) / (std + module.eps)
|
|
return weight
|
|
|
|
|
|
class StdConv2d(nn.Conv2d):
|
|
"""Conv2d with Weight Standardization. Used for BiT ResNet-V2 models.
|
|
|
|
Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` -
|
|
https://arxiv.org/abs/1903.10520v2
|
|
"""
|
|
def __init__(
|
|
self, in_channel, out_channels, kernel_size, stride=1,
|
|
padding=None, dilation=1, groups=1, bias=False, eps=1e-5):
|
|
if padding is None:
|
|
padding = get_padding(kernel_size, stride, dilation)
|
|
super().__init__(
|
|
in_channel, out_channels, kernel_size, stride=stride,
|
|
padding=padding, dilation=dilation, groups=groups, bias=bias)
|
|
self.eps = eps
|
|
|
|
def get_weight(self):
|
|
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
|
weight = (self.weight - mean) / (std + self.eps)
|
|
return weight
|
|
|
|
def forward(self, x):
|
|
x = F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
|
return x
|
|
|
|
|
|
class StdConv2dSame(nn.Conv2d):
|
|
"""Conv2d with Weight Standardization. TF compatible SAME padding. Used for ViT Hybrid model.
|
|
|
|
Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` -
|
|
https://arxiv.org/abs/1903.10520v2
|
|
"""
|
|
def __init__(
|
|
self, in_channel, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=False, eps=1e-5):
|
|
super().__init__(
|
|
in_channel, out_channels, kernel_size, stride=stride,
|
|
padding=0, dilation=dilation, groups=groups, bias=bias)
|
|
self.eps = eps
|
|
|
|
def get_weight(self):
|
|
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
|
weight = (self.weight - mean) / (std + self.eps)
|
|
return weight
|
|
|
|
def forward(self, x):
|
|
x = conv2d_same(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
|
return x
|
|
|
|
|
|
class ScaledStdConv2d(nn.Conv2d):
|
|
"""Conv2d layer with Scaled Weight Standardization.
|
|
|
|
Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` -
|
|
https://arxiv.org/abs/2101.08692
|
|
"""
|
|
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1,
|
|
bias=True, gain=True, gamma=1.0, eps=1e-5, use_layernorm=False):
|
|
if padding is None:
|
|
padding = get_padding(kernel_size, stride, dilation)
|
|
super().__init__(
|
|
in_channels, out_channels, kernel_size, stride=stride,
|
|
padding=padding, dilation=dilation, groups=groups, bias=bias)
|
|
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) if gain else None
|
|
self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in)
|
|
self.eps = eps ** 2 if use_layernorm else eps
|
|
self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory use
|
|
|
|
def get_weight(self):
|
|
if self.use_layernorm:
|
|
weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
|
|
else:
|
|
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
|
weight = self.scale * (self.weight - mean) / (std + self.eps)
|
|
if self.gain is not None:
|
|
weight = weight * self.gain
|
|
return weight
|
|
|
|
def forward(self, x):
|
|
return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
|