""" Attention Factory Hacked together by / Copyright 2021 Ross Wightman """ import torch from functools import partial from .bottleneck_attn import BottleneckAttn from .cbam import CbamModule, LightCbamModule from .eca import EcaModule, CecaModule from .gather_excite import GatherExcite from .global_context import GlobalContext from .halo_attn import HaloAttn from .lambda_layer import LambdaLayer from .non_local_attn import NonLocalAttn, BatNonLocalAttn from .selective_kernel import SelectiveKernel from .split_attn import SplitAttn from .squeeze_excite import SEModule, EffectiveSEModule def get_attn(attn_type): if isinstance(attn_type, torch.nn.Module): return attn_type module_cls = None if attn_type: if isinstance(attn_type, str): attn_type = attn_type.lower() # Lightweight attention modules (channel and/or coarse spatial). # Typically added to existing network architecture blocks in addition to existing convolutions. if attn_type == 'se': module_cls = SEModule elif attn_type == 'ese': module_cls = EffectiveSEModule elif attn_type == 'eca': module_cls = EcaModule elif attn_type == 'ecam': module_cls = partial(EcaModule, use_mlp=True) elif attn_type == 'ceca': module_cls = CecaModule elif attn_type == 'ge': module_cls = GatherExcite elif attn_type == 'gc': module_cls = GlobalContext elif attn_type == 'gca': module_cls = partial(GlobalContext, fuse_add=True, fuse_scale=False) elif attn_type == 'cbam': module_cls = CbamModule elif attn_type == 'lcbam': module_cls = LightCbamModule # Attention / attention-like modules w/ significant params # Typically replace some of the existing workhorse convs in a network architecture. # All of these accept a stride argument and can spatially downsample the input. elif attn_type == 'sk': module_cls = SelectiveKernel elif attn_type == 'splat': module_cls = SplitAttn # Self-attention / attention-like modules w/ significant compute and/or params # Typically replace some of the existing workhorse convs in a network architecture. # All of these accept a stride argument and can spatially downsample the input. elif attn_type == 'lambda': return LambdaLayer elif attn_type == 'bottleneck': return BottleneckAttn elif attn_type == 'halo': return HaloAttn elif attn_type == 'nl': module_cls = NonLocalAttn elif attn_type == 'bat': module_cls = BatNonLocalAttn # Woops! else: assert False, "Invalid attn module (%s)" % attn_type elif isinstance(attn_type, bool): if attn_type: module_cls = SEModule else: module_cls = attn_type return module_cls def create_attn(attn_type, channels, **kwargs): module_cls = get_attn(attn_type) if module_cls is not None: # NOTE: it's expected the first (positional) argument of all attention layers is the # input channels return module_cls(channels, **kwargs) return None