Add new weights for ecaresnet26t/50t/269d models. Remove distinction between 't' and 'tn' (tiered models), tn is now t. Add test time img size spec to default cfg.

pull/413/head
Ross Wightman 3 years ago
parent 4203efa36d
commit 68a4144882

@ -122,8 +122,6 @@ model_list = [
model_desc='Block cfg of SE-ResNeXt-34 w/ Bottleneck, deep stem, and avg-pool in downsample layers.'),
_entry('seresnext26t_32x4d', 'SE-ResNeXt-26-T 32x4d', '1812.01187',
model_desc='Block cfg of SE-ResNeXt-34 w/ Bottleneck, deep tiered stem, and avg-pool in downsample layers.'),
_entry('seresnext26tn_32x4d', 'SE-ResNeXt-26-TN 32x4d', '1812.01187',
model_desc='Block cfg of SE-ResNeXt-34 w/ Bottleneck, deep tiered narrow stem, and avg-pool in downsample layers.'),
_entry('seresnext50_32x4d', 'SE-ResNeXt-50 32x4d', '1709.01507'),
_entry('skresnet18', 'SK-ResNet-18', '1903.06586'),

@ -5,7 +5,7 @@ from .constants import *
_logger = logging.getLogger(__name__)
def resolve_data_config(args, default_cfg={}, model=None, verbose=True):
def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=True):
new_config = {}
default_cfg = default_cfg
if not default_cfg and model is not None and hasattr(model, 'default_cfg'):
@ -25,8 +25,11 @@ def resolve_data_config(args, default_cfg={}, model=None, verbose=True):
elif 'img_size' in args and args['img_size'] is not None:
assert isinstance(args['img_size'], int)
input_size = (in_chans, args['img_size'], args['img_size'])
elif 'input_size' in default_cfg:
input_size = default_cfg['input_size']
else:
if use_test_size and 'test_input_size' in default_cfg:
input_size = default_cfg['test_input_size']
elif 'input_size' in default_cfg:
input_size = default_cfg['input_size']
new_config['input_size'] = input_size
# resolve interpolation method

@ -36,14 +36,17 @@ class TestTimePoolHead(nn.Module):
return x.view(x.size(0), -1)
def apply_test_time_pool(model, config):
def apply_test_time_pool(model, config, use_test_size=True):
test_time_pool = False
if not hasattr(model, 'default_cfg') or not model.default_cfg:
return model, False
if (config['input_size'][-1] > model.default_cfg['input_size'][-1] and
config['input_size'][-2] > model.default_cfg['input_size'][-2]):
if use_test_size and 'test_input_size' in model.default_cfg:
df_input_size = model.default_cfg['test_input_size']
else:
df_input_size = model.default_cfg['input_size']
if config['input_size'][-1] > df_input_size[-1] and config['input_size'][-2] > df_input_size[-2]:
_logger.info('Target input size %s > pretrained default %s, using test time pooling' %
(str(config['input_size'][-2:]), str(model.default_cfg['input_size'][-2:])))
(str(config['input_size'][-2:]), str(df_input_size[-2:])))
model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size'])
test_time_pool = True
return model, test_time_pool

@ -7,7 +7,6 @@ ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered ste
Copyright 2020 Ross Wightman
"""
import math
import copy
import torch
import torch.nn as nn
@ -58,24 +57,18 @@ default_cfgs = {
'resnet101': _cfg(url='', interpolation='bicubic'),
'resnet101d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet101d_ra2-2803ffab.pth',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
'resnet101d_320': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet101d_ra2-2803ffab.pth',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 320, 320), crop_pct=1.0, pool_size=(10, 10)),
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8),
crop_pct=1.0, test_input_size=(3, 320, 320)),
'resnet152': _cfg(url='', interpolation='bicubic'),
'resnet152d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet152d_ra2-5cac0439.pth',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
'resnet152d_320': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet152d_ra2-5cac0439.pth',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 320, 320), crop_pct=1.0, pool_size=(10, 10)),
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8),
crop_pct=1.0, test_input_size=(3, 320, 320)),
'resnet200': _cfg(url='', interpolation='bicubic'),
'resnet200d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet200d_ra2-bdba9bf9.pth',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
'resnet200d_320': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet200d_ra2-bdba9bf9.pth',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 320, 320), crop_pct=1.0, pool_size=(10, 10)),
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8),
crop_pct=1.0, test_input_size=(3, 320, 320)),
'tv_resnet34': _cfg(url='https://download.pytorch.org/models/resnet34-333f7ec4.pth'),
'tv_resnet50': _cfg(url='https://download.pytorch.org/models/resnet50-19c8e357.pth'),
'tv_resnet101': _cfg(url='https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'),
@ -146,7 +139,7 @@ default_cfgs = {
'seresnet50': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet50_ra_224-8efdb4bb.pth',
interpolation='bicubic'),
'seresnet50tn': _cfg(
'seresnet50t': _cfg(
url='',
interpolation='bicubic',
first_conv='conv1.0'),
@ -158,10 +151,9 @@ default_cfgs = {
interpolation='bicubic'),
'seresnet152d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet152d_ra2-04464dd2.pth',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
'seresnet152d_320': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet152d_ra2-04464dd2.pth',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 320, 320), crop_pct=1.0, pool_size=(10, 10)),
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8),
crop_pct=1.0, test_input_size=(3, 320, 320)
),
'seresnet200d': _cfg(
url='',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
@ -171,18 +163,11 @@ default_cfgs = {
# Squeeze-Excitation ResNeXts, to eventually replace the models in senet.py
'seresnext26_32x4d': _cfg(
url='',
interpolation='bicubic'),
'seresnext26d_32x4d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26d_32x4d-80fa48a3.pth',
interpolation='bicubic',
first_conv='conv1.0'),
'seresnext26t_32x4d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26t_32x4d-361bc1c4.pth',
interpolation='bicubic',
first_conv='conv1.0'),
'seresnext26tn_32x4d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26tn_32x4d-569cb627.pth',
interpolation='bicubic',
first_conv='conv1.0'),
@ -201,8 +186,10 @@ default_cfgs = {
first_conv='conv1.0'),
# Efficient Channel Attention ResNets
'ecaresnet18': _cfg(),
'ecaresnet50': _cfg(),
'ecaresnet26t': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet26t_ra2-46609757.pth',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8),
crop_pct=0.95, test_input_size=(3, 320, 320)),
'ecaresnetlight': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNetLight_4f34b35b.pth',
interpolation='bicubic'),
@ -214,10 +201,13 @@ default_cfgs = {
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45899/outputs/ECAResNet50D_P_9c67f710.pth',
interpolation='bicubic',
first_conv='conv1.0'),
'ecaresnet50t': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet50t_ra2-f7ac63c4.pth',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8),
crop_pct=0.95, test_input_size=(3, 320, 320)),
'ecaresnet101d': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet101D_281c5844.pth',
interpolation='bicubic',
first_conv='conv1.0'),
interpolation='bicubic', first_conv='conv1.0'),
'ecaresnet101d_pruned': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth',
interpolation='bicubic',
@ -226,17 +216,17 @@ default_cfgs = {
url='',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
'ecaresnet269d': _cfg(
url='',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet269d_320_ra2-7baa55cb.pth',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 320, 320), pool_size=(8, 8),
crop_pct=1.0, test_input_size=(3, 352, 352)),
# Efficient Channel Attention ResNeXts
'ecaresnext26tn_32x4d': _cfg(
'ecaresnext26t_32x4d': _cfg(
url='',
interpolation='bicubic',
first_conv='conv1.0'),
'ecaresnext50_32x4d': _cfg(
interpolation='bicubic', first_conv='conv1.0'),
'ecaresnext50t_32x4d': _cfg(
url='',
interpolation='bicubic'),
interpolation='bicubic', first_conv='conv1.0'),
# ResNets with anti-aliasing blur pool
'resnetblur18': _cfg(
@ -529,8 +519,7 @@ class ResNet(nn.Module):
The type of stem:
* '', default - a single 7x7 conv with a width of stem_width
* 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2
* 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width//4 * 6, stem_width * 2
* 'deep_tiered_narrow' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2
* 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2
block_reduce_first: int, default 1
Reduction factor for first convolution output width of residual blocks,
1 for all archs except senets, where 2
@ -564,18 +553,17 @@ class ResNet(nn.Module):
deep_stem = 'deep' in stem_type
inplanes = stem_width * 2 if deep_stem else 64
if deep_stem:
stem_chs_1 = stem_chs_2 = stem_width
stem_chs = (stem_width, stem_width)
if 'tiered' in stem_type:
stem_chs_1 = 3 * (stem_width // 4)
stem_chs_2 = stem_width if 'narrow' in stem_type else 6 * (stem_width // 4)
stem_chs = (3 * (stem_width // 4), stem_width)
self.conv1 = nn.Sequential(*[
nn.Conv2d(in_chans, stem_chs_1, 3, stride=2, padding=1, bias=False),
norm_layer(stem_chs_1),
nn.Conv2d(in_chans, stem_chs[0], 3, stride=2, padding=1, bias=False),
norm_layer(stem_chs[0]),
act_layer(inplace=True),
nn.Conv2d(stem_chs_1, stem_chs_2, 3, stride=1, padding=1, bias=False),
norm_layer(stem_chs_2),
nn.Conv2d(stem_chs[0], stem_chs[1], 3, stride=1, padding=1, bias=False),
norm_layer(stem_chs[1]),
act_layer(inplace=True),
nn.Conv2d(stem_chs_2, inplanes, 3, stride=1, padding=1, bias=False)])
nn.Conv2d(stem_chs[1], inplanes, 3, stride=1, padding=1, bias=False)])
else:
self.conv1 = nn.Conv2d(in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = norm_layer(inplanes)
@ -732,14 +720,6 @@ def resnet101d(pretrained=False, **kwargs):
return _create_resnet('resnet101d', pretrained, **model_args)
@register_model
def resnet101d_320(pretrained=False, **kwargs):
"""Constructs a ResNet-101-D model.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnet101d_320', pretrained, **model_args)
@register_model
def resnet152(pretrained=False, **kwargs):
"""Constructs a ResNet-152 model.
@ -757,15 +737,6 @@ def resnet152d(pretrained=False, **kwargs):
return _create_resnet('resnet152d', pretrained, **model_args)
@register_model
def resnet152d_320(pretrained=False, **kwargs):
"""Constructs a ResNet-152-D model.
"""
model_args = dict(
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnet152d_320', pretrained, **model_args)
@register_model
def resnet200(pretrained=False, **kwargs):
"""Constructs a ResNet-200 model.
@ -783,15 +754,6 @@ def resnet200d(pretrained=False, **kwargs):
return _create_resnet('resnet200d', pretrained, **model_args)
@register_model
def resnet200d_320(pretrained=False, **kwargs):
"""Constructs a ResNet-200-D model. NOTE: Duplicate of 200D above w/ diff default cfg for 320x320.
"""
model_args = dict(
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnet200d_320', pretrained, **model_args)
@register_model
def tv_resnet34(pretrained=False, **kwargs):
"""Constructs a ResNet-34 model with original Torchvision weights.
@ -1075,6 +1037,18 @@ def ecaresnet18(pretrained=False, **kwargs):
return _create_resnet('ecaresnet18', pretrained, **model_args)
@register_model
def ecaresnet26t(pretrained=False, **kwargs):
"""Constructs an ECA-ResNeXt-26-T model.
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
in the deep stem and ECA attn.
"""
model_args = dict(
block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32,
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs)
return _create_resnet('ecaresnet26t', pretrained, **model_args)
@register_model
def ecaresnet50(pretrained=False, **kwargs):
"""Constructs an ECA-ResNet-50 model.
@ -1104,6 +1078,17 @@ def ecaresnet50d_pruned(pretrained=False, **kwargs):
return _create_resnet('ecaresnet50d_pruned', pretrained, pruned=True, **model_args)
@register_model
def ecaresnet50t(pretrained=False, **kwargs):
"""Constructs an ECA-ResNet-50-T model.
Like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels in the deep stem and ECA attn.
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32,
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs)
return _create_resnet('ecaresnet50t', pretrained, **model_args)
@register_model
def ecaresnetlight(pretrained=False, **kwargs):
"""Constructs a ResNet-50-D light model with eca.
@ -1156,16 +1141,27 @@ def ecaresnet269d(pretrained=False, **kwargs):
@register_model
def ecaresnext26tn_32x4d(pretrained=False, **kwargs):
"""Constructs an ECA-ResNeXt-26-TN model.
def ecaresnext26t_32x4d(pretrained=False, **kwargs):
"""Constructs an ECA-ResNeXt-26-T model.
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant.
this model replaces SE module with the ECA module
in the deep stem. This model replaces SE module with the ECA module
"""
model_args = dict(
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
stem_type='deep_tiered_narrow', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs)
return _create_resnet('ecaresnext26tn_32x4d', pretrained, **model_args)
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs)
return _create_resnet('ecaresnext26t_32x4d', pretrained, **model_args)
@register_model
def ecaresnext50t_32x4d(pretrained=False, **kwargs):
"""Constructs an ECA-ResNeXt-50-T model.
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
in the deep stem. This model replaces SE module with the ECA module
"""
model_args = dict(
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs)
return _create_resnet('ecaresnext50t_32x4d', pretrained, **model_args)
@register_model
@ -1203,11 +1199,11 @@ def seresnet50(pretrained=False, **kwargs):
@register_model
def seresnet50tn(pretrained=False, **kwargs):
def seresnet50t(pretrained=False, **kwargs):
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered_narrow', avg_down=True,
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered', avg_down=True,
block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnet50tn', pretrained, **model_args)
return _create_resnet('seresnet50t', pretrained, **model_args)
@register_model
@ -1250,22 +1246,6 @@ def seresnet269d(pretrained=False, **kwargs):
return _create_resnet('seresnet269d', pretrained, **model_args)
@register_model
def seresnet152d_320(pretrained=False, **kwargs):
model_args = dict(
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnet152d_320', pretrained, **model_args)
@register_model
def seresnext26_32x4d(pretrained=False, **kwargs):
model_args = dict(
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4,
block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnext26_32x4d', pretrained, **model_args)
@register_model
def seresnext26d_32x4d(pretrained=False, **kwargs):
"""Constructs a SE-ResNeXt-26-D model.`
@ -1281,7 +1261,7 @@ def seresnext26d_32x4d(pretrained=False, **kwargs):
@register_model
def seresnext26t_32x4d(pretrained=False, **kwargs):
"""Constructs a SE-ResNet-26-T model.
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 48, 64 channels
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
in the deep stem.
"""
model_args = dict(
@ -1292,14 +1272,11 @@ def seresnext26t_32x4d(pretrained=False, **kwargs):
@register_model
def seresnext26tn_32x4d(pretrained=False, **kwargs):
"""Constructs a SE-ResNeXt-26-TN model.
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant.
"""Constructs a SE-ResNeXt-26-T model.
NOTE I deprecated previous 't' model defs and replaced 't' with 'tn', this was the only tn model of note
so keeping this def for backwards compat with any uses out there. Old 't' model is lost.
"""
model_args = dict(
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
stem_type='deep_tiered_narrow', avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnext26tn_32x4d', pretrained, **model_args)
return seresnext26t_32x4d(pretrained=pretrained, **kwargs)
@register_model

Loading…
Cancel
Save