From e15c3886ba8303868d5f86a43b0e5c4837eb4df2 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 29 Apr 2021 10:58:49 -0700 Subject: [PATCH] Defaul lambda r=7. Define '26t' stage 4/5 256x256 variants for all of bot/halo/lambda nets for experiment. Add resnet50t for exp. Fix a few comments. --- timm/models/byoanet.py | 32 ++++++++++++++++++++++++++---- timm/models/layers/halo_attn.py | 8 ++++---- timm/models/layers/lambda_layer.py | 2 +- timm/models/resnet.py | 12 +++++++++++ 4 files changed, 45 insertions(+), 9 deletions(-) diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index 7b7481dd..ca46043f 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -45,15 +45,16 @@ def _cfg(url='', **kwargs): default_cfgs = { # GPU-Efficient (ResNet) weights + 'botnet26t_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256)), 'botnet50t_224': _cfg(url='', fixed_input_size=True), 'botnet50t_c4c5_224': _cfg(url='', fixed_input_size=True), 'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), 'halonet_h1_c4c5': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), - 'halonet26t': _cfg(url=''), + 'halonet26t': _cfg(url='', input_size=(3, 256, 256)), 'halonet50t': _cfg(url=''), - 'lambda_resnet26t': _cfg(url='', min_input_size=(3, 128, 128)), + 'lambda_resnet26t': _cfg(url='', min_input_size=(3, 128, 128), input_size=(3, 256, 256)), 'lambda_resnet50t': _cfg(url='', min_input_size=(3, 128, 128)), } @@ -92,6 +93,21 @@ def interleave_attn( model_cfgs = dict( + botnet26t=ByoaCfg( + blocks=( + ByoaBlocksCfg(type='bottle', d=3, c=256, s=2, gs=0, br=0.25), + ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25), + interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25), + ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=1, gs=0, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + self_attn_layer='bottleneck', + self_attn_fixed_size=True, + self_attn_kwargs=dict() + ), botnet50t=ByoaCfg( blocks=( ByoaBlocksCfg(type='bottle', d=3, c=256, s=2, gs=0, br=0.25), @@ -161,7 +177,7 @@ model_cfgs = dict( blocks=( ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25), - ByoaBlocksCfg(type='bottle', d=2, c=1024, s=2, gs=0, br=0.25), + interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25), ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25), ), stem_chs=64, @@ -169,7 +185,7 @@ model_cfgs = dict( stem_pool='maxpool', num_features=0, self_attn_layer='halo', - self_attn_kwargs=dict(block_size=7, halo_size=2) + self_attn_kwargs=dict(block_size=8, halo_size=2) # intended for 256x256 res ), halonet50t=ByoaCfg( blocks=( @@ -370,6 +386,14 @@ def _create_byoanet(variant, cfg_variant=None, pretrained=False, **kwargs): **kwargs) +@register_model +def botnet26t_256(pretrained=False, **kwargs): + """ Bottleneck Transformer w/ ResNet26-T backbone. Bottleneck attn in final stage. + """ + kwargs.setdefault('img_size', 256) + return _create_byoanet('botnet26t_256', 'botnet26t', pretrained=pretrained, **kwargs) + + @register_model def botnet50t_224(pretrained=False, **kwargs): """ Bottleneck Transformer w/ ResNet50-T backbone. Bottleneck attn in final stage. diff --git a/timm/models/layers/halo_attn.py b/timm/models/layers/halo_attn.py index bd5d1b45..8452aa94 100644 --- a/timm/models/layers/halo_attn.py +++ b/timm/models/layers/halo_attn.py @@ -115,7 +115,7 @@ class HaloAttn(nn.Module): self.win_size = block_size + halo_size * 2 # neighbourhood window size self.scale = self.dim_head ** -0.5 - # FIXME not clear if this stride behaviour is what the paper intended, not really clear + # FIXME not clear if this stride behaviour is what the paper intended # Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving # data in unfolded block form. I haven't wrapped my head around how that'd look. self.q = nn.Conv2d(dim, self.dim_qk, 1, stride=self.stride, bias=qkv_bias) @@ -139,10 +139,10 @@ class HaloAttn(nn.Module): kv = self.kv(x) # FIXME I 'think' this unfold does what I want it to, but I should investigate - k = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size) - k = k.reshape( + kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size) + kv = kv.reshape( B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), -1, num_blocks).transpose(1, 3) - k, v = torch.split(k, [self.dim_head, self.dim_v // self.num_heads], dim=-1) + k, v = torch.split(kv, [self.dim_head, self.dim_v // self.num_heads], dim=-1) attn_logits = (q @ k.transpose(-1, -2)) * self.scale # FIXME should usual attn scale be applied? attn_logits = attn_logits + self.pos_embed(q) # B * num_heads, block_size ** 2, win_size ** 2 diff --git a/timm/models/layers/lambda_layer.py b/timm/models/layers/lambda_layer.py index bdaebb5d..c89982af 100644 --- a/timm/models/layers/lambda_layer.py +++ b/timm/models/layers/lambda_layer.py @@ -34,7 +34,7 @@ class LambdaLayer(nn.Module): """ def __init__( self, - dim, dim_out=None, stride=1, num_heads=4, dim_head=16, r=5, qkv_bias=False): + dim, dim_out=None, stride=1, num_heads=4, dim_head=16, r=7, qkv_bias=False): super().__init__() self.dim_out = dim_out or dim self.dim_k = dim_head # query depth 'k' diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 656e3a51..2b38b963 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -54,6 +54,9 @@ default_cfgs = { 'resnet50d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth', interpolation='bicubic', first_conv='conv1.0'), + 'resnet50t': _cfg( + url='', + interpolation='bicubic', first_conv='conv1.0'), 'resnet101': _cfg(url='', interpolation='bicubic'), 'resnet101d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet101d_ra2-2803ffab.pth', @@ -706,6 +709,15 @@ def resnet50d(pretrained=False, **kwargs): return _create_resnet('resnet50d', pretrained, **model_args) +@register_model +def resnet50t(pretrained=False, **kwargs): + """Constructs a ResNet-50-T model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs) + return _create_resnet('resnet50t', pretrained, **model_args) + + @register_model def resnet101(pretrained=False, **kwargs): """Constructs a ResNet-101 model.