You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
23 lines
638 B
23 lines
638 B
from .bottleneck_attn import BottleneckAttn
|
|
from .halo_attn import HaloAttn
|
|
from .lambda_layer import LambdaLayer
|
|
from .swin_attn import WindowAttention
|
|
|
|
|
|
def get_self_attn(attn_type):
|
|
if attn_type == 'bottleneck':
|
|
return BottleneckAttn
|
|
elif attn_type == 'halo':
|
|
return HaloAttn
|
|
elif attn_type == 'lambda':
|
|
return LambdaLayer
|
|
elif attn_type == 'swin':
|
|
return WindowAttention
|
|
else:
|
|
assert False, f"Unknown attn type ({attn_type})"
|
|
|
|
|
|
def create_self_attn(attn_type, dim, stride=1, **kwargs):
|
|
attn_fn = get_self_attn(attn_type)
|
|
return attn_fn(dim, stride=stride, **kwargs)
|