You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
90 lines
3.4 KiB
90 lines
3.4 KiB
4 years ago
|
""" Attention Factory
|
||
5 years ago
|
|
||
4 years ago
|
Hacked together by / Copyright 2021 Ross Wightman
|
||
5 years ago
|
"""
|
||
|
import torch
|
||
4 years ago
|
from functools import partial
|
||
4 years ago
|
|
||
4 years ago
|
from .bottleneck_attn import BottleneckAttn
|
||
5 years ago
|
from .cbam import CbamModule, LightCbamModule
|
||
4 years ago
|
from .eca import EcaModule, CecaModule
|
||
|
from .gather_excite import GatherExcite
|
||
|
from .global_context import GlobalContext
|
||
4 years ago
|
from .halo_attn import HaloAttn
|
||
|
from .lambda_layer import LambdaLayer
|
||
|
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
|
||
|
from .selective_kernel import SelectiveKernel
|
||
|
from .split_attn import SplitAttn
|
||
4 years ago
|
from .squeeze_excite import SEModule, EffectiveSEModule
|
||
5 years ago
|
|
||
|
|
||
4 years ago
|
def get_attn(attn_type):
|
||
4 years ago
|
if isinstance(attn_type, torch.nn.Module):
|
||
|
return attn_type
|
||
5 years ago
|
module_cls = None
|
||
2 years ago
|
if attn_type:
|
||
5 years ago
|
if isinstance(attn_type, str):
|
||
|
attn_type = attn_type.lower()
|
||
4 years ago
|
# Lightweight attention modules (channel and/or coarse spatial).
|
||
|
# Typically added to existing network architecture blocks in addition to existing convolutions.
|
||
5 years ago
|
if attn_type == 'se':
|
||
|
module_cls = SEModule
|
||
5 years ago
|
elif attn_type == 'ese':
|
||
|
module_cls = EffectiveSEModule
|
||
5 years ago
|
elif attn_type == 'eca':
|
||
|
module_cls = EcaModule
|
||
4 years ago
|
elif attn_type == 'ecam':
|
||
|
module_cls = partial(EcaModule, use_mlp=True)
|
||
5 years ago
|
elif attn_type == 'ceca':
|
||
5 years ago
|
module_cls = CecaModule
|
||
4 years ago
|
elif attn_type == 'ge':
|
||
|
module_cls = GatherExcite
|
||
|
elif attn_type == 'gc':
|
||
|
module_cls = GlobalContext
|
||
3 years ago
|
elif attn_type == 'gca':
|
||
|
module_cls = partial(GlobalContext, fuse_add=True, fuse_scale=False)
|
||
5 years ago
|
elif attn_type == 'cbam':
|
||
|
module_cls = CbamModule
|
||
|
elif attn_type == 'lcbam':
|
||
|
module_cls = LightCbamModule
|
||
4 years ago
|
|
||
|
# Attention / attention-like modules w/ significant params
|
||
|
# Typically replace some of the existing workhorse convs in a network architecture.
|
||
|
# All of these accept a stride argument and can spatially downsample the input.
|
||
|
elif attn_type == 'sk':
|
||
|
module_cls = SelectiveKernel
|
||
|
elif attn_type == 'splat':
|
||
|
module_cls = SplitAttn
|
||
|
|
||
|
# Self-attention / attention-like modules w/ significant compute and/or params
|
||
|
# Typically replace some of the existing workhorse convs in a network architecture.
|
||
|
# All of these accept a stride argument and can spatially downsample the input.
|
||
|
elif attn_type == 'lambda':
|
||
|
return LambdaLayer
|
||
|
elif attn_type == 'bottleneck':
|
||
|
return BottleneckAttn
|
||
|
elif attn_type == 'halo':
|
||
|
return HaloAttn
|
||
|
elif attn_type == 'nl':
|
||
|
module_cls = NonLocalAttn
|
||
|
elif attn_type == 'bat':
|
||
|
module_cls = BatNonLocalAttn
|
||
|
|
||
|
# Woops!
|
||
5 years ago
|
else:
|
||
|
assert False, "Invalid attn module (%s)" % attn_type
|
||
|
elif isinstance(attn_type, bool):
|
||
|
if attn_type:
|
||
|
module_cls = SEModule
|
||
|
else:
|
||
|
module_cls = attn_type
|
||
4 years ago
|
return module_cls
|
||
|
|
||
|
|
||
|
def create_attn(attn_type, channels, **kwargs):
|
||
|
module_cls = get_attn(attn_type)
|
||
5 years ago
|
if module_cls is not None:
|
||
4 years ago
|
# NOTE: it's expected the first (positional) argument of all attention layers is the # input channels
|
||
5 years ago
|
return module_cls(channels, **kwargs)
|
||
|
return None
|