You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/models/layers/create_attn.py

35 lines
1.1 KiB

""" Select AttentionFactory Method
Hacked together by Ross Wightman
"""
import torch
from .se import SqueezeExcite, SqueezeExciteV2
from .eca import EfficientChannelAttn, CircularEfficientChannelAttn
from .cbam import ConvBlockAttn, LightConvBlockAttn
def create_attn(attn_type, channels, **kwargs):
module_cls = None
if attn_type is not None:
if isinstance(attn_type, str):
attn_type = attn_type.lower()
if attn_type == 'se':
module_cls = SqueezeExcite
elif attn_type == 'sev2':
module_cls = SqueezeExciteV2
elif attn_type == 'eca':
module_cls = EfficientChannelAttn
elif attn_type == 'ceca':
module_cls = CircularEfficientChannelAttn
elif attn_type == 'cbam':
module_cls = ConvBlockAttn
elif attn_type == 'lcbam':
module_cls = LightConvBlockAttn
else:
assert False, "Invalid attn module (%s)" % attn_type
else:
module_cls = attn_type
if module_cls is not None:
return module_cls(channels, **kwargs)
return None