Add 'fast' layer norm that doesn't cast to float32, support APEX LN impl for slight speed gain, update norm and act factories, tweak SE for ability to disable bias (needed by GCVit)
parent
c486aa71f8
commit
43aa84e861
@ -0,0 +1,56 @@
|
|||||||
|
""" Norm Layer Factory
|
||||||
|
|
||||||
|
Create norm modules by string (to mirror create_act and creat_norm-act fns)
|
||||||
|
|
||||||
|
Copyright 2022 Ross Wightman
|
||||||
|
"""
|
||||||
|
import types
|
||||||
|
import functools
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
|
||||||
|
|
||||||
|
_NORM_MAP = dict(
|
||||||
|
batchnorm=nn.BatchNorm2d,
|
||||||
|
batchnorm2d=nn.BatchNorm2d,
|
||||||
|
batchnorm1d=nn.BatchNorm1d,
|
||||||
|
groupnorm=GroupNorm,
|
||||||
|
groupnorm1=GroupNorm1,
|
||||||
|
layernorm=LayerNorm,
|
||||||
|
layernorm2d=LayerNorm2d,
|
||||||
|
)
|
||||||
|
_NORM_TYPES = {m for n, m in _NORM_MAP.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def create_norm_layer(layer_name, num_features, act_layer=None, apply_act=True, **kwargs):
|
||||||
|
layer = get_norm_layer(layer_name, act_layer=act_layer)
|
||||||
|
layer_instance = layer(num_features, apply_act=apply_act, **kwargs)
|
||||||
|
return layer_instance
|
||||||
|
|
||||||
|
|
||||||
|
def get_norm_layer(norm_layer):
|
||||||
|
assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
|
||||||
|
norm_kwargs = {}
|
||||||
|
|
||||||
|
# unbind partial fn, so args can be rebound later
|
||||||
|
if isinstance(norm_layer, functools.partial):
|
||||||
|
norm_kwargs.update(norm_layer.keywords)
|
||||||
|
norm_layer = norm_layer.func
|
||||||
|
|
||||||
|
if isinstance(norm_layer, str):
|
||||||
|
layer_name = norm_layer.replace('_', '')
|
||||||
|
norm_layer = _NORM_MAP.get(layer_name, None)
|
||||||
|
elif norm_layer in _NORM_TYPES:
|
||||||
|
norm_layer = norm_layer
|
||||||
|
elif isinstance(norm_layer, types.FunctionType):
|
||||||
|
# if function type, assume it is a lambda/fn that creates a norm layer
|
||||||
|
norm_layer = norm_layer
|
||||||
|
else:
|
||||||
|
type_name = norm_layer.__name__.lower().replace('_', '')
|
||||||
|
norm_layer = _NORM_MAP.get(type_name, None)
|
||||||
|
assert norm_layer is not None, f"No equivalent norm layer for {type_name}"
|
||||||
|
|
||||||
|
if norm_kwargs:
|
||||||
|
norm_layer = functools.partial(norm_layer, **norm_kwargs) # bind/rebind args
|
||||||
|
return norm_layer
|
@ -0,0 +1,68 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
try:
|
||||||
|
from apex.normalization.fused_layer_norm import fused_layer_norm_affine
|
||||||
|
has_apex = True
|
||||||
|
except ImportError:
|
||||||
|
has_apex = False
|
||||||
|
|
||||||
|
|
||||||
|
# fast (ie lower precision LN) can be disabled with this flag if issues crop up
|
||||||
|
_USE_FAST_NORM = False # defaulting to False for now
|
||||||
|
|
||||||
|
|
||||||
|
def is_fast_norm():
|
||||||
|
return _USE_FAST_NORM
|
||||||
|
|
||||||
|
|
||||||
|
def set_fast_norm(enable=True):
|
||||||
|
global _USE_FAST_NORM
|
||||||
|
_USE_FAST_NORM = enable
|
||||||
|
|
||||||
|
|
||||||
|
def fast_group_norm(
|
||||||
|
x: torch.Tensor,
|
||||||
|
num_groups: int,
|
||||||
|
weight: Optional[torch.Tensor] = None,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
eps: float = 1e-5
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if torch.jit.is_scripting():
|
||||||
|
# currently cannot use is_autocast_enabled within torchscript
|
||||||
|
return F.group_norm(x, num_groups, weight, bias, eps)
|
||||||
|
|
||||||
|
if torch.is_autocast_enabled():
|
||||||
|
# normally native AMP casts GN inputs to float32
|
||||||
|
# here we use the low precision autocast dtype
|
||||||
|
dt = torch.get_autocast_gpu_dtype()
|
||||||
|
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt)
|
||||||
|
|
||||||
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
|
return F.group_norm(x, num_groups, weight, bias, eps)
|
||||||
|
|
||||||
|
|
||||||
|
def fast_layer_norm(
|
||||||
|
x: torch.Tensor,
|
||||||
|
normalized_shape: List[int],
|
||||||
|
weight: Optional[torch.Tensor] = None,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
eps: float = 1e-5
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if torch.jit.is_scripting():
|
||||||
|
# currently cannot use is_autocast_enabled within torchscript
|
||||||
|
return F.layer_norm(x, normalized_shape, weight, bias, eps)
|
||||||
|
|
||||||
|
if has_apex:
|
||||||
|
return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps)
|
||||||
|
|
||||||
|
if torch.is_autocast_enabled():
|
||||||
|
# normally native AMP casts LN inputs to float32
|
||||||
|
# apex LN does not, this is behaving like Apex
|
||||||
|
dt = torch.get_autocast_gpu_dtype()
|
||||||
|
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt)
|
||||||
|
|
||||||
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
|
return F.layer_norm(x, normalized_shape, weight, bias, eps)
|
Loading…
Reference in new issue