""" Select AttentionFactory Method Hacked together by / Copyright 2020 Ross Wightman """ import torch from .cbam import CbamModule, LightCbamModule from .eca import EcaModule, CecaModule from .gather_excite import GatherExcite from .global_context import GlobalContext from .squeeze_excite import SEModule, EffectiveSEModule def get_attn(attn_type): if isinstance(attn_type, torch.nn.Module): return attn_type module_cls = None if attn_type is not None: if isinstance(attn_type, str): attn_type = attn_type.lower() if attn_type == 'se': module_cls = SEModule elif attn_type == 'ese': module_cls = EffectiveSEModule elif attn_type == 'eca': module_cls = EcaModule elif attn_type == 'ceca': module_cls = CecaModule elif attn_type == 'ge': module_cls = GatherExcite elif attn_type == 'gc': module_cls = GlobalContext elif attn_type == 'cbam': module_cls = CbamModule elif attn_type == 'lcbam': module_cls = LightCbamModule else: assert False, "Invalid attn module (%s)" % attn_type elif isinstance(attn_type, bool): if attn_type: module_cls = SEModule else: module_cls = attn_type return module_cls def create_attn(attn_type, channels, **kwargs): module_cls = get_attn(attn_type) if module_cls is not None: # NOTE: it's expected the first (positional) argument of all attention layers is the # input channels return module_cls(channels, **kwargs) return None