From 43aa84e861b2df12f0ec92db493b4049b753311c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 17 Aug 2022 14:32:58 -0700 Subject: [PATCH] Add 'fast' layer norm that doesn't cast to float32, support APEX LN impl for slight speed gain, update norm and act factories, tweak SE for ability to disable bias (needed by GCVit) --- timm/models/layers/__init__.py | 4 +- timm/models/layers/create_act.py | 8 +++- timm/models/layers/create_norm.py | 56 +++++++++++++++++++++++ timm/models/layers/fast_norm.py | 68 ++++++++++++++++++++++++++++ timm/models/layers/norm.py | 51 +++++++++++++++++++-- timm/models/layers/norm_act.py | 23 ++++++++-- timm/models/layers/squeeze_excite.py | 6 +-- 7 files changed, 201 insertions(+), 15 deletions(-) create mode 100644 timm/models/layers/create_norm.py create mode 100644 timm/models/layers/fast_norm.py diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index b9eeec0f..071da7bc 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -11,11 +11,13 @@ from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct from .create_act import create_act_layer, get_act_layer, get_act_fn from .create_attn import get_attn, create_attn from .create_conv2d import create_conv2d +from .create_norm import get_norm_layer, create_norm_layer from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a +from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d from .gather_excite import GatherExcite from .global_context import GlobalContext @@ -25,7 +27,7 @@ from .linear import Linear from .mixed_conv2d import MixedConv2d from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp from .non_local_attn import NonLocalAttn, BatNonLocalAttn -from .norm import GroupNorm, GroupNorm1, LayerNorm2d +from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm from .padding import get_padding, get_same_padding, pad_same from .patch_embed import PatchEmbed diff --git a/timm/models/layers/create_act.py b/timm/models/layers/create_act.py index e38f2e03..a3044a3d 100644 --- a/timm/models/layers/create_act.py +++ b/timm/models/layers/create_act.py @@ -145,4 +145,10 @@ def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs): act_layer = get_act_layer(name) if act_layer is None: return None - return act_layer(**kwargs) if inplace is None else act_layer(inplace=inplace, **kwargs) + if inplace is None: + return act_layer(**kwargs) + try: + return act_layer(inplace=inplace, **kwargs) + except TypeError: + # recover if act layer doesn't have inplace arg + return act_layer(**kwargs) diff --git a/timm/models/layers/create_norm.py b/timm/models/layers/create_norm.py new file mode 100644 index 00000000..b9efae8c --- /dev/null +++ b/timm/models/layers/create_norm.py @@ -0,0 +1,56 @@ +""" Norm Layer Factory + +Create norm modules by string (to mirror create_act and creat_norm-act fns) + +Copyright 2022 Ross Wightman +""" +import types +import functools + +import torch.nn as nn + +from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d + +_NORM_MAP = dict( + batchnorm=nn.BatchNorm2d, + batchnorm2d=nn.BatchNorm2d, + batchnorm1d=nn.BatchNorm1d, + groupnorm=GroupNorm, + groupnorm1=GroupNorm1, + layernorm=LayerNorm, + layernorm2d=LayerNorm2d, +) +_NORM_TYPES = {m for n, m in _NORM_MAP.items()} + + +def create_norm_layer(layer_name, num_features, act_layer=None, apply_act=True, **kwargs): + layer = get_norm_layer(layer_name, act_layer=act_layer) + layer_instance = layer(num_features, apply_act=apply_act, **kwargs) + return layer_instance + + +def get_norm_layer(norm_layer): + assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) + norm_kwargs = {} + + # unbind partial fn, so args can be rebound later + if isinstance(norm_layer, functools.partial): + norm_kwargs.update(norm_layer.keywords) + norm_layer = norm_layer.func + + if isinstance(norm_layer, str): + layer_name = norm_layer.replace('_', '') + norm_layer = _NORM_MAP.get(layer_name, None) + elif norm_layer in _NORM_TYPES: + norm_layer = norm_layer + elif isinstance(norm_layer, types.FunctionType): + # if function type, assume it is a lambda/fn that creates a norm layer + norm_layer = norm_layer + else: + type_name = norm_layer.__name__.lower().replace('_', '') + norm_layer = _NORM_MAP.get(type_name, None) + assert norm_layer is not None, f"No equivalent norm layer for {type_name}" + + if norm_kwargs: + norm_layer = functools.partial(norm_layer, **norm_kwargs) # bind/rebind args + return norm_layer diff --git a/timm/models/layers/fast_norm.py b/timm/models/layers/fast_norm.py new file mode 100644 index 00000000..9a34a15e --- /dev/null +++ b/timm/models/layers/fast_norm.py @@ -0,0 +1,68 @@ +from typing import List, Optional + +import torch +from torch.nn import functional as F + +try: + from apex.normalization.fused_layer_norm import fused_layer_norm_affine + has_apex = True +except ImportError: + has_apex = False + + +# fast (ie lower precision LN) can be disabled with this flag if issues crop up +_USE_FAST_NORM = False # defaulting to False for now + + +def is_fast_norm(): + return _USE_FAST_NORM + + +def set_fast_norm(enable=True): + global _USE_FAST_NORM + _USE_FAST_NORM = enable + + +def fast_group_norm( + x: torch.Tensor, + num_groups: int, + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + eps: float = 1e-5 +) -> torch.Tensor: + if torch.jit.is_scripting(): + # currently cannot use is_autocast_enabled within torchscript + return F.group_norm(x, num_groups, weight, bias, eps) + + if torch.is_autocast_enabled(): + # normally native AMP casts GN inputs to float32 + # here we use the low precision autocast dtype + dt = torch.get_autocast_gpu_dtype() + x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) + + with torch.cuda.amp.autocast(enabled=False): + return F.group_norm(x, num_groups, weight, bias, eps) + + +def fast_layer_norm( + x: torch.Tensor, + normalized_shape: List[int], + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + eps: float = 1e-5 +) -> torch.Tensor: + if torch.jit.is_scripting(): + # currently cannot use is_autocast_enabled within torchscript + return F.layer_norm(x, normalized_shape, weight, bias, eps) + + if has_apex: + return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps) + + if torch.is_autocast_enabled(): + # normally native AMP casts LN inputs to float32 + # apex LN does not, this is behaving like Apex + dt = torch.get_autocast_gpu_dtype() + x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) + + with torch.cuda.amp.autocast(enabled=False): + return F.layer_norm(x, normalized_shape, weight, bias, eps) diff --git a/timm/models/layers/norm.py b/timm/models/layers/norm.py index 1677dbfa..2ff8fc08 100644 --- a/timm/models/layers/norm.py +++ b/timm/models/layers/norm.py @@ -1,17 +1,24 @@ """ Normalization layers and wrappers """ + import torch import torch.nn as nn import torch.nn.functional as F +from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm + class GroupNorm(nn.GroupNorm): def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True): # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN super().__init__(num_groups, num_channels, eps=eps, affine=affine) + self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) def forward(self, x): - return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + if self.fast_norm: + return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + else: + return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) class GroupNorm1(nn.GroupNorm): @@ -21,22 +28,48 @@ class GroupNorm1(nn.GroupNorm): def __init__(self, num_channels, **kwargs): super().__init__(1, num_channels, **kwargs) + self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.fast_norm: + return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + else: + return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + + +class LayerNorm(nn.LayerNorm): + """ LayerNorm w/ fast norm option + """ + def __init__(self, num_channels, eps=1e-6, affine=True): + super().__init__(num_channels, eps=eps, elementwise_affine=affine) + self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self._fast_norm: + x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + else: + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return x class LayerNorm2d(nn.LayerNorm): """ LayerNorm for channels of '2D' spatial NCHW tensors """ def __init__(self, num_channels, eps=1e-6, affine=True): super().__init__(num_channels, eps=eps, elementwise_affine=affine) + self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) def forward(self, x: torch.Tensor) -> torch.Tensor: - return F.layer_norm( - x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) + x = x.permute(0, 2, 3, 1) + if self._fast_norm: + x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + else: + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + x = x.permute(0, 3, 1, 2) + return x def _is_contiguous(tensor: torch.Tensor) -> bool: # jit is oh so lovely :/ - # if torch.jit.is_tracing(): - # return True if torch.jit.is_scripting(): return tensor.is_contiguous() else: @@ -51,6 +84,14 @@ def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, ep return x +def _layer_norm_cf_sqm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float): + u = x.mean(dim=1, keepdim=True) + s = ((x * x).mean(dim=1, keepdim=True) - (u * u)).clamp(0) + x = (x - u) * torch.rsqrt(s + eps) + x = x * weight.view(1, -1, 1, 1) + bias.view(1, -1, 1, 1) + return x + + class LayerNormExp2d(nn.LayerNorm): """ LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W). diff --git a/timm/models/layers/norm_act.py b/timm/models/layers/norm_act.py index ea5b7883..be1edead 100644 --- a/timm/models/layers/norm_act.py +++ b/timm/models/layers/norm_act.py @@ -6,8 +6,9 @@ import torch from torch import nn as nn from torch.nn import functional as F -from .trace_utils import _assert from .create_act import get_act_layer +from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm +from .trace_utils import _assert class BatchNormAct2d(nn.BatchNorm2d): @@ -177,9 +178,13 @@ class GroupNormAct(nn.GroupNorm): self.act = act_layer(**act_args) else: self.act = nn.Identity() + self._fast_norm = is_fast_norm() def forward(self, x): - x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + if self._fast_norm: + x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + else: + x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) x = self.drop(x) x = self.act(x) return x @@ -197,9 +202,13 @@ class LayerNormAct(nn.LayerNorm): self.act = act_layer(**act_args) else: self.act = nn.Identity() + self._fast_norm = is_fast_norm() def forward(self, x): - x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + if self._fast_norm: + x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + else: + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) x = self.drop(x) x = self.act(x) return x @@ -219,8 +228,12 @@ class LayerNormAct2d(nn.LayerNorm): self.act = nn.Identity() def forward(self, x): - x = F.layer_norm( - x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) + x = x.permute(0, 2, 3, 1) + if self._fast_norm: + x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + else: + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + x = x.permute(0, 3, 1, 2) x = self.drop(x) x = self.act(x) return x diff --git a/timm/models/layers/squeeze_excite.py b/timm/models/layers/squeeze_excite.py index e5da29ef..2e41d956 100644 --- a/timm/models/layers/squeeze_excite.py +++ b/timm/models/layers/squeeze_excite.py @@ -27,15 +27,15 @@ class SEModule(nn.Module): """ def __init__( self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False, - act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'): + bias=True, act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'): super(SEModule, self).__init__() self.add_maxpool = add_maxpool if not rd_channels: rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) - self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=True) + self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=bias) self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity() self.act = create_act_layer(act_layer, inplace=True) - self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=True) + self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=bias) self.gate = create_act_layer(gate_layer) def forward(self, x):