""" Select AttentionFactory Method Hacked together by Ross Wightman """ import torch from .se import SqueezeExcite, SqueezeExciteV2 from .eca import EfficientChannelAttn, CircularEfficientChannelAttn from .cbam import ConvBlockAttn, LightConvBlockAttn def create_attn(attn_type, channels, **kwargs): module_cls = None if attn_type is not None: if isinstance(attn_type, str): attn_type = attn_type.lower() if attn_type == 'se': module_cls = SqueezeExcite elif attn_type == 'sev2': module_cls = SqueezeExciteV2 elif attn_type == 'eca': module_cls = EfficientChannelAttn elif attn_type == 'ceca': module_cls = CircularEfficientChannelAttn elif attn_type == 'cbam': module_cls = ConvBlockAttn elif attn_type == 'lcbam': module_cls = LightConvBlockAttn else: assert False, "Invalid attn module (%s)" % attn_type else: module_cls = attn_type if module_cls is not None: return module_cls(channels, **kwargs) return None