diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index d34977b6..e458ca6f 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -45,6 +45,7 @@ default_cfgs = { 'halonet26t': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet26t_256-9b4bf0b3.pth', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), + 'sehalonet33ts': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), 'halonet50ts': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), 'eca_halonext26ts': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_halonext26ts_256-1e55880b.pth', @@ -131,6 +132,22 @@ model_cfgs = dict( self_attn_layer='halo', self_attn_kwargs=dict(block_size=8, halo_size=2, dim_head=16) ), + sehalonet33ts=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=[2], d=3, c=512, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=[2], d=3, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg('self_attn', d=2, c=1536, s=2, gs=0, br=0.333), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='', + act_layer='silu', + num_features=1280, + attn_layer='se', + self_attn_layer='halo', + self_attn_kwargs=dict(block_size=8, halo_size=3) + ), halonet50ts=ByoModelCfg( blocks=( ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), @@ -227,6 +244,13 @@ def halonet26t(pretrained=False, **kwargs): return _create_byoanet('halonet26t', pretrained=pretrained, **kwargs) +@register_model +def sehalonet33ts(pretrained=False, **kwargs): + """ HaloNet w/ a ResNet26-t backbone. Halo attention in final two stages + """ + return _create_byoanet('sehalonet33ts', pretrained=pretrained, **kwargs) + + @register_model def halonet50ts(pretrained=False, **kwargs): """ HaloNet w/ a ResNet50-t backbone, silu act. Halo attention in final two stages diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 66baa37a..dad42f38 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -50,7 +50,7 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet26d-69e92c46.pth', interpolation='bicubic', first_conv='conv1.0'), 'resnet26t': _cfg( - url='', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet26t_256_ra2-6f6fa748.pth', interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8)), 'resnet50': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50_ram-a26f946b.pth',