Model updates. Add my best ResNet50 weights top-1=78.47. Add some other torchvision weights.

* Remove some models that don't exist as pretrained an likely never will (se)resnext152
* Add some torchvision weights as tv_ for models that I have added better weights for
* Add wide resnet recently added to torchvision along with resnext101-32x8d
* Add functionality to model registry to allow filtering on pretrained weight presence
pull/16/head
Ross Wightman 6 years ago
parent 65a634626f
commit b8762cc67d

@ -67,6 +67,7 @@ I've leveraged the training scripts in this repository to train a few of the mod
|Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling | |Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling |
|---|---|---|---|---| |---|---|---|---|---|
| resnext50_32x4d | 78.512 (21.488) | 94.042 (5.958) | 25M | bicubic | | resnext50_32x4d | 78.512 (21.488) | 94.042 (5.958) | 25M | bicubic |
| resnet50 | 78.470 (21.530) | 94.266 (5.734) | 25.6M | bicubic |
| seresnext26_32x4d | 77.104 (22.896) | 93.316 (6.684) | 16.8M | bicubic | | seresnext26_32x4d | 77.104 (22.896) | 93.316 (6.684) | 16.8M | bicubic |
| efficientnet_b0 | 76.912 (23.088) | 93.210 (6.790) | 5.29M | bicubic | | efficientnet_b0 | 76.912 (23.088) | 93.210 (6.790) | 5.29M | bicubic |
| mobilenetv3_100 | 75.634 (24.366) | 92.708 (7.292) | 5.5M | bicubic | | mobilenetv3_100 | 75.634 (24.366) | 92.708 (7.292) | 5.5M | bicubic |

@ -35,9 +35,9 @@ def _cfg(url=''):
default_cfgs = { default_cfgs = {
'dpn68': _cfg( 'dpn68': _cfg(
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68-66bebafa7.pth'), url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68-66bebafa7.pth'),
'dpn68b_extra': _cfg( 'dpn68b': _cfg(
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68b_extra-84854c156.pth'), url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68b_extra-84854c156.pth'),
'dpn92_extra': _cfg( 'dpn92': _cfg(
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn92_extra-b040e4a9b.pth'), url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn92_extra-b040e4a9b.pth'),
'dpn98': _cfg( 'dpn98': _cfg(
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn98-5b90dec4d.pth'), url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn98-5b90dec4d.pth'),

@ -50,11 +50,9 @@ default_cfgs = {
'gluon_resnext50_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext50_32x4d-e6a097c1.pth'), 'gluon_resnext50_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext50_32x4d-e6a097c1.pth'),
'gluon_resnext101_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_32x4d-b253c8c4.pth'), 'gluon_resnext101_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_32x4d-b253c8c4.pth'),
'gluon_resnext101_64x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_64x4d-f9a8e184.pth'), 'gluon_resnext101_64x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_64x4d-f9a8e184.pth'),
'gluon_resnext152_32x4d': _cfg(url=''),
'gluon_seresnext50_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext50_32x4d-90cf2d6e.pth'), 'gluon_seresnext50_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext50_32x4d-90cf2d6e.pth'),
'gluon_seresnext101_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_32x4d-cf52900d.pth'), 'gluon_seresnext101_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_32x4d-cf52900d.pth'),
'gluon_seresnext101_64x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_64x4d-f9926f93.pth'), 'gluon_seresnext101_64x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_64x4d-f9926f93.pth'),
'gluon_seresnext152_32x4d': _cfg(url=''),
'gluon_senet154': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_senet154-70a1a3c0.pth'), 'gluon_senet154': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_senet154-70a1a3c0.pth'),
} }
@ -617,20 +615,6 @@ def gluon_resnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **kwa
return model return model
@register_model
def gluon_resnext152_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNeXt152-32x4d model.
"""
default_cfg = default_cfgs['gluon_resnext152_32x4d']
model = GluonResNet(
BottleneckGl, [3, 8, 36, 3], cardinality=32, base_width=4,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model @register_model
def gluon_seresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def gluon_seresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a SEResNeXt50-32x4d model. """Constructs a SEResNeXt50-32x4d model.
@ -673,20 +657,6 @@ def gluon_seresnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **k
return model return model
@register_model
def gluon_seresnext152_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a SEResNeXt152-32x4d model.
"""
default_cfg = default_cfgs['gluon_seresnext152_32x4d']
model = GluonResNet(
BottleneckGl, [3, 8, 36, 3], cardinality=32, base_width=4, use_se=True,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
#if pretrained:
# load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model @register_model
def gluon_senet154(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def gluon_senet154(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs an SENet-154 model. """Constructs an SENet-154 model.

@ -5,22 +5,36 @@ from collections import defaultdict
__all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules'] __all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules']
_module_to_models = defaultdict(set) _module_to_models = defaultdict(set) # dict of sets to check membership of model in module
_model_to_module = {} _model_to_module = {} # mapping of model names to module names
_model_entrypoints = {} _model_entrypoints = {} # mapping of model names to entrypoint fns
_model_has_pretrained = set() # set of model names that have pretrained weight url present
def register_model(fn): def register_model(fn):
# lookup containing module
mod = sys.modules[fn.__module__] mod = sys.modules[fn.__module__]
module_name_split = fn.__module__.split('.') module_name_split = fn.__module__.split('.')
module_name = module_name_split[-1] if len(module_name_split) else '' module_name = module_name_split[-1] if len(module_name_split) else ''
# add model to __all__ in module
model_name = fn.__name__
if hasattr(mod, '__all__'): if hasattr(mod, '__all__'):
mod.__all__.append(fn.__name__) mod.__all__.append(model_name)
else: else:
mod.__all__ = [fn.__name__] mod.__all__ = [model_name]
_model_entrypoints[fn.__name__] = fn
_model_to_module[fn.__name__] = module_name # add entries to registry dict/sets
_module_to_models[module_name].add(fn.__name__) _model_entrypoints[model_name] = fn
_model_to_module[model_name] = module_name
_module_to_models[module_name].add(model_name)
has_pretrained = False # check if model has a pretrained url to allow filtering on this
if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing
# entrypoints or non-matching combos
has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url']
if has_pretrained:
_model_has_pretrained.add(model_name)
return fn return fn
@ -28,7 +42,7 @@ def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def list_models(filter='', module=''): def list_models(filter='', module='', pretrained=False):
""" Return list of available model names, sorted alphabetically """ Return list of available model names, sorted alphabetically
Args: Args:
@ -45,6 +59,8 @@ def list_models(filter='', module=''):
models = _model_entrypoints.keys() models = _model_entrypoints.keys()
if filter: if filter:
models = fnmatch.filter(models, filter) models = fnmatch.filter(models, filter)
if pretrained:
models = _model_has_pretrained.intersection(models)
return list(sorted(models, key=_natural_key)) return list(sorted(models, key=_natural_key))

@ -33,14 +33,22 @@ default_cfgs = {
'resnet18': _cfg(url='https://download.pytorch.org/models/resnet18-5c106cde.pth'), 'resnet18': _cfg(url='https://download.pytorch.org/models/resnet18-5c106cde.pth'),
'resnet34': _cfg( 'resnet34': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth'),
'resnet50': _cfg(url='https://download.pytorch.org/models/resnet50-19c8e357.pth'), 'resnet50': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/rw_resnet50-86acaeed.pth',
interpolation='bicubic'),
'resnet101': _cfg(url='https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'), 'resnet101': _cfg(url='https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'),
'resnet152': _cfg(url='https://download.pytorch.org/models/resnet152-b121ed2d.pth'), 'resnet152': _cfg(url='https://download.pytorch.org/models/resnet152-b121ed2d.pth'),
'resnext50_32x4d': _cfg(url='https://www.dropbox.com/s/yxci33lfew51p6a/resnext50_32x4d-068914d1.pth?dl=1', 'tv_resnet34': _cfg(url='https://download.pytorch.org/models/resnet34-333f7ec4.pth'),
interpolation='bicubic'), 'tv_resnet50': _cfg(url='https://download.pytorch.org/models/resnet50-19c8e357.pth'),
'wide_resnet50_2': _cfg(url='https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth'),
'wide_resnet101_2': _cfg(url='https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth'),
'resnext50_32x4d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50_32x4d-068914d1.pth',
interpolation='bicubic'),
'resnext101_32x4d': _cfg(url=''), 'resnext101_32x4d': _cfg(url=''),
'resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth'),
'resnext101_64x4d': _cfg(url=''), 'resnext101_64x4d': _cfg(url=''),
'resnext152_32x4d': _cfg(url=''), 'tv_resnext50_32x4d': _cfg(url='https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth'),
'ig_resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth'), 'ig_resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth'),
'ig_resnext101_32x16d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth'), 'ig_resnext101_32x16d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth'),
'ig_resnext101_32x32d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth'), 'ig_resnext101_32x32d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth'),
@ -285,6 +293,61 @@ def resnet152(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
return model return model
@register_model
def tv_resnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNet-34 model with original Torchvision weights.
"""
model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfgs['tv_resnet34']
if pretrained:
load_pretrained(model, model.default_cfg, num_classes, in_chans)
return model
@register_model
def tv_resnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNet-50 model with original Torchvision weights.
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfgs['tv_resnet50']
if pretrained:
load_pretrained(model, model.default_cfg, num_classes, in_chans)
return model
@register_model
def wide_resnet50_2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a Wide ResNet-50-2 model.
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
"""
model = ResNet(
Bottleneck, [3, 4, 6, 3], base_width=128,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfgs['wide_resnet50_2']
if pretrained:
load_pretrained(model, model.default_cfg, num_classes, in_chans)
return model
@register_model
def wide_resnet101_2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a Wide ResNet-100-2 model.
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same.
"""
model = ResNet(
Bottleneck, [3, 4, 23, 3], base_width=128,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfgs['wide_resnet101_2']
if pretrained:
load_pretrained(model, model.default_cfg, num_classes, in_chans)
return model
@register_model @register_model
def resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNeXt50-32x4d model. """Constructs a ResNeXt50-32x4d model.
@ -301,7 +364,7 @@ def resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
@register_model @register_model
def resnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def resnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNeXt-101 model. """Constructs a ResNeXt-101 32x4d model.
""" """
default_cfg = default_cfgs['resnext101_32x4d'] default_cfg = default_cfgs['resnext101_32x4d']
model = ResNet( model = ResNet(
@ -313,6 +376,20 @@ def resnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
return model return model
@register_model
def resnext101_32x8d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNeXt-101 32x8d model.
"""
default_cfg = default_cfgs['resnext101_32x8d']
model = ResNet(
Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=8,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model @register_model
def resnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def resnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNeXt101-64x4d model. """Constructs a ResNeXt101-64x4d model.
@ -328,12 +405,12 @@ def resnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
@register_model @register_model
def resnext152_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def tv_resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNeXt152-32x4d model. """Constructs a ResNeXt50-32x4d model with original Torchvision weights.
""" """
default_cfg = default_cfgs['resnext152_32x4d'] default_cfg = default_cfgs['tv_resnext50_32x4d']
model = ResNet( model = ResNet(
Bottleneck, [3, 8, 36, 3], cardinality=32, base_width=4, Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
num_classes=num_classes, in_chans=in_chans, **kwargs) num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg model.default_cfg = default_cfg
if pretrained: if pretrained:

Loading…
Cancel
Save