You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/models/layers/create_attn.py

53 lines
1.7 KiB

""" Select AttentionFactory Method
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
from .cbam import CbamModule, LightCbamModule
from .eca import EcaModule, CecaModule
from .gather_excite import GatherExcite
from .global_context import GlobalContext
from .squeeze_excite import SEModule, EffectiveSEModule
def get_attn(attn_type):
if isinstance(attn_type, torch.nn.Module):
return attn_type
module_cls = None
if attn_type is not None:
if isinstance(attn_type, str):
attn_type = attn_type.lower()
if attn_type == 'se':
module_cls = SEModule
elif attn_type == 'ese':
module_cls = EffectiveSEModule
elif attn_type == 'eca':
module_cls = EcaModule
elif attn_type == 'ceca':
module_cls = CecaModule
elif attn_type == 'ge':
module_cls = GatherExcite
elif attn_type == 'gc':
module_cls = GlobalContext
elif attn_type == 'cbam':
module_cls = CbamModule
elif attn_type == 'lcbam':
module_cls = LightCbamModule
else:
assert False, "Invalid attn module (%s)" % attn_type
elif isinstance(attn_type, bool):
if attn_type:
module_cls = SEModule
else:
module_cls = attn_type
return module_cls
def create_attn(attn_type, channels, **kwargs):
module_cls = get_attn(attn_type)
if module_cls is not None:
# NOTE: it's expected the first (positional) argument of all attention layers is the # input channels
return module_cls(channels, **kwargs)
return None