You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
118 lines
4.4 KiB
118 lines
4.4 KiB
4 years ago
|
""" Normalization layers and wrappers
|
||
2 years ago
|
|
||
|
Norm layer definitions that support fast norm and consistent channel arg order (always first arg).
|
||
|
|
||
|
Hacked together by / Copyright 2022 Ross Wightman
|
||
4 years ago
|
"""
|
||
2 years ago
|
|
||
4 years ago
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
|
||
2 years ago
|
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm
|
||
|
|
||
4 years ago
|
|
||
|
class GroupNorm(nn.GroupNorm):
|
||
3 years ago
|
def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True):
|
||
4 years ago
|
# NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN
|
||
|
super().__init__(num_groups, num_channels, eps=eps, affine=affine)
|
||
2 years ago
|
self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
|
||
4 years ago
|
|
||
|
def forward(self, x):
|
||
2 years ago
|
if self.fast_norm:
|
||
|
return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
|
||
|
else:
|
||
|
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
|
||
4 years ago
|
|
||
|
|
||
2 years ago
|
class GroupNorm1(nn.GroupNorm):
|
||
|
""" Group Normalization with 1 group.
|
||
|
Input: tensor in shape [B, C, *]
|
||
|
"""
|
||
|
|
||
|
def __init__(self, num_channels, **kwargs):
|
||
|
super().__init__(1, num_channels, **kwargs)
|
||
2 years ago
|
self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
|
||
|
|
||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
|
if self.fast_norm:
|
||
|
return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
|
||
|
else:
|
||
|
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
|
||
|
|
||
|
|
||
|
class LayerNorm(nn.LayerNorm):
|
||
|
""" LayerNorm w/ fast norm option
|
||
|
"""
|
||
|
def __init__(self, num_channels, eps=1e-6, affine=True):
|
||
|
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
|
||
|
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
|
||
|
|
||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
|
if self._fast_norm:
|
||
2 years ago
|
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||
2 years ago
|
else:
|
||
|
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||
|
return x
|
||
2 years ago
|
|
||
|
|
||
4 years ago
|
class LayerNorm2d(nn.LayerNorm):
|
||
2 years ago
|
""" LayerNorm for channels of '2D' spatial NCHW tensors """
|
||
|
def __init__(self, num_channels, eps=1e-6, affine=True):
|
||
|
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
|
||
2 years ago
|
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
|
||
4 years ago
|
|
||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
2 years ago
|
x = x.permute(0, 2, 3, 1)
|
||
|
if self._fast_norm:
|
||
2 years ago
|
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||
2 years ago
|
else:
|
||
|
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||
|
x = x.permute(0, 3, 1, 2)
|
||
|
return x
|
||
2 years ago
|
|
||
|
|
||
|
def _is_contiguous(tensor: torch.Tensor) -> bool:
|
||
|
# jit is oh so lovely :/
|
||
|
if torch.jit.is_scripting():
|
||
|
return tensor.is_contiguous()
|
||
|
else:
|
||
|
return tensor.is_contiguous(memory_format=torch.contiguous_format)
|
||
|
|
||
|
|
||
|
@torch.jit.script
|
||
|
def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float):
|
||
|
s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True)
|
||
|
x = (x - u) * torch.rsqrt(s + eps)
|
||
|
x = x * weight[:, None, None] + bias[:, None, None]
|
||
|
return x
|
||
|
|
||
|
|
||
2 years ago
|
def _layer_norm_cf_sqm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float):
|
||
|
u = x.mean(dim=1, keepdim=True)
|
||
|
s = ((x * x).mean(dim=1, keepdim=True) - (u * u)).clamp(0)
|
||
|
x = (x - u) * torch.rsqrt(s + eps)
|
||
|
x = x * weight.view(1, -1, 1, 1) + bias.view(1, -1, 1, 1)
|
||
|
return x
|
||
|
|
||
|
|
||
2 years ago
|
class LayerNormExp2d(nn.LayerNorm):
|
||
|
""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).
|
||
|
|
||
|
Experimental implementation w/ manual norm for tensors non-contiguous tensors.
|
||
|
|
||
|
This improves throughput in some scenarios (tested on Ampere GPU), esp w/ channels_last
|
||
|
layout. However, benefits are not always clear and can perform worse on other GPUs.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, num_channels, eps=1e-6):
|
||
|
super().__init__(num_channels, eps=eps)
|
||
|
|
||
|
def forward(self, x) -> torch.Tensor:
|
||
|
if _is_contiguous(x):
|
||
|
x = F.layer_norm(
|
||
|
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
|
||
|
else:
|
||
|
x = _layer_norm_cf(x, self.weight, self.bias, self.eps)
|
||
|
return x
|