|
|
|
""" Select AttentionFactory Method
|
|
|
|
|
|
|
|
Hacked together by / Copyright 2020 Ross Wightman
|
|
|
|
"""
|
|
|
|
import torch
|
Monster commit, activation refactor, VoVNet, norm_act improvements, more
* refactor activations into basic PyTorch, jit scripted, and memory efficient custom auto
* implement hard-mish, better grad for hard-swish
* add initial VovNet V1/V2 impl, fix #151
* VovNet and DenseNet first models to use NormAct layers (support BatchNormAct2d, EvoNorm, InplaceIABN)
* Wrap IABN for any models that use it
* make more models torchscript compatible (DPN, PNasNet, Res2Net, SelecSLS) and add tests
5 years ago
|
|
|
from .se import SEModule, EffectiveSEModule
|
|
|
|
from .eca import EcaModule, CecaModule
|
|
|
|
from .cbam import CbamModule, LightCbamModule
|
|
|
|
|
|
|
|
|
|
|
|
def get_attn(attn_type):
|
|
|
|
module_cls = None
|
|
|
|
if attn_type is not None:
|
|
|
|
if isinstance(attn_type, str):
|
|
|
|
attn_type = attn_type.lower()
|
|
|
|
if attn_type == 'se':
|
|
|
|
module_cls = SEModule
|
Monster commit, activation refactor, VoVNet, norm_act improvements, more
* refactor activations into basic PyTorch, jit scripted, and memory efficient custom auto
* implement hard-mish, better grad for hard-swish
* add initial VovNet V1/V2 impl, fix #151
* VovNet and DenseNet first models to use NormAct layers (support BatchNormAct2d, EvoNorm, InplaceIABN)
* Wrap IABN for any models that use it
* make more models torchscript compatible (DPN, PNasNet, Res2Net, SelecSLS) and add tests
5 years ago
|
|
|
elif attn_type == 'ese':
|
|
|
|
module_cls = EffectiveSEModule
|
|
|
|
elif attn_type == 'eca':
|
|
|
|
module_cls = EcaModule
|
|
|
|
elif attn_type == 'ceca':
|
|
|
|
module_cls = CecaModule
|
|
|
|
elif attn_type == 'cbam':
|
|
|
|
module_cls = CbamModule
|
|
|
|
elif attn_type == 'lcbam':
|
|
|
|
module_cls = LightCbamModule
|
|
|
|
else:
|
|
|
|
assert False, "Invalid attn module (%s)" % attn_type
|
|
|
|
elif isinstance(attn_type, bool):
|
|
|
|
if attn_type:
|
|
|
|
module_cls = SEModule
|
|
|
|
else:
|
|
|
|
module_cls = attn_type
|
|
|
|
return module_cls
|
|
|
|
|
|
|
|
|
|
|
|
def create_attn(attn_type, channels, **kwargs):
|
|
|
|
module_cls = get_attn(attn_type)
|
|
|
|
if module_cls is not None:
|
|
|
|
# NOTE: it's expected the first (positional) argument of all attention layers is the # input channels
|
|
|
|
return module_cls(channels, **kwargs)
|
|
|
|
return None
|