Add 'fast' layer norm that doesn't cast to float32, support APEX LN impl for slight speed gain, update norm and act factories, tweak SE for ability to disable bias (needed by GCVit)

pull/1415/head
Ross Wightman 2 years ago
parent c486aa71f8
commit 43aa84e861

@ -11,11 +11,13 @@ from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct
from .create_act import create_act_layer, get_act_layer, get_act_fn from .create_act import create_act_layer, get_act_layer, get_act_fn
from .create_attn import get_attn, create_attn from .create_attn import get_attn, create_attn
from .create_conv2d import create_conv2d from .create_conv2d import create_conv2d
from .create_norm import get_norm_layer, create_norm_layer
from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\
EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm
from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
from .gather_excite import GatherExcite from .gather_excite import GatherExcite
from .global_context import GlobalContext from .global_context import GlobalContext
@ -25,7 +27,7 @@ from .linear import Linear
from .mixed_conv2d import MixedConv2d from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp
from .non_local_attn import NonLocalAttn, BatNonLocalAttn from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, GroupNorm1, LayerNorm2d from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
from .padding import get_padding, get_same_padding, pad_same from .padding import get_padding, get_same_padding, pad_same
from .patch_embed import PatchEmbed from .patch_embed import PatchEmbed

@ -145,4 +145,10 @@ def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs):
act_layer = get_act_layer(name) act_layer = get_act_layer(name)
if act_layer is None: if act_layer is None:
return None return None
return act_layer(**kwargs) if inplace is None else act_layer(inplace=inplace, **kwargs) if inplace is None:
return act_layer(**kwargs)
try:
return act_layer(inplace=inplace, **kwargs)
except TypeError:
# recover if act layer doesn't have inplace arg
return act_layer(**kwargs)

@ -0,0 +1,56 @@
""" Norm Layer Factory
Create norm modules by string (to mirror create_act and creat_norm-act fns)
Copyright 2022 Ross Wightman
"""
import types
import functools
import torch.nn as nn
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
_NORM_MAP = dict(
batchnorm=nn.BatchNorm2d,
batchnorm2d=nn.BatchNorm2d,
batchnorm1d=nn.BatchNorm1d,
groupnorm=GroupNorm,
groupnorm1=GroupNorm1,
layernorm=LayerNorm,
layernorm2d=LayerNorm2d,
)
_NORM_TYPES = {m for n, m in _NORM_MAP.items()}
def create_norm_layer(layer_name, num_features, act_layer=None, apply_act=True, **kwargs):
layer = get_norm_layer(layer_name, act_layer=act_layer)
layer_instance = layer(num_features, apply_act=apply_act, **kwargs)
return layer_instance
def get_norm_layer(norm_layer):
assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
norm_kwargs = {}
# unbind partial fn, so args can be rebound later
if isinstance(norm_layer, functools.partial):
norm_kwargs.update(norm_layer.keywords)
norm_layer = norm_layer.func
if isinstance(norm_layer, str):
layer_name = norm_layer.replace('_', '')
norm_layer = _NORM_MAP.get(layer_name, None)
elif norm_layer in _NORM_TYPES:
norm_layer = norm_layer
elif isinstance(norm_layer, types.FunctionType):
# if function type, assume it is a lambda/fn that creates a norm layer
norm_layer = norm_layer
else:
type_name = norm_layer.__name__.lower().replace('_', '')
norm_layer = _NORM_MAP.get(type_name, None)
assert norm_layer is not None, f"No equivalent norm layer for {type_name}"
if norm_kwargs:
norm_layer = functools.partial(norm_layer, **norm_kwargs) # bind/rebind args
return norm_layer

@ -0,0 +1,68 @@
from typing import List, Optional
import torch
from torch.nn import functional as F
try:
from apex.normalization.fused_layer_norm import fused_layer_norm_affine
has_apex = True
except ImportError:
has_apex = False
# fast (ie lower precision LN) can be disabled with this flag if issues crop up
_USE_FAST_NORM = False # defaulting to False for now
def is_fast_norm():
return _USE_FAST_NORM
def set_fast_norm(enable=True):
global _USE_FAST_NORM
_USE_FAST_NORM = enable
def fast_group_norm(
x: torch.Tensor,
num_groups: int,
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
eps: float = 1e-5
) -> torch.Tensor:
if torch.jit.is_scripting():
# currently cannot use is_autocast_enabled within torchscript
return F.group_norm(x, num_groups, weight, bias, eps)
if torch.is_autocast_enabled():
# normally native AMP casts GN inputs to float32
# here we use the low precision autocast dtype
dt = torch.get_autocast_gpu_dtype()
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt)
with torch.cuda.amp.autocast(enabled=False):
return F.group_norm(x, num_groups, weight, bias, eps)
def fast_layer_norm(
x: torch.Tensor,
normalized_shape: List[int],
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
eps: float = 1e-5
) -> torch.Tensor:
if torch.jit.is_scripting():
# currently cannot use is_autocast_enabled within torchscript
return F.layer_norm(x, normalized_shape, weight, bias, eps)
if has_apex:
return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps)
if torch.is_autocast_enabled():
# normally native AMP casts LN inputs to float32
# apex LN does not, this is behaving like Apex
dt = torch.get_autocast_gpu_dtype()
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt)
with torch.cuda.amp.autocast(enabled=False):
return F.layer_norm(x, normalized_shape, weight, bias, eps)

@ -1,16 +1,23 @@
""" Normalization layers and wrappers """ Normalization layers and wrappers
""" """
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm
class GroupNorm(nn.GroupNorm): class GroupNorm(nn.GroupNorm):
def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True): def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True):
# NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN
super().__init__(num_groups, num_channels, eps=eps, affine=affine) super().__init__(num_groups, num_channels, eps=eps, affine=affine)
self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
def forward(self, x): def forward(self, x):
if self.fast_norm:
return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
else:
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
@ -21,22 +28,48 @@ class GroupNorm1(nn.GroupNorm):
def __init__(self, num_channels, **kwargs): def __init__(self, num_channels, **kwargs):
super().__init__(1, num_channels, **kwargs) super().__init__(1, num_channels, **kwargs)
self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.fast_norm:
return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
else:
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
class LayerNorm(nn.LayerNorm):
""" LayerNorm w/ fast norm option
"""
def __init__(self, num_channels, eps=1e-6, affine=True):
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self._fast_norm:
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
else:
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
return x
class LayerNorm2d(nn.LayerNorm): class LayerNorm2d(nn.LayerNorm):
""" LayerNorm for channels of '2D' spatial NCHW tensors """ """ LayerNorm for channels of '2D' spatial NCHW tensors """
def __init__(self, num_channels, eps=1e-6, affine=True): def __init__(self, num_channels, eps=1e-6, affine=True):
super().__init__(num_channels, eps=eps, elementwise_affine=affine) super().__init__(num_channels, eps=eps, elementwise_affine=affine)
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.layer_norm( x = x.permute(0, 2, 3, 1)
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) if self._fast_norm:
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
else:
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
x = x.permute(0, 3, 1, 2)
return x
def _is_contiguous(tensor: torch.Tensor) -> bool: def _is_contiguous(tensor: torch.Tensor) -> bool:
# jit is oh so lovely :/ # jit is oh so lovely :/
# if torch.jit.is_tracing():
# return True
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return tensor.is_contiguous() return tensor.is_contiguous()
else: else:
@ -51,6 +84,14 @@ def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, ep
return x return x
def _layer_norm_cf_sqm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float):
u = x.mean(dim=1, keepdim=True)
s = ((x * x).mean(dim=1, keepdim=True) - (u * u)).clamp(0)
x = (x - u) * torch.rsqrt(s + eps)
x = x * weight.view(1, -1, 1, 1) + bias.view(1, -1, 1, 1)
return x
class LayerNormExp2d(nn.LayerNorm): class LayerNormExp2d(nn.LayerNorm):
""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W). """ LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).

@ -6,8 +6,9 @@ import torch
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from .trace_utils import _assert
from .create_act import get_act_layer from .create_act import get_act_layer
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm
from .trace_utils import _assert
class BatchNormAct2d(nn.BatchNorm2d): class BatchNormAct2d(nn.BatchNorm2d):
@ -177,8 +178,12 @@ class GroupNormAct(nn.GroupNorm):
self.act = act_layer(**act_args) self.act = act_layer(**act_args)
else: else:
self.act = nn.Identity() self.act = nn.Identity()
self._fast_norm = is_fast_norm()
def forward(self, x): def forward(self, x):
if self._fast_norm:
x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
else:
x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
x = self.drop(x) x = self.drop(x)
x = self.act(x) x = self.act(x)
@ -197,8 +202,12 @@ class LayerNormAct(nn.LayerNorm):
self.act = act_layer(**act_args) self.act = act_layer(**act_args)
else: else:
self.act = nn.Identity() self.act = nn.Identity()
self._fast_norm = is_fast_norm()
def forward(self, x): def forward(self, x):
if self._fast_norm:
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
else:
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
x = self.drop(x) x = self.drop(x)
x = self.act(x) x = self.act(x)
@ -219,8 +228,12 @@ class LayerNormAct2d(nn.LayerNorm):
self.act = nn.Identity() self.act = nn.Identity()
def forward(self, x): def forward(self, x):
x = F.layer_norm( x = x.permute(0, 2, 3, 1)
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) if self._fast_norm:
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
else:
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
x = x.permute(0, 3, 1, 2)
x = self.drop(x) x = self.drop(x)
x = self.act(x) x = self.act(x)
return x return x

@ -27,15 +27,15 @@ class SEModule(nn.Module):
""" """
def __init__( def __init__(
self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False, self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False,
act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'): bias=True, act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'):
super(SEModule, self).__init__() super(SEModule, self).__init__()
self.add_maxpool = add_maxpool self.add_maxpool = add_maxpool
if not rd_channels: if not rd_channels:
rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=True) self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=bias)
self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity() self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity()
self.act = create_act_layer(act_layer, inplace=True) self.act = create_act_layer(act_layer, inplace=True)
self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=True) self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=bias)
self.gate = create_act_layer(gate_layer) self.gate = create_act_layer(gate_layer)
def forward(self, x): def forward(self, x):

Loading…
Cancel
Save