Improve performance of HaloAttn, change default dim calc. Some cleanup / fixes for byoanet. Rename resnet26ts to tfs to distinguish (extra fc).

pull/821/head
Ross Wightman 3 years ago
parent a8b65695f1
commit 8449ba210c

@ -52,13 +52,12 @@ model_cfgs = dict(
blocks=(
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
fixed_input_size=True,
self_attn_layer='bottleneck',
self_attn_kwargs=dict()
@ -66,14 +65,13 @@ model_cfgs = dict(
botnet50ts=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=3, c=256, s=2, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=3, c=2048, s=1, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), d=4, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), d=3, c=2048, s=1, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='',
num_features=0,
fixed_input_size=True,
act_layer='silu',
self_attn_layer='bottleneck',
@ -83,13 +81,12 @@ model_cfgs = dict(
blocks=(
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=16, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=16, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=16, br=0.25),
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=16, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
fixed_input_size=True,
act_layer='silu',
attn_layer='eca',
@ -107,7 +104,7 @@ model_cfgs = dict(
stem_chs=64,
stem_type='7x7',
stem_pool='maxpool',
num_features=0,
self_attn_layer='halo',
self_attn_kwargs=dict(block_size=8, halo_size=3),
),
@ -115,59 +112,57 @@ model_cfgs = dict(
blocks=(
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
self_attn_layer='halo',
self_attn_kwargs=dict(block_size=8, halo_size=2) # intended for 256x256 res
self_attn_kwargs=dict(block_size=8, halo_size=2, dim_head=16)
),
halonet50ts=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
interleave_blocks(
types=('bottle', 'self_attn'), every=4, d=4, c=512, s=2, gs=0, br=0.25,
self_attn_layer='halo', self_attn_kwargs=dict(block_size=8, halo_size=3, num_heads=4)),
interleave_blocks(types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
act_layer='silu',
self_attn_layer='halo',
self_attn_kwargs=dict(block_size=8, halo_size=2)
self_attn_kwargs=dict(block_size=8, halo_size=3)
),
eca_halonext26ts=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=16, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
act_layer='silu',
attn_layer='eca',
self_attn_layer='halo',
self_attn_kwargs=dict(block_size=8, halo_size=2) # intended for 256x256 res
self_attn_kwargs=dict(block_size=8, halo_size=2, dim_head=16)
),
lambda_resnet26t=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
self_attn_layer='lambda',
self_attn_kwargs=dict(r=9)
),
@ -185,7 +180,7 @@ def _create_byoanet(variant, cfg_variant=None, pretrained=False, **kwargs):
@register_model
def botnet26t_256(pretrained=False, **kwargs):
""" Bottleneck Transformer w/ ResNet26-T backbone. Bottleneck attn in final stage.
""" Bottleneck Transformer w/ ResNet26-T backbone. Bottleneck attn in final two stages.
"""
kwargs.setdefault('img_size', 256)
return _create_byoanet('botnet26t_256', 'botnet26t', pretrained=pretrained, **kwargs)
@ -193,7 +188,7 @@ def botnet26t_256(pretrained=False, **kwargs):
@register_model
def botnet50ts_256(pretrained=False, **kwargs):
""" Bottleneck Transformer w/ ResNet50-T backbone. Bottleneck attn in final stage.
""" Bottleneck Transformer w/ ResNet50-T backbone, silu act. Bottleneck attn in final two stages.
"""
kwargs.setdefault('img_size', 256)
return _create_byoanet('botnet50ts_256', 'botnet50ts', pretrained=pretrained, **kwargs)
@ -201,7 +196,7 @@ def botnet50ts_256(pretrained=False, **kwargs):
@register_model
def eca_botnext26ts_256(pretrained=False, **kwargs):
""" Bottleneck Transformer w/ ResNet26-T backbone. Bottleneck attn in final stage.
""" Bottleneck Transformer w/ ResNet26-T backbone, silu act, Bottleneck attn in final two stages.
"""
kwargs.setdefault('img_size', 256)
return _create_byoanet('eca_botnext26ts_256', 'eca_botnext26ts', pretrained=pretrained, **kwargs)
@ -210,35 +205,34 @@ def eca_botnext26ts_256(pretrained=False, **kwargs):
@register_model
def halonet_h1(pretrained=False, **kwargs):
""" HaloNet-H1. Halo attention in all stages as per the paper.
This runs very slowly, param count lower than paper --> something is wrong.
NOTE: This runs very slowly!
"""
return _create_byoanet('halonet_h1', pretrained=pretrained, **kwargs)
@register_model
def halonet26t(pretrained=False, **kwargs):
""" HaloNet w/ a ResNet26-t backbone, Hallo attention in final stage
""" HaloNet w/ a ResNet26-t backbone. Halo attention in final two stages
"""
return _create_byoanet('halonet26t', pretrained=pretrained, **kwargs)
@register_model
def halonet50ts(pretrained=False, **kwargs):
""" HaloNet w/ a ResNet50-t backbone, Hallo attention in final stage
""" HaloNet w/ a ResNet50-t backbone, silu act. Halo attention in final two stages
"""
return _create_byoanet('halonet50ts', pretrained=pretrained, **kwargs)
@register_model
def eca_halonext26ts(pretrained=False, **kwargs):
""" HaloNet w/ a ResNet26-t backbone, Hallo attention in final stage
""" HaloNet w/ a ResNet26-t backbone, silu act. Halo attention in final two stages
"""
return _create_byoanet('eca_halonext26ts', pretrained=pretrained, **kwargs)
@register_model
def lambda_resnet26t(pretrained=False, **kwargs):
""" Lambda-ResNet-26T. Lambda layers in one C4 stage and all C5.
""" Lambda-ResNet-26T. Lambda layers in last two stages.
"""
return _create_byoanet('lambda_resnet26t', pretrained=pretrained, **kwargs)

@ -107,13 +107,13 @@ default_cfgs = {
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic',
min_input_size=(3, 256, 256)),
'resnet26ts': _cfg(
'resnet26tfs': _cfg(
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
'gcresnet26ts': _cfg(
'gcresnet26tfs': _cfg(
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
'seresnet26ts': _cfg(
'seresnet26tfs': _cfg(
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
'eca_resnet26ts': _cfg(
'eca_resnet26tfs': _cfg(
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
'gcresnet50t': _cfg(
@ -174,13 +174,13 @@ def _rep_vgg_bcfg(d=(4, 6, 16, 1), wf=(1., 1., 1., 1.), groups=0):
def interleave_blocks(
types: Tuple[str, str], every: Union[int, List[int]], d, first: bool = False, **kwargs
types: Tuple[str, str], d, every: Union[int, List[int]] = 1, first: bool = False, **kwargs
) -> Tuple[ByoBlockCfg]:
""" interleave 2 block types in stack
"""
assert len(types) == 2
if isinstance(every, int):
every = list(range(0 if first else every, d, every))
every = list(range(0 if first else every, d, every + 1))
if not every:
every = [d - 1]
set(every)
@ -300,21 +300,6 @@ model_cfgs = dict(
block_kwargs=dict(extra_conv=True),
),
# WARN: experimental, may vanish/change
geresnet50t=ByoModelCfg(
blocks=(
ByoBlockCfg(type='edge', d=3, c=256, s=1, br=0.25),
ByoBlockCfg(type='edge', d=4, c=512, s=2, br=0.25),
ByoBlockCfg(type='bottle', d=6, c=1024, s=2, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=2048, s=2, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool=None,
attn_layer='ge',
attn_kwargs=dict(extent=8, extra_params=True),
),
# A series of ResNeXt-26 models w/ one of none, GC, SE, ECA, BAT attn, group size 32, SiLU act,
# and a tiered stem w/ maxpool
resnext26ts=ByoModelCfg(
@ -327,7 +312,6 @@ model_cfgs = dict(
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
act_layer='silu',
),
gcresnext26ts=ByoModelCfg(
@ -340,7 +324,6 @@ model_cfgs = dict(
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
act_layer='silu',
attn_layer='gca',
),
@ -354,8 +337,7 @@ model_cfgs = dict(
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
act_layer='relu',
act_layer='silu',
attn_layer='se',
),
eca_resnext26ts=ByoModelCfg(
@ -368,7 +350,6 @@ model_cfgs = dict(
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
act_layer='silu',
attn_layer='eca',
),
@ -382,7 +363,6 @@ model_cfgs = dict(
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
act_layer='silu',
attn_layer='bat',
attn_kwargs=dict(block_size=8)
@ -390,7 +370,7 @@ model_cfgs = dict(
# A series of ResNet-26 models w/ one of none, GC, SE, ECA attn, no groups, SiLU act, 1280 feat fc
# and a tiered stem w/ no maxpool
resnet26ts=ByoModelCfg(
resnet26tfs=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
@ -403,7 +383,7 @@ model_cfgs = dict(
num_features=0,
act_layer='silu',
),
gcresnet26ts=ByoModelCfg(
gcresnet26tfs=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
@ -417,7 +397,7 @@ model_cfgs = dict(
act_layer='silu',
attn_layer='gca',
),
seresnet26ts=ByoModelCfg(
seresnet26tfs=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
@ -431,7 +411,7 @@ model_cfgs = dict(
act_layer='silu',
attn_layer='se',
),
eca_resnet26ts=ByoModelCfg(
eca_resnet26tfs=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
@ -455,7 +435,7 @@ model_cfgs = dict(
),
stem_chs=64,
stem_type='tiered',
stem_pool=None,
stem_pool='',
attn_layer='gca',
),
@ -614,31 +594,31 @@ def bat_resnext26ts(pretrained=False, **kwargs):
@register_model
def resnet26ts(pretrained=False, **kwargs):
def resnet26tfs(pretrained=False, **kwargs):
"""
"""
return _create_byobnet('resnet26ts', pretrained=pretrained, **kwargs)
return _create_byobnet('resnet26tfs', pretrained=pretrained, **kwargs)
@register_model
def gcresnet26ts(pretrained=False, **kwargs):
def gcresnet26tfs(pretrained=False, **kwargs):
"""
"""
return _create_byobnet('gcresnet26ts', pretrained=pretrained, **kwargs)
return _create_byobnet('gcresnet26tfs', pretrained=pretrained, **kwargs)
@register_model
def seresnet26ts(pretrained=False, **kwargs):
def seresnet26tfs(pretrained=False, **kwargs):
"""
"""
return _create_byobnet('seresnet26ts', pretrained=pretrained, **kwargs)
return _create_byobnet('seresnet26tfs', pretrained=pretrained, **kwargs)
@register_model
def eca_resnet26ts(pretrained=False, **kwargs):
def eca_resnet26tfs(pretrained=False, **kwargs):
"""
"""
return _create_byobnet('eca_resnet26ts', pretrained=pretrained, **kwargs)
return _create_byobnet('eca_resnet26tfs', pretrained=pretrained, **kwargs)
@register_model
@ -1144,27 +1124,29 @@ def update_block_kwargs(block_kwargs: Dict[str, Any], block_cfg: ByoBlockCfg, mo
layer_fns = block_kwargs['layers']
# override attn layer / args with block local config
if block_cfg.attn_kwargs is not None or block_cfg.attn_layer is not None:
attn_set = block_cfg.attn_layer is not None
if attn_set or block_cfg.attn_kwargs is not None:
# override attn layer config
if not block_cfg.attn_layer:
if attn_set and not block_cfg.attn_layer:
# empty string for attn_layer type will disable attn for this block
attn_layer = None
else:
attn_kwargs = override_kwargs(block_cfg.attn_kwargs, model_cfg.attn_kwargs)
attn_layer = block_cfg.attn_layer or model_cfg.attn_layer
attn_layer = partial(get_attn(attn_layer), *attn_kwargs) if attn_layer is not None else None
attn_layer = partial(get_attn(attn_layer), **attn_kwargs) if attn_layer is not None else None
layer_fns = replace(layer_fns, attn=attn_layer)
# override self-attn layer / args with block local cfg
if block_cfg.self_attn_kwargs is not None or block_cfg.self_attn_layer is not None:
self_attn_set = block_cfg.self_attn_layer is not None
if self_attn_set or block_cfg.self_attn_kwargs is not None:
# override attn layer config
if not block_cfg.self_attn_layer:
if self_attn_set and not block_cfg.self_attn_layer: # attn_layer == ''
# empty string for self_attn_layer type will disable attn for this block
self_attn_layer = None
else:
self_attn_kwargs = override_kwargs(block_cfg.self_attn_kwargs, model_cfg.self_attn_kwargs)
self_attn_layer = block_cfg.self_attn_layer or model_cfg.self_attn_layer
self_attn_layer = partial(get_attn(self_attn_layer), *self_attn_kwargs) \
self_attn_layer = partial(get_attn(self_attn_layer), **self_attn_kwargs) \
if self_attn_layer is not None else None
layer_fns = replace(layer_fns, self_attn=self_attn_layer)

@ -103,19 +103,21 @@ class HaloAttn(nn.Module):
- https://arxiv.org/abs/2103.12731
"""
def __init__(
self, dim, dim_out=None, stride=1, num_heads=8, dim_head=16, block_size=8, halo_size=3, qkv_bias=False):
self, dim, dim_out=None, stride=1, num_heads=8, dim_head=None, block_size=8, halo_size=3, qkv_bias=False):
super().__init__()
dim_out = dim_out or dim
assert dim_out % num_heads == 0
self.stride = stride
self.num_heads = num_heads
self.dim_head = dim_head
self.dim_qk = num_heads * dim_head
self.dim_head = dim_head or dim // num_heads
self.dim_qk = num_heads * self.dim_head
self.dim_v = dim_out
self.block_size = block_size
self.halo_size = halo_size
self.win_size = block_size + halo_size * 2 # neighbourhood window size
self.scale = self.dim_head ** -0.5
# stride_tricks hard-coded for now, works well on CPU / GPU, neither unfold or as_strided works on TPU (XLA)
self.stride_tricks = True
# FIXME not clear if this stride behaviour is what the paper intended
# Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving
@ -139,28 +141,43 @@ class HaloAttn(nn.Module):
num_h_blocks = H // self.block_size
num_w_blocks = W // self.block_size
num_blocks = num_h_blocks * num_w_blocks
bs_stride = self.block_size // self.stride
q = self.q(x)
q = F.unfold(q, kernel_size=self.block_size // self.stride, stride=self.block_size // self.stride)
# q = F.unfold(q, kernel_size=bs_stride, stride=bs_stride) # don't need to use unfold here since no overlap
q = q.reshape(-1, self.dim_head, num_h_blocks, bs_stride, num_w_blocks, bs_stride).permute(0, 1, 3, 5, 2, 4)
# B, num_heads * dim_head * block_size ** 2, num_blocks
q = q.reshape(B * self.num_heads, self.dim_head, -1, num_blocks).transpose(1, 3)
# B * num_heads, num_blocks, block_size ** 2, dim_head
kv = self.kv(x)
# FIXME I 'think' this unfold does what I want it to, but I should investigate
kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size)
# generate overlapping windows using either stride tricks (as_strided) or unfold
if self.stride_tricks:
# this is much faster
kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous()
kv = kv.as_strided((
B, self.dim_qk + self.dim_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks),
stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size))
else:
kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size)
kv = kv.reshape(
B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), -1, num_blocks).transpose(1, 3)
k, v = torch.split(kv, [self.dim_head, self.dim_v // self.num_heads], dim=-1)
# B * num_heads, num_blocks, block_size ** 2, dim_head or dim_v // num_heads
attn_logits = (q @ k.transpose(-1, -2)) * self.scale # FIXME should usual attn scale be applied?
attn_logits = attn_logits + self.pos_embed(q) # B * num_heads, block_size ** 2, win_size ** 2
attn_out = attn_logits.softmax(dim=-1)
attn_out = (attn_out @ v).transpose(1, 3) # B * num_heads, dim_v // num_heads, block_size ** 2, num_blocks
attn_out = F.fold(
attn_out.reshape(B, -1, num_blocks),
(H // self.stride, W // self.stride),
kernel_size=self.block_size // self.stride, stride=self.block_size // self.stride)
# F.fold can be replaced by reshape + permute, slightly faster
# attn_out = F.fold(
# attn_out.reshape(B, -1, num_blocks),
# (H // self.stride, W // self.stride), kernel_size=bs_stride, stride=bs_stride)
attn_out = attn_out.reshape(-1, bs_stride, bs_stride, num_h_blocks, num_w_blocks)
attn_out = attn_out.permute(0, 3, 1, 4, 2).contiguous().view(B, self.dim_v, H // self.stride, W // self.stride)
# B, dim_out, H // stride, W // stride
return attn_out

Loading…
Cancel
Save