diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index b9eeec0f..071da7bc 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -11,11 +11,13 @@ from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct from .create_act import create_act_layer, get_act_layer, get_act_fn from .create_attn import get_attn, create_attn from .create_conv2d import create_conv2d +from .create_norm import get_norm_layer, create_norm_layer from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a +from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d from .gather_excite import GatherExcite from .global_context import GlobalContext @@ -25,7 +27,7 @@ from .linear import Linear from .mixed_conv2d import MixedConv2d from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp from .non_local_attn import NonLocalAttn, BatNonLocalAttn -from .norm import GroupNorm, GroupNorm1, LayerNorm2d +from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm from .padding import get_padding, get_same_padding, pad_same from .patch_embed import PatchEmbed diff --git a/timm/models/layers/create_act.py b/timm/models/layers/create_act.py index e38f2e03..a3044a3d 100644 --- a/timm/models/layers/create_act.py +++ b/timm/models/layers/create_act.py @@ -145,4 +145,10 @@ def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs): act_layer = get_act_layer(name) if act_layer is None: return None - return act_layer(**kwargs) if inplace is None else act_layer(inplace=inplace, **kwargs) + if inplace is None: + return act_layer(**kwargs) + try: + return act_layer(inplace=inplace, **kwargs) + except TypeError: + # recover if act layer doesn't have inplace arg + return act_layer(**kwargs) diff --git a/timm/models/layers/create_norm.py b/timm/models/layers/create_norm.py new file mode 100644 index 00000000..b9efae8c --- /dev/null +++ b/timm/models/layers/create_norm.py @@ -0,0 +1,56 @@ +""" Norm Layer Factory + +Create norm modules by string (to mirror create_act and creat_norm-act fns) + +Copyright 2022 Ross Wightman +""" +import types +import functools + +import torch.nn as nn + +from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d + +_NORM_MAP = dict( + batchnorm=nn.BatchNorm2d, + batchnorm2d=nn.BatchNorm2d, + batchnorm1d=nn.BatchNorm1d, + groupnorm=GroupNorm, + groupnorm1=GroupNorm1, + layernorm=LayerNorm, + layernorm2d=LayerNorm2d, +) +_NORM_TYPES = {m for n, m in _NORM_MAP.items()} + + +def create_norm_layer(layer_name, num_features, act_layer=None, apply_act=True, **kwargs): + layer = get_norm_layer(layer_name, act_layer=act_layer) + layer_instance = layer(num_features, apply_act=apply_act, **kwargs) + return layer_instance + + +def get_norm_layer(norm_layer): + assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) + norm_kwargs = {} + + # unbind partial fn, so args can be rebound later + if isinstance(norm_layer, functools.partial): + norm_kwargs.update(norm_layer.keywords) + norm_layer = norm_layer.func + + if isinstance(norm_layer, str): + layer_name = norm_layer.replace('_', '') + norm_layer = _NORM_MAP.get(layer_name, None) + elif norm_layer in _NORM_TYPES: + norm_layer = norm_layer + elif isinstance(norm_layer, types.FunctionType): + # if function type, assume it is a lambda/fn that creates a norm layer + norm_layer = norm_layer + else: + type_name = norm_layer.__name__.lower().replace('_', '') + norm_layer = _NORM_MAP.get(type_name, None) + assert norm_layer is not None, f"No equivalent norm layer for {type_name}" + + if norm_kwargs: + norm_layer = functools.partial(norm_layer, **norm_kwargs) # bind/rebind args + return norm_layer diff --git a/timm/models/layers/fast_norm.py b/timm/models/layers/fast_norm.py new file mode 100644 index 00000000..9a34a15e --- /dev/null +++ b/timm/models/layers/fast_norm.py @@ -0,0 +1,68 @@ +from typing import List, Optional + +import torch +from torch.nn import functional as F + +try: + from apex.normalization.fused_layer_norm import fused_layer_norm_affine + has_apex = True +except ImportError: + has_apex = False + + +# fast (ie lower precision LN) can be disabled with this flag if issues crop up +_USE_FAST_NORM = False # defaulting to False for now + + +def is_fast_norm(): + return _USE_FAST_NORM + + +def set_fast_norm(enable=True): + global _USE_FAST_NORM + _USE_FAST_NORM = enable + + +def fast_group_norm( + x: torch.Tensor, + num_groups: int, + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + eps: float = 1e-5 +) -> torch.Tensor: + if torch.jit.is_scripting(): + # currently cannot use is_autocast_enabled within torchscript + return F.group_norm(x, num_groups, weight, bias, eps) + + if torch.is_autocast_enabled(): + # normally native AMP casts GN inputs to float32 + # here we use the low precision autocast dtype + dt = torch.get_autocast_gpu_dtype() + x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) + + with torch.cuda.amp.autocast(enabled=False): + return F.group_norm(x, num_groups, weight, bias, eps) + + +def fast_layer_norm( + x: torch.Tensor, + normalized_shape: List[int], + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + eps: float = 1e-5 +) -> torch.Tensor: + if torch.jit.is_scripting(): + # currently cannot use is_autocast_enabled within torchscript + return F.layer_norm(x, normalized_shape, weight, bias, eps) + + if has_apex: + return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps) + + if torch.is_autocast_enabled(): + # normally native AMP casts LN inputs to float32 + # apex LN does not, this is behaving like Apex + dt = torch.get_autocast_gpu_dtype() + x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) + + with torch.cuda.amp.autocast(enabled=False): + return F.layer_norm(x, normalized_shape, weight, bias, eps) diff --git a/timm/models/layers/norm.py b/timm/models/layers/norm.py index 1677dbfa..2ff8fc08 100644 --- a/timm/models/layers/norm.py +++ b/timm/models/layers/norm.py @@ -1,17 +1,24 @@ """ Normalization layers and wrappers """ + import torch import torch.nn as nn import torch.nn.functional as F +from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm + class GroupNorm(nn.GroupNorm): def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True): # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN super().__init__(num_groups, num_channels, eps=eps, affine=affine) + self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) def forward(self, x): - return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + if self.fast_norm: + return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + else: + return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) class GroupNorm1(nn.GroupNorm): @@ -21,22 +28,48 @@ class GroupNorm1(nn.GroupNorm): def __init__(self, num_channels, **kwargs): super().__init__(1, num_channels, **kwargs) + self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.fast_norm: + return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + else: + return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + + +class LayerNorm(nn.LayerNorm): + """ LayerNorm w/ fast norm option + """ + def __init__(self, num_channels, eps=1e-6, affine=True): + super().__init__(num_channels, eps=eps, elementwise_affine=affine) + self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self._fast_norm: + x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + else: + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return x class LayerNorm2d(nn.LayerNorm): """ LayerNorm for channels of '2D' spatial NCHW tensors """ def __init__(self, num_channels, eps=1e-6, affine=True): super().__init__(num_channels, eps=eps, elementwise_affine=affine) + self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) def forward(self, x: torch.Tensor) -> torch.Tensor: - return F.layer_norm( - x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) + x = x.permute(0, 2, 3, 1) + if self._fast_norm: + x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + else: + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + x = x.permute(0, 3, 1, 2) + return x def _is_contiguous(tensor: torch.Tensor) -> bool: # jit is oh so lovely :/ - # if torch.jit.is_tracing(): - # return True if torch.jit.is_scripting(): return tensor.is_contiguous() else: @@ -51,6 +84,14 @@ def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, ep return x +def _layer_norm_cf_sqm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float): + u = x.mean(dim=1, keepdim=True) + s = ((x * x).mean(dim=1, keepdim=True) - (u * u)).clamp(0) + x = (x - u) * torch.rsqrt(s + eps) + x = x * weight.view(1, -1, 1, 1) + bias.view(1, -1, 1, 1) + return x + + class LayerNormExp2d(nn.LayerNorm): """ LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W). diff --git a/timm/models/layers/norm_act.py b/timm/models/layers/norm_act.py index ea5b7883..be1edead 100644 --- a/timm/models/layers/norm_act.py +++ b/timm/models/layers/norm_act.py @@ -6,8 +6,9 @@ import torch from torch import nn as nn from torch.nn import functional as F -from .trace_utils import _assert from .create_act import get_act_layer +from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm +from .trace_utils import _assert class BatchNormAct2d(nn.BatchNorm2d): @@ -177,9 +178,13 @@ class GroupNormAct(nn.GroupNorm): self.act = act_layer(**act_args) else: self.act = nn.Identity() + self._fast_norm = is_fast_norm() def forward(self, x): - x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + if self._fast_norm: + x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + else: + x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) x = self.drop(x) x = self.act(x) return x @@ -197,9 +202,13 @@ class LayerNormAct(nn.LayerNorm): self.act = act_layer(**act_args) else: self.act = nn.Identity() + self._fast_norm = is_fast_norm() def forward(self, x): - x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + if self._fast_norm: + x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + else: + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) x = self.drop(x) x = self.act(x) return x @@ -219,8 +228,12 @@ class LayerNormAct2d(nn.LayerNorm): self.act = nn.Identity() def forward(self, x): - x = F.layer_norm( - x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) + x = x.permute(0, 2, 3, 1) + if self._fast_norm: + x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + else: + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + x = x.permute(0, 3, 1, 2) x = self.drop(x) x = self.act(x) return x diff --git a/timm/models/layers/squeeze_excite.py b/timm/models/layers/squeeze_excite.py index e5da29ef..2e41d956 100644 --- a/timm/models/layers/squeeze_excite.py +++ b/timm/models/layers/squeeze_excite.py @@ -27,15 +27,15 @@ class SEModule(nn.Module): """ def __init__( self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False, - act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'): + bias=True, act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'): super(SEModule, self).__init__() self.add_maxpool = add_maxpool if not rd_channels: rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) - self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=True) + self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=bias) self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity() self.act = create_act_layer(act_layer, inplace=True) - self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=True) + self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=bias) self.gate = create_act_layer(gate_layer) def forward(self, x):