You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/models/layers/create_self_attn.py

23 lines
638 B

from .bottleneck_attn import BottleneckAttn
from .halo_attn import HaloAttn
from .lambda_layer import LambdaLayer
from .swin_attn import WindowAttention
def get_self_attn(attn_type):
if attn_type == 'bottleneck':
return BottleneckAttn
elif attn_type == 'halo':
return HaloAttn
elif attn_type == 'lambda':
return LambdaLayer
elif attn_type == 'swin':
return WindowAttention
else:
assert False, f"Unknown attn type ({attn_type})"
def create_self_attn(attn_type, dim, stride=1, **kwargs):
attn_fn = get_self_attn(attn_type)
return attn_fn(dim, stride=stride, **kwargs)