Improve performance of HaloAttn, change default dim calc. Some cleanup / fixes for byoanet. Rename resnet26ts to tfs to distinguish (extra fc).

pull/821/head
Ross Wightman 3 years ago
parent a8b65695f1
commit 8449ba210c

@ -52,13 +52,12 @@ model_cfgs = dict(
blocks=( blocks=(
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25), ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25), interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
), ),
stem_chs=64, stem_chs=64,
stem_type='tiered', stem_type='tiered',
stem_pool='maxpool', stem_pool='maxpool',
num_features=0,
fixed_input_size=True, fixed_input_size=True,
self_attn_layer='bottleneck', self_attn_layer='bottleneck',
self_attn_kwargs=dict() self_attn_kwargs=dict()
@ -66,14 +65,13 @@ model_cfgs = dict(
botnet50ts=ByoModelCfg( botnet50ts=ByoModelCfg(
blocks=( blocks=(
ByoBlockCfg(type='bottle', d=3, c=256, s=2, gs=0, br=0.25), ByoBlockCfg(type='bottle', d=3, c=256, s=2, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25), interleave_blocks(types=('bottle', 'self_attn'), d=4, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25), interleave_blocks(types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=3, c=2048, s=1, gs=0, br=0.25), interleave_blocks(types=('bottle', 'self_attn'), d=3, c=2048, s=1, gs=0, br=0.25),
), ),
stem_chs=64, stem_chs=64,
stem_type='tiered', stem_type='tiered',
stem_pool='', stem_pool='',
num_features=0,
fixed_input_size=True, fixed_input_size=True,
act_layer='silu', act_layer='silu',
self_attn_layer='bottleneck', self_attn_layer='bottleneck',
@ -83,13 +81,12 @@ model_cfgs = dict(
blocks=( blocks=(
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=16, br=0.25), ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=16, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=16, br=0.25), ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=16, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=16, br=0.25),
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=16, br=0.25), ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=16, br=0.25),
), ),
stem_chs=64, stem_chs=64,
stem_type='tiered', stem_type='tiered',
stem_pool='maxpool', stem_pool='maxpool',
num_features=0,
fixed_input_size=True, fixed_input_size=True,
act_layer='silu', act_layer='silu',
attn_layer='eca', attn_layer='eca',
@ -107,7 +104,7 @@ model_cfgs = dict(
stem_chs=64, stem_chs=64,
stem_type='7x7', stem_type='7x7',
stem_pool='maxpool', stem_pool='maxpool',
num_features=0,
self_attn_layer='halo', self_attn_layer='halo',
self_attn_kwargs=dict(block_size=8, halo_size=3), self_attn_kwargs=dict(block_size=8, halo_size=3),
), ),
@ -115,59 +112,57 @@ model_cfgs = dict(
blocks=( blocks=(
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25), ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25), interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25), ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
), ),
stem_chs=64, stem_chs=64,
stem_type='tiered', stem_type='tiered',
stem_pool='maxpool', stem_pool='maxpool',
num_features=0,
self_attn_layer='halo', self_attn_layer='halo',
self_attn_kwargs=dict(block_size=8, halo_size=2) # intended for 256x256 res self_attn_kwargs=dict(block_size=8, halo_size=2, dim_head=16)
), ),
halonet50ts=ByoModelCfg( halonet50ts=ByoModelCfg(
blocks=( blocks=(
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25), interleave_blocks(
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25), types=('bottle', 'self_attn'), every=4, d=4, c=512, s=2, gs=0, br=0.25,
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), self_attn_layer='halo', self_attn_kwargs=dict(block_size=8, halo_size=3, num_heads=4)),
interleave_blocks(types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25),
), ),
stem_chs=64, stem_chs=64,
stem_type='tiered', stem_type='tiered',
stem_pool='maxpool', stem_pool='maxpool',
num_features=0,
act_layer='silu', act_layer='silu',
self_attn_layer='halo', self_attn_layer='halo',
self_attn_kwargs=dict(block_size=8, halo_size=2) self_attn_kwargs=dict(block_size=8, halo_size=3)
), ),
eca_halonext26ts=ByoModelCfg( eca_halonext26ts=ByoModelCfg(
blocks=( blocks=(
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25), ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25), ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=16, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25), ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
), ),
stem_chs=64, stem_chs=64,
stem_type='tiered', stem_type='tiered',
stem_pool='maxpool', stem_pool='maxpool',
num_features=0,
act_layer='silu', act_layer='silu',
attn_layer='eca', attn_layer='eca',
self_attn_layer='halo', self_attn_layer='halo',
self_attn_kwargs=dict(block_size=8, halo_size=2) # intended for 256x256 res self_attn_kwargs=dict(block_size=8, halo_size=2, dim_head=16)
), ),
lambda_resnet26t=ByoModelCfg( lambda_resnet26t=ByoModelCfg(
blocks=( blocks=(
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25), ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25), interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25), ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
), ),
stem_chs=64, stem_chs=64,
stem_type='tiered', stem_type='tiered',
stem_pool='maxpool', stem_pool='maxpool',
num_features=0,
self_attn_layer='lambda', self_attn_layer='lambda',
self_attn_kwargs=dict(r=9) self_attn_kwargs=dict(r=9)
), ),
@ -185,7 +180,7 @@ def _create_byoanet(variant, cfg_variant=None, pretrained=False, **kwargs):
@register_model @register_model
def botnet26t_256(pretrained=False, **kwargs): def botnet26t_256(pretrained=False, **kwargs):
""" Bottleneck Transformer w/ ResNet26-T backbone. Bottleneck attn in final stage. """ Bottleneck Transformer w/ ResNet26-T backbone. Bottleneck attn in final two stages.
""" """
kwargs.setdefault('img_size', 256) kwargs.setdefault('img_size', 256)
return _create_byoanet('botnet26t_256', 'botnet26t', pretrained=pretrained, **kwargs) return _create_byoanet('botnet26t_256', 'botnet26t', pretrained=pretrained, **kwargs)
@ -193,7 +188,7 @@ def botnet26t_256(pretrained=False, **kwargs):
@register_model @register_model
def botnet50ts_256(pretrained=False, **kwargs): def botnet50ts_256(pretrained=False, **kwargs):
""" Bottleneck Transformer w/ ResNet50-T backbone. Bottleneck attn in final stage. """ Bottleneck Transformer w/ ResNet50-T backbone, silu act. Bottleneck attn in final two stages.
""" """
kwargs.setdefault('img_size', 256) kwargs.setdefault('img_size', 256)
return _create_byoanet('botnet50ts_256', 'botnet50ts', pretrained=pretrained, **kwargs) return _create_byoanet('botnet50ts_256', 'botnet50ts', pretrained=pretrained, **kwargs)
@ -201,7 +196,7 @@ def botnet50ts_256(pretrained=False, **kwargs):
@register_model @register_model
def eca_botnext26ts_256(pretrained=False, **kwargs): def eca_botnext26ts_256(pretrained=False, **kwargs):
""" Bottleneck Transformer w/ ResNet26-T backbone. Bottleneck attn in final stage. """ Bottleneck Transformer w/ ResNet26-T backbone, silu act, Bottleneck attn in final two stages.
""" """
kwargs.setdefault('img_size', 256) kwargs.setdefault('img_size', 256)
return _create_byoanet('eca_botnext26ts_256', 'eca_botnext26ts', pretrained=pretrained, **kwargs) return _create_byoanet('eca_botnext26ts_256', 'eca_botnext26ts', pretrained=pretrained, **kwargs)
@ -210,35 +205,34 @@ def eca_botnext26ts_256(pretrained=False, **kwargs):
@register_model @register_model
def halonet_h1(pretrained=False, **kwargs): def halonet_h1(pretrained=False, **kwargs):
""" HaloNet-H1. Halo attention in all stages as per the paper. """ HaloNet-H1. Halo attention in all stages as per the paper.
NOTE: This runs very slowly!
This runs very slowly, param count lower than paper --> something is wrong.
""" """
return _create_byoanet('halonet_h1', pretrained=pretrained, **kwargs) return _create_byoanet('halonet_h1', pretrained=pretrained, **kwargs)
@register_model @register_model
def halonet26t(pretrained=False, **kwargs): def halonet26t(pretrained=False, **kwargs):
""" HaloNet w/ a ResNet26-t backbone, Hallo attention in final stage """ HaloNet w/ a ResNet26-t backbone. Halo attention in final two stages
""" """
return _create_byoanet('halonet26t', pretrained=pretrained, **kwargs) return _create_byoanet('halonet26t', pretrained=pretrained, **kwargs)
@register_model @register_model
def halonet50ts(pretrained=False, **kwargs): def halonet50ts(pretrained=False, **kwargs):
""" HaloNet w/ a ResNet50-t backbone, Hallo attention in final stage """ HaloNet w/ a ResNet50-t backbone, silu act. Halo attention in final two stages
""" """
return _create_byoanet('halonet50ts', pretrained=pretrained, **kwargs) return _create_byoanet('halonet50ts', pretrained=pretrained, **kwargs)
@register_model @register_model
def eca_halonext26ts(pretrained=False, **kwargs): def eca_halonext26ts(pretrained=False, **kwargs):
""" HaloNet w/ a ResNet26-t backbone, Hallo attention in final stage """ HaloNet w/ a ResNet26-t backbone, silu act. Halo attention in final two stages
""" """
return _create_byoanet('eca_halonext26ts', pretrained=pretrained, **kwargs) return _create_byoanet('eca_halonext26ts', pretrained=pretrained, **kwargs)
@register_model @register_model
def lambda_resnet26t(pretrained=False, **kwargs): def lambda_resnet26t(pretrained=False, **kwargs):
""" Lambda-ResNet-26T. Lambda layers in one C4 stage and all C5. """ Lambda-ResNet-26T. Lambda layers in last two stages.
""" """
return _create_byoanet('lambda_resnet26t', pretrained=pretrained, **kwargs) return _create_byoanet('lambda_resnet26t', pretrained=pretrained, **kwargs)

@ -107,13 +107,13 @@ default_cfgs = {
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic', first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic',
min_input_size=(3, 256, 256)), min_input_size=(3, 256, 256)),
'resnet26ts': _cfg( 'resnet26tfs': _cfg(
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
'gcresnet26ts': _cfg( 'gcresnet26tfs': _cfg(
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
'seresnet26ts': _cfg( 'seresnet26tfs': _cfg(
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
'eca_resnet26ts': _cfg( 'eca_resnet26tfs': _cfg(
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
'gcresnet50t': _cfg( 'gcresnet50t': _cfg(
@ -174,13 +174,13 @@ def _rep_vgg_bcfg(d=(4, 6, 16, 1), wf=(1., 1., 1., 1.), groups=0):
def interleave_blocks( def interleave_blocks(
types: Tuple[str, str], every: Union[int, List[int]], d, first: bool = False, **kwargs types: Tuple[str, str], d, every: Union[int, List[int]] = 1, first: bool = False, **kwargs
) -> Tuple[ByoBlockCfg]: ) -> Tuple[ByoBlockCfg]:
""" interleave 2 block types in stack """ interleave 2 block types in stack
""" """
assert len(types) == 2 assert len(types) == 2
if isinstance(every, int): if isinstance(every, int):
every = list(range(0 if first else every, d, every)) every = list(range(0 if first else every, d, every + 1))
if not every: if not every:
every = [d - 1] every = [d - 1]
set(every) set(every)
@ -300,21 +300,6 @@ model_cfgs = dict(
block_kwargs=dict(extra_conv=True), block_kwargs=dict(extra_conv=True),
), ),
# WARN: experimental, may vanish/change
geresnet50t=ByoModelCfg(
blocks=(
ByoBlockCfg(type='edge', d=3, c=256, s=1, br=0.25),
ByoBlockCfg(type='edge', d=4, c=512, s=2, br=0.25),
ByoBlockCfg(type='bottle', d=6, c=1024, s=2, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=2048, s=2, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool=None,
attn_layer='ge',
attn_kwargs=dict(extent=8, extra_params=True),
),
# A series of ResNeXt-26 models w/ one of none, GC, SE, ECA, BAT attn, group size 32, SiLU act, # A series of ResNeXt-26 models w/ one of none, GC, SE, ECA, BAT attn, group size 32, SiLU act,
# and a tiered stem w/ maxpool # and a tiered stem w/ maxpool
resnext26ts=ByoModelCfg( resnext26ts=ByoModelCfg(
@ -327,7 +312,6 @@ model_cfgs = dict(
stem_chs=64, stem_chs=64,
stem_type='tiered', stem_type='tiered',
stem_pool='maxpool', stem_pool='maxpool',
num_features=0,
act_layer='silu', act_layer='silu',
), ),
gcresnext26ts=ByoModelCfg( gcresnext26ts=ByoModelCfg(
@ -340,7 +324,6 @@ model_cfgs = dict(
stem_chs=64, stem_chs=64,
stem_type='tiered', stem_type='tiered',
stem_pool='maxpool', stem_pool='maxpool',
num_features=0,
act_layer='silu', act_layer='silu',
attn_layer='gca', attn_layer='gca',
), ),
@ -354,8 +337,7 @@ model_cfgs = dict(
stem_chs=64, stem_chs=64,
stem_type='tiered', stem_type='tiered',
stem_pool='maxpool', stem_pool='maxpool',
num_features=0, act_layer='silu',
act_layer='relu',
attn_layer='se', attn_layer='se',
), ),
eca_resnext26ts=ByoModelCfg( eca_resnext26ts=ByoModelCfg(
@ -368,7 +350,6 @@ model_cfgs = dict(
stem_chs=64, stem_chs=64,
stem_type='tiered', stem_type='tiered',
stem_pool='maxpool', stem_pool='maxpool',
num_features=0,
act_layer='silu', act_layer='silu',
attn_layer='eca', attn_layer='eca',
), ),
@ -382,7 +363,6 @@ model_cfgs = dict(
stem_chs=64, stem_chs=64,
stem_type='tiered', stem_type='tiered',
stem_pool='maxpool', stem_pool='maxpool',
num_features=0,
act_layer='silu', act_layer='silu',
attn_layer='bat', attn_layer='bat',
attn_kwargs=dict(block_size=8) attn_kwargs=dict(block_size=8)
@ -390,7 +370,7 @@ model_cfgs = dict(
# A series of ResNet-26 models w/ one of none, GC, SE, ECA attn, no groups, SiLU act, 1280 feat fc # A series of ResNet-26 models w/ one of none, GC, SE, ECA attn, no groups, SiLU act, 1280 feat fc
# and a tiered stem w/ no maxpool # and a tiered stem w/ no maxpool
resnet26ts=ByoModelCfg( resnet26tfs=ByoModelCfg(
blocks=( blocks=(
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25), ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
@ -403,7 +383,7 @@ model_cfgs = dict(
num_features=0, num_features=0,
act_layer='silu', act_layer='silu',
), ),
gcresnet26ts=ByoModelCfg( gcresnet26tfs=ByoModelCfg(
blocks=( blocks=(
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25), ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
@ -417,7 +397,7 @@ model_cfgs = dict(
act_layer='silu', act_layer='silu',
attn_layer='gca', attn_layer='gca',
), ),
seresnet26ts=ByoModelCfg( seresnet26tfs=ByoModelCfg(
blocks=( blocks=(
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25), ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
@ -431,7 +411,7 @@ model_cfgs = dict(
act_layer='silu', act_layer='silu',
attn_layer='se', attn_layer='se',
), ),
eca_resnet26ts=ByoModelCfg( eca_resnet26tfs=ByoModelCfg(
blocks=( blocks=(
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25), ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
@ -455,7 +435,7 @@ model_cfgs = dict(
), ),
stem_chs=64, stem_chs=64,
stem_type='tiered', stem_type='tiered',
stem_pool=None, stem_pool='',
attn_layer='gca', attn_layer='gca',
), ),
@ -614,31 +594,31 @@ def bat_resnext26ts(pretrained=False, **kwargs):
@register_model @register_model
def resnet26ts(pretrained=False, **kwargs): def resnet26tfs(pretrained=False, **kwargs):
""" """
""" """
return _create_byobnet('resnet26ts', pretrained=pretrained, **kwargs) return _create_byobnet('resnet26tfs', pretrained=pretrained, **kwargs)
@register_model @register_model
def gcresnet26ts(pretrained=False, **kwargs): def gcresnet26tfs(pretrained=False, **kwargs):
""" """
""" """
return _create_byobnet('gcresnet26ts', pretrained=pretrained, **kwargs) return _create_byobnet('gcresnet26tfs', pretrained=pretrained, **kwargs)
@register_model @register_model
def seresnet26ts(pretrained=False, **kwargs): def seresnet26tfs(pretrained=False, **kwargs):
""" """
""" """
return _create_byobnet('seresnet26ts', pretrained=pretrained, **kwargs) return _create_byobnet('seresnet26tfs', pretrained=pretrained, **kwargs)
@register_model @register_model
def eca_resnet26ts(pretrained=False, **kwargs): def eca_resnet26tfs(pretrained=False, **kwargs):
""" """
""" """
return _create_byobnet('eca_resnet26ts', pretrained=pretrained, **kwargs) return _create_byobnet('eca_resnet26tfs', pretrained=pretrained, **kwargs)
@register_model @register_model
@ -1144,27 +1124,29 @@ def update_block_kwargs(block_kwargs: Dict[str, Any], block_cfg: ByoBlockCfg, mo
layer_fns = block_kwargs['layers'] layer_fns = block_kwargs['layers']
# override attn layer / args with block local config # override attn layer / args with block local config
if block_cfg.attn_kwargs is not None or block_cfg.attn_layer is not None: attn_set = block_cfg.attn_layer is not None
if attn_set or block_cfg.attn_kwargs is not None:
# override attn layer config # override attn layer config
if not block_cfg.attn_layer: if attn_set and not block_cfg.attn_layer:
# empty string for attn_layer type will disable attn for this block # empty string for attn_layer type will disable attn for this block
attn_layer = None attn_layer = None
else: else:
attn_kwargs = override_kwargs(block_cfg.attn_kwargs, model_cfg.attn_kwargs) attn_kwargs = override_kwargs(block_cfg.attn_kwargs, model_cfg.attn_kwargs)
attn_layer = block_cfg.attn_layer or model_cfg.attn_layer attn_layer = block_cfg.attn_layer or model_cfg.attn_layer
attn_layer = partial(get_attn(attn_layer), *attn_kwargs) if attn_layer is not None else None attn_layer = partial(get_attn(attn_layer), **attn_kwargs) if attn_layer is not None else None
layer_fns = replace(layer_fns, attn=attn_layer) layer_fns = replace(layer_fns, attn=attn_layer)
# override self-attn layer / args with block local cfg # override self-attn layer / args with block local cfg
if block_cfg.self_attn_kwargs is not None or block_cfg.self_attn_layer is not None: self_attn_set = block_cfg.self_attn_layer is not None
if self_attn_set or block_cfg.self_attn_kwargs is not None:
# override attn layer config # override attn layer config
if not block_cfg.self_attn_layer: if self_attn_set and not block_cfg.self_attn_layer: # attn_layer == ''
# empty string for self_attn_layer type will disable attn for this block # empty string for self_attn_layer type will disable attn for this block
self_attn_layer = None self_attn_layer = None
else: else:
self_attn_kwargs = override_kwargs(block_cfg.self_attn_kwargs, model_cfg.self_attn_kwargs) self_attn_kwargs = override_kwargs(block_cfg.self_attn_kwargs, model_cfg.self_attn_kwargs)
self_attn_layer = block_cfg.self_attn_layer or model_cfg.self_attn_layer self_attn_layer = block_cfg.self_attn_layer or model_cfg.self_attn_layer
self_attn_layer = partial(get_attn(self_attn_layer), *self_attn_kwargs) \ self_attn_layer = partial(get_attn(self_attn_layer), **self_attn_kwargs) \
if self_attn_layer is not None else None if self_attn_layer is not None else None
layer_fns = replace(layer_fns, self_attn=self_attn_layer) layer_fns = replace(layer_fns, self_attn=self_attn_layer)

@ -103,19 +103,21 @@ class HaloAttn(nn.Module):
- https://arxiv.org/abs/2103.12731 - https://arxiv.org/abs/2103.12731
""" """
def __init__( def __init__(
self, dim, dim_out=None, stride=1, num_heads=8, dim_head=16, block_size=8, halo_size=3, qkv_bias=False): self, dim, dim_out=None, stride=1, num_heads=8, dim_head=None, block_size=8, halo_size=3, qkv_bias=False):
super().__init__() super().__init__()
dim_out = dim_out or dim dim_out = dim_out or dim
assert dim_out % num_heads == 0 assert dim_out % num_heads == 0
self.stride = stride self.stride = stride
self.num_heads = num_heads self.num_heads = num_heads
self.dim_head = dim_head self.dim_head = dim_head or dim // num_heads
self.dim_qk = num_heads * dim_head self.dim_qk = num_heads * self.dim_head
self.dim_v = dim_out self.dim_v = dim_out
self.block_size = block_size self.block_size = block_size
self.halo_size = halo_size self.halo_size = halo_size
self.win_size = block_size + halo_size * 2 # neighbourhood window size self.win_size = block_size + halo_size * 2 # neighbourhood window size
self.scale = self.dim_head ** -0.5 self.scale = self.dim_head ** -0.5
# stride_tricks hard-coded for now, works well on CPU / GPU, neither unfold or as_strided works on TPU (XLA)
self.stride_tricks = True
# FIXME not clear if this stride behaviour is what the paper intended # FIXME not clear if this stride behaviour is what the paper intended
# Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving # Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving
@ -139,28 +141,43 @@ class HaloAttn(nn.Module):
num_h_blocks = H // self.block_size num_h_blocks = H // self.block_size
num_w_blocks = W // self.block_size num_w_blocks = W // self.block_size
num_blocks = num_h_blocks * num_w_blocks num_blocks = num_h_blocks * num_w_blocks
bs_stride = self.block_size // self.stride
q = self.q(x) q = self.q(x)
q = F.unfold(q, kernel_size=self.block_size // self.stride, stride=self.block_size // self.stride) # q = F.unfold(q, kernel_size=bs_stride, stride=bs_stride) # don't need to use unfold here since no overlap
q = q.reshape(-1, self.dim_head, num_h_blocks, bs_stride, num_w_blocks, bs_stride).permute(0, 1, 3, 5, 2, 4)
# B, num_heads * dim_head * block_size ** 2, num_blocks # B, num_heads * dim_head * block_size ** 2, num_blocks
q = q.reshape(B * self.num_heads, self.dim_head, -1, num_blocks).transpose(1, 3) q = q.reshape(B * self.num_heads, self.dim_head, -1, num_blocks).transpose(1, 3)
# B * num_heads, num_blocks, block_size ** 2, dim_head # B * num_heads, num_blocks, block_size ** 2, dim_head
kv = self.kv(x) kv = self.kv(x)
# FIXME I 'think' this unfold does what I want it to, but I should investigate
kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size) # generate overlapping windows using either stride tricks (as_strided) or unfold
if self.stride_tricks:
# this is much faster
kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous()
kv = kv.as_strided((
B, self.dim_qk + self.dim_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks),
stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size))
else:
kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size)
kv = kv.reshape( kv = kv.reshape(
B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), -1, num_blocks).transpose(1, 3) B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), -1, num_blocks).transpose(1, 3)
k, v = torch.split(kv, [self.dim_head, self.dim_v // self.num_heads], dim=-1) k, v = torch.split(kv, [self.dim_head, self.dim_v // self.num_heads], dim=-1)
# B * num_heads, num_blocks, block_size ** 2, dim_head or dim_v // num_heads
attn_logits = (q @ k.transpose(-1, -2)) * self.scale # FIXME should usual attn scale be applied? attn_logits = (q @ k.transpose(-1, -2)) * self.scale # FIXME should usual attn scale be applied?
attn_logits = attn_logits + self.pos_embed(q) # B * num_heads, block_size ** 2, win_size ** 2 attn_logits = attn_logits + self.pos_embed(q) # B * num_heads, block_size ** 2, win_size ** 2
attn_out = attn_logits.softmax(dim=-1) attn_out = attn_logits.softmax(dim=-1)
attn_out = (attn_out @ v).transpose(1, 3) # B * num_heads, dim_v // num_heads, block_size ** 2, num_blocks attn_out = (attn_out @ v).transpose(1, 3) # B * num_heads, dim_v // num_heads, block_size ** 2, num_blocks
attn_out = F.fold(
attn_out.reshape(B, -1, num_blocks), # F.fold can be replaced by reshape + permute, slightly faster
(H // self.stride, W // self.stride), # attn_out = F.fold(
kernel_size=self.block_size // self.stride, stride=self.block_size // self.stride) # attn_out.reshape(B, -1, num_blocks),
# (H // self.stride, W // self.stride), kernel_size=bs_stride, stride=bs_stride)
attn_out = attn_out.reshape(-1, bs_stride, bs_stride, num_h_blocks, num_w_blocks)
attn_out = attn_out.permute(0, 3, 1, 4, 2).contiguous().view(B, self.dim_v, H // self.stride, W // self.stride)
# B, dim_out, H // stride, W // stride # B, dim_out, H // stride, W // stride
return attn_out return attn_out

Loading…
Cancel
Save