Add ResNet-50 w/ GN (resnet50_gn) and SEBotNet-33-TS (sebotnet33ts_256) model defs and weights. Update halonet50ts weights w/ slightly better variant in1k val, more robust to test sets.

pull/989/head
Ross Wightman 3 years ago
parent 9b3519545d
commit c976a410d9

@ -36,6 +36,9 @@ default_cfgs = {
'botnet26t_256': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/botnet26t_c1_256-167a0e9f.pth',
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'sebotnet33ts_256': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sebotnet33ts_a1h2_256-957e3c3e.pth',
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94),
'botnet50ts_256': _cfg(
url='',
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
@ -51,7 +54,7 @@ default_cfgs = {
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sehalonet33ts_256-87e053f9.pth',
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
'halonet50ts': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet50ts_a1h_256-c6d7ff15.pth',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet50ts_a1h2_256-f3a3daee.pth',
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
'eca_halonext26ts': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_halonext26ts_c_256-06906299.pth',
@ -97,6 +100,22 @@ model_cfgs = dict(
self_attn_layer='bottleneck',
self_attn_kwargs=dict()
),
sebotnet33ts=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=[2], d=3, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=[2], d=3, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg('self_attn', d=2, c=1536, s=2, gs=0, br=0.333),
),
stem_chs=64,
stem_type='tiered',
stem_pool='',
act_layer='silu',
num_features=1280,
attn_layer='se',
self_attn_layer='bottleneck',
self_attn_kwargs=dict()
),
botnet50ts=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
@ -322,6 +341,13 @@ def botnet26t_256(pretrained=False, **kwargs):
return _create_byoanet('botnet26t_256', 'botnet26t', pretrained=pretrained, **kwargs)
@register_model
def sebotnet33ts_256(pretrained=False, **kwargs):
""" Bottleneck Transformer w/ a ResNet33-t backbone, SE attn for non Halo blocks, SiLU,
"""
return _create_byoanet('sebotnet33ts_256', 'sebotnet33ts', pretrained=pretrained, **kwargs)
@register_model
def botnet50ts_256(pretrained=False, **kwargs):
""" Bottleneck Transformer w/ ResNet50-T backbone, silu act.

@ -6,7 +6,7 @@ import torch.nn.functional as F
class GroupNorm(nn.GroupNorm):
def __init__(self, num_channels, num_groups, eps=1e-5, affine=True):
def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True):
# NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN
super().__init__(num_groups, num_channels, eps=eps, affine=affine)

@ -15,7 +15,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, create_attn, get_attn, create_classifier
from .layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, create_attn, get_attn, create_classifier
from .registry import register_model
__all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this
@ -89,6 +89,11 @@ default_cfgs = {
interpolation='bicubic'),
'wide_resnet101_2': _cfg(url='https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth'),
# ResNets w/ alternative norm layers
'resnet50_gn': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_gn_a1h2-8fe6c4d0.pth',
crop_pct=0.94, interpolation='bicubic'),
# ResNeXt
'resnext50_32x4d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnext50_32x4d_a1h-0146ab0a.pth',
@ -881,6 +886,14 @@ def wide_resnet101_2(pretrained=False, **kwargs):
return _create_resnet('wide_resnet101_2', pretrained, **model_args)
@register_model
def resnet50_gn(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model w/ GroupNorm
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs)
return _create_resnet('resnet50_gn', pretrained, norm_layer=GroupNorm, **model_args)
@register_model
def resnext50_32x4d(pretrained=False, **kwargs):
"""Constructs a ResNeXt50-32x4d model.

Loading…
Cancel
Save