""" NormAct (Normalizaiton + Activation Layer) Factory Create norm + act combo modules that attempt to be backwards compatible with separate norm + act isntances in models. Where these are used it will be possible to swap separate BN + act layers with combined modules like IABN or EvoNorms. Hacked together by / Copyright 2020 Ross Wightman """ import types import functools import torch import torch.nn as nn from .evo_norm import EvoNormBatch2d, EvoNormSample2d from .norm_act import BatchNormAct2d, GroupNormAct from .inplace_abn import InplaceAbn _NORM_ACT_TYPES = {BatchNormAct2d, GroupNormAct, EvoNormBatch2d, EvoNormSample2d, InplaceAbn} _NORM_ACT_REQUIRES_ARG = {BatchNormAct2d, GroupNormAct, InplaceAbn} # requires act_layer arg to define act type def get_norm_act_layer(layer_class): layer_class = layer_class.replace('_', '').lower() if layer_class.startswith("batchnorm"): layer = BatchNormAct2d elif layer_class.startswith("groupnorm"): layer = GroupNormAct elif layer_class == "evonormbatch": layer = EvoNormBatch2d elif layer_class == "evonormsample": layer = EvoNormSample2d elif layer_class == "iabn" or layer_class == "inplaceabn": layer = InplaceAbn else: assert False, "Invalid norm_act layer (%s)" % layer_class return layer def create_norm_act(layer_type, num_features, apply_act=True, jit=False, **kwargs): layer_parts = layer_type.split('-') # e.g. batchnorm-leaky_relu assert len(layer_parts) in (1, 2) layer = get_norm_act_layer(layer_parts[0]) #activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection? layer_instance = layer(num_features, apply_act=apply_act, **kwargs) if jit: layer_instance = torch.jit.script(layer_instance) return layer_instance def convert_norm_act_type(norm_layer, act_layer, norm_kwargs=None): assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial)) norm_act_args = norm_kwargs.copy() if norm_kwargs else {} if isinstance(norm_layer, str): norm_act_layer = get_norm_act_layer(norm_layer) elif norm_layer in _NORM_ACT_TYPES: norm_act_layer = norm_layer elif isinstance(norm_layer, (types.FunctionType, functools.partial)): # assuming this is a lambda/fn/bound partial that creates norm_act layer norm_act_layer = norm_layer else: type_name = norm_layer.__name__.lower() if type_name.startswith('batchnorm'): norm_act_layer = BatchNormAct2d elif type_name.startswith('groupnorm'): norm_act_layer = GroupNormAct else: assert False, f"No equivalent norm_act layer for {type_name}" if norm_act_layer in _NORM_ACT_REQUIRES_ARG: # Must pass `act_layer` through for backwards compat where `act_layer=None` implies no activation. # In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types # It is intended that functions/partial does not trigger this, they should define act. norm_act_args.update(dict(act_layer=act_layer)) return norm_act_layer, norm_act_args