""" Attention Factory

Hacked together by / Copyright 2021 Ross Wightman
"""
import torch
from functools import partial

from .bottleneck_attn import BottleneckAttn
from .cbam import CbamModule, LightCbamModule
from .eca import EcaModule, CecaModule
from .gather_excite import GatherExcite
from .global_context import GlobalContext
from .halo_attn import HaloAttn
from .lambda_layer import LambdaLayer
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .selective_kernel import SelectiveKernel
from .split_attn import SplitAttn
from .squeeze_excite import SEModule, EffectiveSEModule


def get_attn(attn_type):
    if isinstance(attn_type, torch.nn.Module):
        return attn_type
    module_cls = None
    if attn_type:
        if isinstance(attn_type, str):
            attn_type = attn_type.lower()
            # Lightweight attention modules (channel and/or coarse spatial).
            # Typically added to existing network architecture blocks in addition to existing convolutions.
            if attn_type == 'se':
                module_cls = SEModule
            elif attn_type == 'ese':
                module_cls = EffectiveSEModule
            elif attn_type == 'eca':
                module_cls = EcaModule
            elif attn_type == 'ecam':
                module_cls = partial(EcaModule, use_mlp=True)
            elif attn_type == 'ceca':
                module_cls = CecaModule
            elif attn_type == 'ge':
                module_cls = GatherExcite
            elif attn_type == 'gc':
                module_cls = GlobalContext
            elif attn_type == 'gca':
                module_cls = partial(GlobalContext, fuse_add=True, fuse_scale=False)
            elif attn_type == 'cbam':
                module_cls = CbamModule
            elif attn_type == 'lcbam':
                module_cls = LightCbamModule

            # Attention / attention-like modules w/ significant params
            # Typically replace some of the existing workhorse convs in a network architecture.
            # All of these accept a stride argument and can spatially downsample the input.
            elif attn_type == 'sk':
                module_cls = SelectiveKernel
            elif attn_type == 'splat':
                module_cls = SplitAttn

            # Self-attention / attention-like modules w/ significant compute and/or params
            # Typically replace some of the existing workhorse convs in a network architecture.
            # All of these accept a stride argument and can spatially downsample the input.
            elif attn_type == 'lambda':
                return LambdaLayer
            elif attn_type == 'bottleneck':
                return BottleneckAttn
            elif attn_type == 'halo':
                return HaloAttn
            elif attn_type == 'nl':
                module_cls = NonLocalAttn
            elif attn_type == 'bat':
                module_cls = BatNonLocalAttn

            # Woops!
            else:
                assert False, "Invalid attn module (%s)" % attn_type
        elif isinstance(attn_type, bool):
            if attn_type:
                module_cls = SEModule
        else:
            module_cls = attn_type
    return module_cls


def create_attn(attn_type, channels, **kwargs):
    module_cls = get_attn(attn_type)
    if module_cls is not None:
        # NOTE: it's expected the first (positional) argument of all attention layers is the # input channels
        return module_cls(channels, **kwargs)
    return None