From e2b8d44ff0220bf3ccb1b11c62f222975e80606f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 6 Oct 2021 16:29:33 -0700 Subject: [PATCH] Halo, bottleneck attn, lambda layer additions and cleanup along w/ experimental model defs * align interfaces of halo, bottleneck attn and lambda layer * add qk_ratio to all of above, control q/k dim relative to output dim * add experimental haloregnetz, and trionet (lambda + halo + bottle) models --- timm/models/byoanet.py | 61 +++++++++++++++++++++++++++ timm/models/byobnet.py | 10 ++--- timm/models/layers/bottleneck_attn.py | 60 ++++++++++++++++++-------- timm/models/layers/halo_attn.py | 59 ++++++++++++++++++++------ timm/models/layers/lambda_layer.py | 54 ++++++++++++++++-------- 5 files changed, 191 insertions(+), 53 deletions(-) diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index 61f94490..8c816f6e 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -66,6 +66,13 @@ default_cfgs = { 'lambda_resnet26rpt_256': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26rpt_a2h_256-482adad8.pth', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), + + 'haloregnetz_b': _cfg( + url='', + input_size=(3, 224, 224), pool_size=(7, 7), min_input_size=(3, 224, 224), crop_pct=0.94), + 'trionet50ts_256': _cfg( + url='', + fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), } @@ -232,6 +239,46 @@ model_cfgs = dict( self_attn_layer='lambda', self_attn_kwargs=dict(r=None) ), + + # experimental + haloregnetz_b=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=3), + ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=3), + interleave_blocks(types=('bottle', 'self_attn'), every=3, d=12, c=192, s=2, gs=16, br=3), + ByoBlockCfg('self_attn', d=2, c=288, s=2, gs=16, br=3), + ), + stem_chs=32, + stem_pool='', + downsample='', + num_features=1536, + act_layer='silu', + attn_layer='se', + attn_kwargs=dict(rd_ratio=0.25), + block_kwargs=dict(bottle_in=True, linear_out=True), + self_attn_layer='halo', + self_attn_kwargs=dict(block_size=7, halo_size=2, qk_ratio=0.33) + ), + + # experimental + trionet50ts=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), + interleave_blocks( + types=('bottle', 'self_attn'), d=4, c=512, s=2, gs=0, br=0.25, + self_attn_layer='lambda', self_attn_kwargs=dict(r=13)), + interleave_blocks( + types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25, + self_attn_layer='halo', self_attn_kwargs=dict(halo_size=3)), + interleave_blocks( + types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25, + self_attn_layer='bottleneck', self_attn_kwargs=dict()), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='', + act_layer='silu', + ), ) @@ -327,3 +374,17 @@ def lambda_resnet26rpt_256(pretrained=False, **kwargs): """ kwargs.setdefault('img_size', 256) return _create_byoanet('lambda_resnet26rpt_256', pretrained=pretrained, **kwargs) + + +@register_model +def haloregnetz_b(pretrained=False, **kwargs): + """ Halo + RegNetZ + """ + return _create_byoanet('haloregnetz_b', pretrained=pretrained, **kwargs) + + +@register_model +def trionet50ts_256(pretrained=False, **kwargs): + """ HaloNet w/ a ResNet50-t backbone, silu act. Halo attention in final two stages + """ + return _create_byoanet('trionet50ts_256', 'trionet50ts', pretrained=pretrained, **kwargs) diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 515f2073..4ac6ece3 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -1096,18 +1096,16 @@ class SelfAttnBlock(nn.Module): self.self_attn.reset_parameters() def forward(self, x): - shortcut = self.shortcut(x) - + shortcut = x x = self.conv1_1x1(x) x = self.conv2_kxk(x) x = self.self_attn(x) x = self.post_attn(x) x = self.conv3_1x1(x) x = self.drop_path(x) - - x = self.act(x + shortcut) - return x - + if self.shortcut is not None: + x = x + self.shortcut(shortcut) + return self.act(x) _block_registry = dict( basic=BasicBlock, diff --git a/timm/models/layers/bottleneck_attn.py b/timm/models/layers/bottleneck_attn.py index 61859f9c..15df62ae 100644 --- a/timm/models/layers/bottleneck_attn.py +++ b/timm/models/layers/bottleneck_attn.py @@ -20,7 +20,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .helpers import to_2tuple +from .helpers import to_2tuple, make_divisible from .weight_init import trunc_normal_ @@ -66,10 +66,10 @@ class PosEmbedRel(nn.Module): self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * self.scale) def forward(self, q): - B, num_heads, HW, _ = q.shape + B, HW, _ = q.shape # relative logits in width dimension. - q = q.reshape(B * num_heads, self.height, self.width, -1) + q = q.reshape(B, self.height, self.width, -1) rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4)) # relative logits in height dimension. @@ -77,35 +77,56 @@ class PosEmbedRel(nn.Module): rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2)) rel_logits = rel_logits_h + rel_logits_w - rel_logits = rel_logits.reshape(B, num_heads, HW, HW) + rel_logits = rel_logits.reshape(B, HW, HW) return rel_logits class BottleneckAttn(nn.Module): """ Bottleneck Attention Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605 + + The internal dimensions of the attention module are controlled by the interaction of several arguments. + * the output dimension of the module is specified by dim_out, which falls back to input dim if not set + * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim + * the query and key (qk) dimensions are determined by + * num_heads * dim_head if dim_head is not None + * num_heads * (dim_out * attn_ratio // num_heads) if dim_head is None + * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not used + + Args: + dim (int): input dimension to the module + dim_out (int): output dimension of the module, same as dim if not set + stride (int): output stride of the module, avg pool used if stride == 2 (default: 1). + num_heads (int): parallel attention heads (default: 4) + dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set + qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0) + qkv_bias (bool): add bias to q, k, and v projections """ - def __init__(self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, qkv_bias=False): + def __init__( + self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=None, + qk_ratio=1.0, qkv_bias=False): super().__init__() assert feat_size is not None, 'A concrete feature size matching expected input (H, W) is required' dim_out = dim_out or dim assert dim_out % num_heads == 0 self.num_heads = num_heads - self.dim_out = dim_out - self.dim_head = dim_out // num_heads - self.scale = self.dim_head ** -0.5 + self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads + self.dim_head_v = dim_out // self.num_heads + self.dim_out_qk = num_heads * self.dim_head_qk + self.dim_out_v = num_heads * self.dim_head_v + self.scale = self.dim_head_qk ** -0.5 - self.qkv = nn.Conv2d(dim, self.dim_out * 3, 1, bias=qkv_bias) + self.qkv = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias) # NOTE I'm only supporting relative pos embedding for now - self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head, scale=self.scale) + self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head_qk, scale=self.scale) self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() self.reset_parameters() def reset_parameters(self): - trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) + trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in trunc_normal_(self.pos_embed.height_rel, std=self.scale) trunc_normal_(self.pos_embed.width_rel, std=self.scale) @@ -114,15 +135,20 @@ class BottleneckAttn(nn.Module): assert H == self.pos_embed.height assert W == self.pos_embed.width - x = self.qkv(x) # B, 3 * num_heads * dim_head, H, W - x = x.reshape(B, -1, self.dim_head, H * W).transpose(-1, -2) - q, k, v = torch.split(x, self.num_heads, dim=1) + x = self.qkv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W + + # NOTE head vs channel split ordering in qkv projection was decided before I allowed qk to differ from v + # So, this is more verbose than if heads were before qkv splits, but throughput is not impacted. + q, k, v = torch.split(x, [self.dim_out_qk, self.dim_out_qk, self.dim_out_v], dim=1) + q = q.reshape(B * self.num_heads, self.dim_head_qk, -1).transpose(-1, -2) + k = k.reshape(B * self.num_heads, self.dim_head_qk, -1) # no transpose, for q @ k + v = v.reshape(B * self.num_heads, self.dim_head_v, -1).transpose(-1, -2) - attn = (q @ k.transpose(-1, -2)) * self.scale - attn = attn + self.pos_embed(q) # B, num_heads, H * W, H * W + attn = (q @ k) * self.scale + attn = attn + self.pos_embed(q) # B * num_heads, H * W, H * W attn = attn.softmax(dim=-1) - out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W + out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W out = self.pool(out) return out diff --git a/timm/models/layers/halo_attn.py b/timm/models/layers/halo_attn.py index 034c66a8..05fb1f6a 100644 --- a/timm/models/layers/halo_attn.py +++ b/timm/models/layers/halo_attn.py @@ -22,6 +22,7 @@ import torch from torch import nn import torch.nn.functional as F +from .helpers import make_divisible from .weight_init import trunc_normal_ @@ -98,31 +99,62 @@ class HaloAttn(nn.Module): Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones` - https://arxiv.org/abs/2103.12731 + + The internal dimensions of the attention module are controlled by the interaction of several arguments. + * the output dimension of the module is specified by dim_out, which falls back to input dim if not set + * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim + * the query and key (qk) dimensions are determined by + * num_heads * dim_head if dim_head is not None + * num_heads * (dim_out * attn_ratio // num_heads) if dim_head is None + * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not used + + Args: + dim (int): input dimension to the module + dim_out (int): output dimension of the module, same as dim if not set + feat_size (Tuple[int, int]): size of input feature_map (not used, for arg compat with bottle/lambda) + stride: output stride of the module, query downscaled if > 1 (default: 1). + num_heads: parallel attention heads (default: 8). + dim_head: dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set + block_size (int): size of blocks. (default: 8) + halo_size (int): size of halo overlap. (default: 3) + qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0) + qkv_bias (bool) : add bias to q, k, and v projections + avg_down (bool): use average pool downsample instead of strided query blocks + """ def __init__( - self, dim, dim_out=None, stride=1, num_heads=8, dim_head=None, block_size=8, halo_size=3, qkv_bias=False): + self, dim, dim_out=None, feat_size=None, stride=1, num_heads=8, dim_head=None, block_size=8, halo_size=3, + qk_ratio=1.0, qkv_bias=False, avg_down=False): super().__init__() dim_out = dim_out or dim assert dim_out % num_heads == 0 - self.stride = stride + assert stride in (1, 2) self.num_heads = num_heads - self.dim_head_qk = dim_head or dim_out // num_heads + self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads self.dim_head_v = dim_out // self.num_heads self.dim_out_qk = num_heads * self.dim_head_qk self.dim_out_v = num_heads * self.dim_head_v - self.block_size = block_size + self.scale = self.dim_head_qk ** -0.5 + self.block_size = self.block_size_ds = block_size self.halo_size = halo_size self.win_size = block_size + halo_size * 2 # neighbourhood window size - self.scale = self.dim_head_qk ** -0.5 + self.block_stride = 1 + use_avg_pool = False + if stride > 1: + use_avg_pool = avg_down or block_size % stride != 0 + self.block_stride = 1 if use_avg_pool else stride + self.block_size_ds = self.block_size // self.block_stride # FIXME not clear if this stride behaviour is what the paper intended # Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving # data in unfolded block form. I haven't wrapped my head around how that'd look. - self.q = nn.Conv2d(dim, self.dim_out_qk, 1, stride=self.stride, bias=qkv_bias) + self.q = nn.Conv2d(dim, self.dim_out_qk, 1, stride=self.block_stride, bias=qkv_bias) self.kv = nn.Conv2d(dim, self.dim_out_qk + self.dim_out_v, 1, bias=qkv_bias) self.pos_embed = PosEmbedRel( - block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head_qk, scale=self.scale) + block_size=self.block_size_ds, win_size=self.win_size, dim_head=self.dim_head_qk, scale=self.scale) + + self.pool = nn.AvgPool2d(2, 2) if use_avg_pool else nn.Identity() self.reset_parameters() @@ -140,11 +172,12 @@ class HaloAttn(nn.Module): num_h_blocks = H // self.block_size num_w_blocks = W // self.block_size num_blocks = num_h_blocks * num_w_blocks - bs_stride = self.block_size // self.stride q = self.q(x) # unfold - q = q.reshape(-1, self.dim_head_qk, num_h_blocks, bs_stride, num_w_blocks, bs_stride).permute(0, 1, 3, 5, 2, 4) + q = q.reshape( + -1, self.dim_head_qk, + num_h_blocks, self.block_size_ds, num_w_blocks, self.block_size_ds).permute(0, 1, 3, 5, 2, 4) # B, num_heads * dim_head * block_size ** 2, num_blocks q = q.reshape(B * self.num_heads, self.dim_head_qk, -1, num_blocks).transpose(1, 3) # B * num_heads, num_blocks, block_size ** 2, dim_head @@ -163,9 +196,11 @@ class HaloAttn(nn.Module): out = (attn @ v).transpose(1, 3) # B * num_heads, dim_head_v, block_size ** 2, num_blocks # fold - out = out.reshape(-1, bs_stride, bs_stride, num_h_blocks, num_w_blocks) - out = out.permute(0, 3, 1, 4, 2).contiguous().view(B, self.dim_out_v, H // self.stride, W // self.stride) - # B, dim_out, H // stride, W // stride + out = out.reshape(-1, self.block_size_ds, self.block_size_ds, num_h_blocks, num_w_blocks) + out = out.permute(0, 3, 1, 4, 2).contiguous().view( + B, self.dim_out_v, H // self.block_stride, W // self.block_stride) + # B, dim_out, H // block_stride, W // block_stride + out = self.pool(out) return out diff --git a/timm/models/layers/lambda_layer.py b/timm/models/layers/lambda_layer.py index eeb77e45..e50b43c8 100644 --- a/timm/models/layers/lambda_layer.py +++ b/timm/models/layers/lambda_layer.py @@ -24,7 +24,7 @@ import torch from torch import nn import torch.nn.functional as F -from .helpers import to_2tuple +from .helpers import to_2tuple, make_divisible from .weight_init import trunc_normal_ @@ -44,28 +44,46 @@ class LambdaLayer(nn.Module): - https://arxiv.org/abs/2102.08602 NOTE: intra-depth parameter 'u' is fixed at 1. It did not appear worth the complexity to add. + + The internal dimensions of the lambda module are controlled via the interaction of several arguments. + * the output dimension of the module is specified by dim_out, which falls back to input dim if not set + * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim + * the query (q) and key (k) dimension are determined by + * dim_head = (dim_out * attn_ratio // num_heads) if dim_head is None + * q = num_heads * dim_head, k = dim_head + * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not set + + Args: + dim (int): input dimension to the module + dim_out (int): output dimension of the module, same as dim if not set + feat_size (Tuple[int, int]): size of input feature_map for relative pos variant H, W + stride (int): output stride of the module, avg pool used if stride == 2 + num_heads (int): parallel attention heads. + dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set + r (int): local lambda convolution radius. Use lambda conv if set, else relative pos if not. (default: 9) + qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0) + qkv_bias (bool): add bias to q, k, and v projections """ def __init__( - self, - dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=16, r=7, qkv_bias=False): + self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=16, r=9, + qk_ratio=1.0, qkv_bias=False): super().__init__() - self.dim = dim - self.dim_out = dim_out or dim - self.dim_k = dim_head # query depth 'k' + dim_out = dim_out or dim + assert dim_out % num_heads == 0, ' should be divided by num_heads' + self.dim_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads self.num_heads = num_heads - assert self.dim_out % num_heads == 0, ' should be divided by num_heads' - self.dim_v = self.dim_out // num_heads # value depth 'v' + self.dim_v = dim_out // num_heads self.qkv = nn.Conv2d( dim, - num_heads * dim_head + dim_head + self.dim_v, + num_heads * self.dim_qk + self.dim_qk + self.dim_v, kernel_size=1, bias=qkv_bias) - self.norm_q = nn.BatchNorm2d(num_heads * dim_head) + self.norm_q = nn.BatchNorm2d(num_heads * self.dim_qk) self.norm_v = nn.BatchNorm2d(self.dim_v) if r is not None: # local lambda convolution for pos - self.conv_lambda = nn.Conv3d(1, dim_head, (r, r, 1), padding=(r // 2, r // 2, 0)) + self.conv_lambda = nn.Conv3d(1, self.dim_qk, (r, r, 1), padding=(r // 2, r // 2, 0)) self.pos_emb = None self.rel_pos_indices = None else: @@ -74,7 +92,7 @@ class LambdaLayer(nn.Module): feat_size = to_2tuple(feat_size) rel_size = [2 * s - 1 for s in feat_size] self.conv_lambda = None - self.pos_emb = nn.Parameter(torch.zeros(rel_size[0], rel_size[1], self.dim_k)) + self.pos_emb = nn.Parameter(torch.zeros(rel_size[0], rel_size[1], self.dim_qk)) self.register_buffer('rel_pos_indices', rel_pos_indices(feat_size), persistent=False) self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() @@ -82,9 +100,9 @@ class LambdaLayer(nn.Module): self.reset_parameters() def reset_parameters(self): - trunc_normal_(self.qkv.weight, std=self.dim ** -0.5) + trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in if self.conv_lambda is not None: - trunc_normal_(self.conv_lambda.weight, std=self.dim_k ** -0.5) + trunc_normal_(self.conv_lambda.weight, std=self.dim_qk ** -0.5) if self.pos_emb is not None: trunc_normal_(self.pos_emb, std=.02) @@ -93,17 +111,17 @@ class LambdaLayer(nn.Module): M = H * W qkv = self.qkv(x) q, k, v = torch.split(qkv, [ - self.num_heads * self.dim_k, self.dim_k, self.dim_v], dim=1) - q = self.norm_q(q).reshape(B, self.num_heads, self.dim_k, M).transpose(-1, -2) # B, num_heads, M, K + self.num_heads * self.dim_qk, self.dim_qk, self.dim_v], dim=1) + q = self.norm_q(q).reshape(B, self.num_heads, self.dim_qk, M).transpose(-1, -2) # B, num_heads, M, K v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V - k = F.softmax(k.reshape(B, self.dim_k, M), dim=-1) # B, K, M + k = F.softmax(k.reshape(B, self.dim_qk, M), dim=-1) # B, K, M content_lam = k @ v # B, K, V content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V if self.pos_emb is None: position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K - position_lam = position_lam.reshape(B, 1, self.dim_k, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V + position_lam = position_lam.reshape(B, 1, self.dim_qk, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V else: # FIXME relative pos embedding path not fully verified pos_emb = self.pos_emb[self.rel_pos_indices[0], self.rel_pos_indices[1]].expand(B, -1, -1, -1)