From 8449ba210c6bde6d65a237eb96a81b2ca2e38de2 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 26 Aug 2021 21:56:44 -0700 Subject: [PATCH] Improve performance of HaloAttn, change default dim calc. Some cleanup / fixes for byoanet. Rename resnet26ts to tfs to distinguish (extra fc). --- timm/models/byoanet.py | 56 +++++++++++------------- timm/models/byobnet.py | 76 +++++++++++++-------------------- timm/models/layers/halo_attn.py | 37 +++++++++++----- 3 files changed, 81 insertions(+), 88 deletions(-) diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index b11e7d52..17e6c514 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -52,13 +52,12 @@ model_cfgs = dict( blocks=( ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25), - interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25), ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), ), stem_chs=64, stem_type='tiered', stem_pool='maxpool', - num_features=0, fixed_input_size=True, self_attn_layer='bottleneck', self_attn_kwargs=dict() @@ -66,14 +65,13 @@ model_cfgs = dict( botnet50ts=ByoModelCfg( blocks=( ByoBlockCfg(type='bottle', d=3, c=256, s=2, gs=0, br=0.25), - ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25), - interleave_blocks(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25), - ByoBlockCfg(type='self_attn', d=3, c=2048, s=1, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), d=4, c=512, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), d=3, c=2048, s=1, gs=0, br=0.25), ), stem_chs=64, stem_type='tiered', stem_pool='', - num_features=0, fixed_input_size=True, act_layer='silu', self_attn_layer='bottleneck', @@ -83,13 +81,12 @@ model_cfgs = dict( blocks=( ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=16, br=0.25), ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=16, br=0.25), - interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=16, br=0.25), ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=16, br=0.25), ), stem_chs=64, stem_type='tiered', stem_pool='maxpool', - num_features=0, fixed_input_size=True, act_layer='silu', attn_layer='eca', @@ -107,7 +104,7 @@ model_cfgs = dict( stem_chs=64, stem_type='7x7', stem_pool='maxpool', - num_features=0, + self_attn_layer='halo', self_attn_kwargs=dict(block_size=8, halo_size=3), ), @@ -115,59 +112,57 @@ model_cfgs = dict( blocks=( ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25), - interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25), ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25), ), stem_chs=64, stem_type='tiered', stem_pool='maxpool', - num_features=0, self_attn_layer='halo', - self_attn_kwargs=dict(block_size=8, halo_size=2) # intended for 256x256 res + self_attn_kwargs=dict(block_size=8, halo_size=2, dim_head=16) ), halonet50ts=ByoModelCfg( blocks=( ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), - ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25), - interleave_blocks(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25), - ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), + interleave_blocks( + types=('bottle', 'self_attn'), every=4, d=4, c=512, s=2, gs=0, br=0.25, + self_attn_layer='halo', self_attn_kwargs=dict(block_size=8, halo_size=3, num_heads=4)), + interleave_blocks(types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25), ), stem_chs=64, stem_type='tiered', stem_pool='maxpool', - num_features=0, act_layer='silu', self_attn_layer='halo', - self_attn_kwargs=dict(block_size=8, halo_size=2) + self_attn_kwargs=dict(block_size=8, halo_size=3) ), eca_halonext26ts=ByoModelCfg( blocks=( ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25), ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25), - interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=16, br=0.25), ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25), ), stem_chs=64, stem_type='tiered', stem_pool='maxpool', - num_features=0, act_layer='silu', attn_layer='eca', self_attn_layer='halo', - self_attn_kwargs=dict(block_size=8, halo_size=2) # intended for 256x256 res + self_attn_kwargs=dict(block_size=8, halo_size=2, dim_head=16) ), lambda_resnet26t=ByoModelCfg( blocks=( ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25), - interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25), ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25), ), stem_chs=64, stem_type='tiered', stem_pool='maxpool', - num_features=0, self_attn_layer='lambda', self_attn_kwargs=dict(r=9) ), @@ -185,7 +180,7 @@ def _create_byoanet(variant, cfg_variant=None, pretrained=False, **kwargs): @register_model def botnet26t_256(pretrained=False, **kwargs): - """ Bottleneck Transformer w/ ResNet26-T backbone. Bottleneck attn in final stage. + """ Bottleneck Transformer w/ ResNet26-T backbone. Bottleneck attn in final two stages. """ kwargs.setdefault('img_size', 256) return _create_byoanet('botnet26t_256', 'botnet26t', pretrained=pretrained, **kwargs) @@ -193,7 +188,7 @@ def botnet26t_256(pretrained=False, **kwargs): @register_model def botnet50ts_256(pretrained=False, **kwargs): - """ Bottleneck Transformer w/ ResNet50-T backbone. Bottleneck attn in final stage. + """ Bottleneck Transformer w/ ResNet50-T backbone, silu act. Bottleneck attn in final two stages. """ kwargs.setdefault('img_size', 256) return _create_byoanet('botnet50ts_256', 'botnet50ts', pretrained=pretrained, **kwargs) @@ -201,7 +196,7 @@ def botnet50ts_256(pretrained=False, **kwargs): @register_model def eca_botnext26ts_256(pretrained=False, **kwargs): - """ Bottleneck Transformer w/ ResNet26-T backbone. Bottleneck attn in final stage. + """ Bottleneck Transformer w/ ResNet26-T backbone, silu act, Bottleneck attn in final two stages. """ kwargs.setdefault('img_size', 256) return _create_byoanet('eca_botnext26ts_256', 'eca_botnext26ts', pretrained=pretrained, **kwargs) @@ -210,35 +205,34 @@ def eca_botnext26ts_256(pretrained=False, **kwargs): @register_model def halonet_h1(pretrained=False, **kwargs): """ HaloNet-H1. Halo attention in all stages as per the paper. - - This runs very slowly, param count lower than paper --> something is wrong. + NOTE: This runs very slowly! """ return _create_byoanet('halonet_h1', pretrained=pretrained, **kwargs) @register_model def halonet26t(pretrained=False, **kwargs): - """ HaloNet w/ a ResNet26-t backbone, Hallo attention in final stage + """ HaloNet w/ a ResNet26-t backbone. Halo attention in final two stages """ return _create_byoanet('halonet26t', pretrained=pretrained, **kwargs) @register_model def halonet50ts(pretrained=False, **kwargs): - """ HaloNet w/ a ResNet50-t backbone, Hallo attention in final stage + """ HaloNet w/ a ResNet50-t backbone, silu act. Halo attention in final two stages """ return _create_byoanet('halonet50ts', pretrained=pretrained, **kwargs) @register_model def eca_halonext26ts(pretrained=False, **kwargs): - """ HaloNet w/ a ResNet26-t backbone, Hallo attention in final stage + """ HaloNet w/ a ResNet26-t backbone, silu act. Halo attention in final two stages """ return _create_byoanet('eca_halonext26ts', pretrained=pretrained, **kwargs) @register_model def lambda_resnet26t(pretrained=False, **kwargs): - """ Lambda-ResNet-26T. Lambda layers in one C4 stage and all C5. + """ Lambda-ResNet-26T. Lambda layers in last two stages. """ return _create_byoanet('lambda_resnet26t', pretrained=pretrained, **kwargs) diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index add07b2f..81ef836b 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -107,13 +107,13 @@ default_cfgs = { first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic', min_input_size=(3, 256, 256)), - 'resnet26ts': _cfg( + 'resnet26tfs': _cfg( first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), - 'gcresnet26ts': _cfg( + 'gcresnet26tfs': _cfg( first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), - 'seresnet26ts': _cfg( + 'seresnet26tfs': _cfg( first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), - 'eca_resnet26ts': _cfg( + 'eca_resnet26tfs': _cfg( first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), 'gcresnet50t': _cfg( @@ -174,13 +174,13 @@ def _rep_vgg_bcfg(d=(4, 6, 16, 1), wf=(1., 1., 1., 1.), groups=0): def interleave_blocks( - types: Tuple[str, str], every: Union[int, List[int]], d, first: bool = False, **kwargs + types: Tuple[str, str], d, every: Union[int, List[int]] = 1, first: bool = False, **kwargs ) -> Tuple[ByoBlockCfg]: """ interleave 2 block types in stack """ assert len(types) == 2 if isinstance(every, int): - every = list(range(0 if first else every, d, every)) + every = list(range(0 if first else every, d, every + 1)) if not every: every = [d - 1] set(every) @@ -300,21 +300,6 @@ model_cfgs = dict( block_kwargs=dict(extra_conv=True), ), - # WARN: experimental, may vanish/change - geresnet50t=ByoModelCfg( - blocks=( - ByoBlockCfg(type='edge', d=3, c=256, s=1, br=0.25), - ByoBlockCfg(type='edge', d=4, c=512, s=2, br=0.25), - ByoBlockCfg(type='bottle', d=6, c=1024, s=2, br=0.25), - ByoBlockCfg(type='bottle', d=3, c=2048, s=2, br=0.25), - ), - stem_chs=64, - stem_type='tiered', - stem_pool=None, - attn_layer='ge', - attn_kwargs=dict(extent=8, extra_params=True), - ), - # A series of ResNeXt-26 models w/ one of none, GC, SE, ECA, BAT attn, group size 32, SiLU act, # and a tiered stem w/ maxpool resnext26ts=ByoModelCfg( @@ -327,7 +312,6 @@ model_cfgs = dict( stem_chs=64, stem_type='tiered', stem_pool='maxpool', - num_features=0, act_layer='silu', ), gcresnext26ts=ByoModelCfg( @@ -340,7 +324,6 @@ model_cfgs = dict( stem_chs=64, stem_type='tiered', stem_pool='maxpool', - num_features=0, act_layer='silu', attn_layer='gca', ), @@ -354,8 +337,7 @@ model_cfgs = dict( stem_chs=64, stem_type='tiered', stem_pool='maxpool', - num_features=0, - act_layer='relu', + act_layer='silu', attn_layer='se', ), eca_resnext26ts=ByoModelCfg( @@ -368,7 +350,6 @@ model_cfgs = dict( stem_chs=64, stem_type='tiered', stem_pool='maxpool', - num_features=0, act_layer='silu', attn_layer='eca', ), @@ -382,15 +363,14 @@ model_cfgs = dict( stem_chs=64, stem_type='tiered', stem_pool='maxpool', - num_features=0, act_layer='silu', attn_layer='bat', attn_kwargs=dict(block_size=8) ), - # A series of ResNet-26 models w/ one of none, GC, SE, ECA attn, no groups, SiLU act, 1280 feat fc + # A series of ResNet-26 models w/ one of none, GC, SE, ECA attn, no groups, SiLU act, 1280 feat fc # and a tiered stem w/ no maxpool - resnet26ts=ByoModelCfg( + resnet26tfs=ByoModelCfg( blocks=( ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25), @@ -403,7 +383,7 @@ model_cfgs = dict( num_features=0, act_layer='silu', ), - gcresnet26ts=ByoModelCfg( + gcresnet26tfs=ByoModelCfg( blocks=( ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25), @@ -417,7 +397,7 @@ model_cfgs = dict( act_layer='silu', attn_layer='gca', ), - seresnet26ts=ByoModelCfg( + seresnet26tfs=ByoModelCfg( blocks=( ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25), @@ -431,7 +411,7 @@ model_cfgs = dict( act_layer='silu', attn_layer='se', ), - eca_resnet26ts=ByoModelCfg( + eca_resnet26tfs=ByoModelCfg( blocks=( ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25), @@ -455,7 +435,7 @@ model_cfgs = dict( ), stem_chs=64, stem_type='tiered', - stem_pool=None, + stem_pool='', attn_layer='gca', ), @@ -614,31 +594,31 @@ def bat_resnext26ts(pretrained=False, **kwargs): @register_model -def resnet26ts(pretrained=False, **kwargs): +def resnet26tfs(pretrained=False, **kwargs): """ """ - return _create_byobnet('resnet26ts', pretrained=pretrained, **kwargs) + return _create_byobnet('resnet26tfs', pretrained=pretrained, **kwargs) @register_model -def gcresnet26ts(pretrained=False, **kwargs): +def gcresnet26tfs(pretrained=False, **kwargs): """ """ - return _create_byobnet('gcresnet26ts', pretrained=pretrained, **kwargs) + return _create_byobnet('gcresnet26tfs', pretrained=pretrained, **kwargs) @register_model -def seresnet26ts(pretrained=False, **kwargs): +def seresnet26tfs(pretrained=False, **kwargs): """ """ - return _create_byobnet('seresnet26ts', pretrained=pretrained, **kwargs) + return _create_byobnet('seresnet26tfs', pretrained=pretrained, **kwargs) @register_model -def eca_resnet26ts(pretrained=False, **kwargs): +def eca_resnet26tfs(pretrained=False, **kwargs): """ """ - return _create_byobnet('eca_resnet26ts', pretrained=pretrained, **kwargs) + return _create_byobnet('eca_resnet26tfs', pretrained=pretrained, **kwargs) @register_model @@ -1144,27 +1124,29 @@ def update_block_kwargs(block_kwargs: Dict[str, Any], block_cfg: ByoBlockCfg, mo layer_fns = block_kwargs['layers'] # override attn layer / args with block local config - if block_cfg.attn_kwargs is not None or block_cfg.attn_layer is not None: + attn_set = block_cfg.attn_layer is not None + if attn_set or block_cfg.attn_kwargs is not None: # override attn layer config - if not block_cfg.attn_layer: + if attn_set and not block_cfg.attn_layer: # empty string for attn_layer type will disable attn for this block attn_layer = None else: attn_kwargs = override_kwargs(block_cfg.attn_kwargs, model_cfg.attn_kwargs) attn_layer = block_cfg.attn_layer or model_cfg.attn_layer - attn_layer = partial(get_attn(attn_layer), *attn_kwargs) if attn_layer is not None else None + attn_layer = partial(get_attn(attn_layer), **attn_kwargs) if attn_layer is not None else None layer_fns = replace(layer_fns, attn=attn_layer) # override self-attn layer / args with block local cfg - if block_cfg.self_attn_kwargs is not None or block_cfg.self_attn_layer is not None: + self_attn_set = block_cfg.self_attn_layer is not None + if self_attn_set or block_cfg.self_attn_kwargs is not None: # override attn layer config - if not block_cfg.self_attn_layer: + if self_attn_set and not block_cfg.self_attn_layer: # attn_layer == '' # empty string for self_attn_layer type will disable attn for this block self_attn_layer = None else: self_attn_kwargs = override_kwargs(block_cfg.self_attn_kwargs, model_cfg.self_attn_kwargs) self_attn_layer = block_cfg.self_attn_layer or model_cfg.self_attn_layer - self_attn_layer = partial(get_attn(self_attn_layer), *self_attn_kwargs) \ + self_attn_layer = partial(get_attn(self_attn_layer), **self_attn_kwargs) \ if self_attn_layer is not None else None layer_fns = replace(layer_fns, self_attn=self_attn_layer) diff --git a/timm/models/layers/halo_attn.py b/timm/models/layers/halo_attn.py index 87cae895..044c5dad 100644 --- a/timm/models/layers/halo_attn.py +++ b/timm/models/layers/halo_attn.py @@ -103,19 +103,21 @@ class HaloAttn(nn.Module): - https://arxiv.org/abs/2103.12731 """ def __init__( - self, dim, dim_out=None, stride=1, num_heads=8, dim_head=16, block_size=8, halo_size=3, qkv_bias=False): + self, dim, dim_out=None, stride=1, num_heads=8, dim_head=None, block_size=8, halo_size=3, qkv_bias=False): super().__init__() dim_out = dim_out or dim assert dim_out % num_heads == 0 self.stride = stride self.num_heads = num_heads - self.dim_head = dim_head - self.dim_qk = num_heads * dim_head + self.dim_head = dim_head or dim // num_heads + self.dim_qk = num_heads * self.dim_head self.dim_v = dim_out self.block_size = block_size self.halo_size = halo_size self.win_size = block_size + halo_size * 2 # neighbourhood window size self.scale = self.dim_head ** -0.5 + # stride_tricks hard-coded for now, works well on CPU / GPU, neither unfold or as_strided works on TPU (XLA) + self.stride_tricks = True # FIXME not clear if this stride behaviour is what the paper intended # Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving @@ -139,28 +141,43 @@ class HaloAttn(nn.Module): num_h_blocks = H // self.block_size num_w_blocks = W // self.block_size num_blocks = num_h_blocks * num_w_blocks + bs_stride = self.block_size // self.stride q = self.q(x) - q = F.unfold(q, kernel_size=self.block_size // self.stride, stride=self.block_size // self.stride) + # q = F.unfold(q, kernel_size=bs_stride, stride=bs_stride) # don't need to use unfold here since no overlap + q = q.reshape(-1, self.dim_head, num_h_blocks, bs_stride, num_w_blocks, bs_stride).permute(0, 1, 3, 5, 2, 4) # B, num_heads * dim_head * block_size ** 2, num_blocks q = q.reshape(B * self.num_heads, self.dim_head, -1, num_blocks).transpose(1, 3) # B * num_heads, num_blocks, block_size ** 2, dim_head kv = self.kv(x) - # FIXME I 'think' this unfold does what I want it to, but I should investigate - kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size) + + # generate overlapping windows using either stride tricks (as_strided) or unfold + if self.stride_tricks: + # this is much faster + kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous() + kv = kv.as_strided(( + B, self.dim_qk + self.dim_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks), + stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size)) + else: + kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size) + kv = kv.reshape( B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), -1, num_blocks).transpose(1, 3) k, v = torch.split(kv, [self.dim_head, self.dim_v // self.num_heads], dim=-1) + # B * num_heads, num_blocks, block_size ** 2, dim_head or dim_v // num_heads attn_logits = (q @ k.transpose(-1, -2)) * self.scale # FIXME should usual attn scale be applied? attn_logits = attn_logits + self.pos_embed(q) # B * num_heads, block_size ** 2, win_size ** 2 attn_out = attn_logits.softmax(dim=-1) attn_out = (attn_out @ v).transpose(1, 3) # B * num_heads, dim_v // num_heads, block_size ** 2, num_blocks - attn_out = F.fold( - attn_out.reshape(B, -1, num_blocks), - (H // self.stride, W // self.stride), - kernel_size=self.block_size // self.stride, stride=self.block_size // self.stride) + + # F.fold can be replaced by reshape + permute, slightly faster + # attn_out = F.fold( + # attn_out.reshape(B, -1, num_blocks), + # (H // self.stride, W // self.stride), kernel_size=bs_stride, stride=bs_stride) + attn_out = attn_out.reshape(-1, bs_stride, bs_stride, num_h_blocks, num_w_blocks) + attn_out = attn_out.permute(0, 3, 1, 4, 2).contiguous().view(B, self.dim_v, H // self.stride, W // self.stride) # B, dim_out, H // stride, W // stride return attn_out