You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
92 lines
3.7 KiB
92 lines
3.7 KiB
4 years ago
|
""" NormAct (Normalizaiton + Activation Layer) Factory
|
||
|
|
||
|
Create norm + act combo modules that attempt to be backwards compatible with separate norm + act
|
||
|
isntances in models. Where these are used it will be possible to swap separate BN + act layers with
|
||
|
combined modules like IABN or EvoNorms.
|
||
|
|
||
|
Hacked together by / Copyright 2020 Ross Wightman
|
||
|
"""
|
||
5 years ago
|
import types
|
||
|
import functools
|
||
|
|
||
3 years ago
|
from .evo_norm import *
|
||
|
from .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d
|
||
3 years ago
|
from .norm_act import BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d
|
||
5 years ago
|
from .inplace_abn import InplaceAbn
|
||
5 years ago
|
|
||
3 years ago
|
_NORM_ACT_MAP = dict(
|
||
|
batchnorm=BatchNormAct2d,
|
||
3 years ago
|
batchnorm2d=BatchNormAct2d,
|
||
3 years ago
|
groupnorm=GroupNormAct,
|
||
2 years ago
|
groupnorm1=functools.partial(GroupNormAct, num_groups=1),
|
||
3 years ago
|
layernorm=LayerNormAct,
|
||
|
layernorm2d=LayerNormAct2d,
|
||
3 years ago
|
evonormb0=EvoNorm2dB0,
|
||
|
evonormb1=EvoNorm2dB1,
|
||
|
evonormb2=EvoNorm2dB2,
|
||
|
evonorms0=EvoNorm2dS0,
|
||
|
evonorms0a=EvoNorm2dS0a,
|
||
|
evonorms1=EvoNorm2dS1,
|
||
|
evonorms1a=EvoNorm2dS1a,
|
||
|
evonorms2=EvoNorm2dS2,
|
||
|
evonorms2a=EvoNorm2dS2a,
|
||
|
frn=FilterResponseNormAct2d,
|
||
|
frntlu=FilterResponseNormTlu2d,
|
||
|
inplaceabn=InplaceAbn,
|
||
|
iabn=InplaceAbn,
|
||
|
)
|
||
|
_NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()}
|
||
|
# has act_layer arg to define act type
|
||
3 years ago
|
_NORM_ACT_REQUIRES_ARG = {
|
||
|
BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d, FilterResponseNormAct2d, InplaceAbn}
|
||
3 years ago
|
|
||
|
|
||
3 years ago
|
def create_norm_act_layer(layer_name, num_features, act_layer=None, apply_act=True, jit=False, **kwargs):
|
||
|
layer = get_norm_act_layer(layer_name, act_layer=act_layer)
|
||
5 years ago
|
layer_instance = layer(num_features, apply_act=apply_act, **kwargs)
|
||
|
if jit:
|
||
|
layer_instance = torch.jit.script(layer_instance)
|
||
|
return layer_instance
|
||
|
|
||
|
|
||
3 years ago
|
def get_norm_act_layer(norm_layer, act_layer=None):
|
||
5 years ago
|
assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
|
||
|
assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial))
|
||
4 years ago
|
norm_act_kwargs = {}
|
||
|
|
||
|
# unbind partial fn, so args can be rebound later
|
||
|
if isinstance(norm_layer, functools.partial):
|
||
|
norm_act_kwargs.update(norm_layer.keywords)
|
||
|
norm_layer = norm_layer.func
|
||
|
|
||
5 years ago
|
if isinstance(norm_layer, str):
|
||
3 years ago
|
layer_name = norm_layer.replace('_', '').lower().split('-')[0]
|
||
|
norm_act_layer = _NORM_ACT_MAP.get(layer_name, None)
|
||
5 years ago
|
elif norm_layer in _NORM_ACT_TYPES:
|
||
|
norm_act_layer = norm_layer
|
||
4 years ago
|
elif isinstance(norm_layer, types.FunctionType):
|
||
|
# if function type, must be a lambda/fn that creates a norm_act layer
|
||
5 years ago
|
norm_act_layer = norm_layer
|
||
|
else:
|
||
|
type_name = norm_layer.__name__.lower()
|
||
|
if type_name.startswith('batchnorm'):
|
||
|
norm_act_layer = BatchNormAct2d
|
||
|
elif type_name.startswith('groupnorm'):
|
||
|
norm_act_layer = GroupNormAct
|
||
2 years ago
|
elif type_name.startswith('groupnorm1'):
|
||
|
norm_act_layer = functools.partial(GroupNormAct, num_groups=1)
|
||
3 years ago
|
elif type_name.startswith('layernorm2d'):
|
||
|
norm_act_layer = LayerNormAct2d
|
||
|
elif type_name.startswith('layernorm'):
|
||
|
norm_act_layer = LayerNormAct
|
||
5 years ago
|
else:
|
||
|
assert False, f"No equivalent norm_act layer for {type_name}"
|
||
4 years ago
|
|
||
5 years ago
|
if norm_act_layer in _NORM_ACT_REQUIRES_ARG:
|
||
4 years ago
|
# pass `act_layer` through for backwards compat where `act_layer=None` implies no activation.
|
||
5 years ago
|
# In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types
|
||
4 years ago
|
norm_act_kwargs.setdefault('act_layer', act_layer)
|
||
|
if norm_act_kwargs:
|
||
|
norm_act_layer = functools.partial(norm_act_layer, **norm_act_kwargs) # bind/rebind args
|
||
|
return norm_act_layer
|