From b8762cc67d79464ae3105fecf6e6f35ebe8ae230 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 29 Jun 2019 15:37:42 -0700 Subject: [PATCH] Model updates. Add my best ResNet50 weights top-1=78.47. Add some other torchvision weights. * Remove some models that don't exist as pretrained an likely never will (se)resnext152 * Add some torchvision weights as tv_ for models that I have added better weights for * Add wide resnet recently added to torchvision along with resnext101-32x8d * Add functionality to model registry to allow filtering on pretrained weight presence --- README.md | 1 + timm/models/dpn.py | 4 +- timm/models/gluon_resnet.py | 30 ------------ timm/models/registry.py | 34 +++++++++---- timm/models/resnet.py | 95 +++++++++++++++++++++++++++++++++---- 5 files changed, 114 insertions(+), 50 deletions(-) diff --git a/README.md b/README.md index 64d8a6f1..ac2fbdf6 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,7 @@ I've leveraged the training scripts in this repository to train a few of the mod |Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling | |---|---|---|---|---| | resnext50_32x4d | 78.512 (21.488) | 94.042 (5.958) | 25M | bicubic | +| resnet50 | 78.470 (21.530) | 94.266 (5.734) | 25.6M | bicubic | | seresnext26_32x4d | 77.104 (22.896) | 93.316 (6.684) | 16.8M | bicubic | | efficientnet_b0 | 76.912 (23.088) | 93.210 (6.790) | 5.29M | bicubic | | mobilenetv3_100 | 75.634 (24.366) | 92.708 (7.292) | 5.5M | bicubic | diff --git a/timm/models/dpn.py b/timm/models/dpn.py index 76b59ca2..92bc7855 100644 --- a/timm/models/dpn.py +++ b/timm/models/dpn.py @@ -35,9 +35,9 @@ def _cfg(url=''): default_cfgs = { 'dpn68': _cfg( url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68-66bebafa7.pth'), - 'dpn68b_extra': _cfg( + 'dpn68b': _cfg( url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68b_extra-84854c156.pth'), - 'dpn92_extra': _cfg( + 'dpn92': _cfg( url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn92_extra-b040e4a9b.pth'), 'dpn98': _cfg( url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn98-5b90dec4d.pth'), diff --git a/timm/models/gluon_resnet.py b/timm/models/gluon_resnet.py index c5d0634f..715e0950 100644 --- a/timm/models/gluon_resnet.py +++ b/timm/models/gluon_resnet.py @@ -50,11 +50,9 @@ default_cfgs = { 'gluon_resnext50_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext50_32x4d-e6a097c1.pth'), 'gluon_resnext101_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_32x4d-b253c8c4.pth'), 'gluon_resnext101_64x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_64x4d-f9a8e184.pth'), - 'gluon_resnext152_32x4d': _cfg(url=''), 'gluon_seresnext50_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext50_32x4d-90cf2d6e.pth'), 'gluon_seresnext101_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_32x4d-cf52900d.pth'), 'gluon_seresnext101_64x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_64x4d-f9926f93.pth'), - 'gluon_seresnext152_32x4d': _cfg(url=''), 'gluon_senet154': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_senet154-70a1a3c0.pth'), } @@ -617,20 +615,6 @@ def gluon_resnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **kwa return model -@register_model -def gluon_resnext152_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Constructs a ResNeXt152-32x4d model. - """ - default_cfg = default_cfgs['gluon_resnext152_32x4d'] - model = GluonResNet( - BottleneckGl, [3, 8, 36, 3], cardinality=32, base_width=4, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - @register_model def gluon_seresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a SEResNeXt50-32x4d model. @@ -673,20 +657,6 @@ def gluon_seresnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **k return model -@register_model -def gluon_seresnext152_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Constructs a SEResNeXt152-32x4d model. - """ - default_cfg = default_cfgs['gluon_seresnext152_32x4d'] - model = GluonResNet( - BottleneckGl, [3, 8, 36, 3], cardinality=32, base_width=4, use_se=True, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - #if pretrained: - # load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - @register_model def gluon_senet154(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs an SENet-154 model. diff --git a/timm/models/registry.py b/timm/models/registry.py index 45bc1809..c15f5414 100644 --- a/timm/models/registry.py +++ b/timm/models/registry.py @@ -5,22 +5,36 @@ from collections import defaultdict __all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules'] -_module_to_models = defaultdict(set) -_model_to_module = {} -_model_entrypoints = {} +_module_to_models = defaultdict(set) # dict of sets to check membership of model in module +_model_to_module = {} # mapping of model names to module names +_model_entrypoints = {} # mapping of model names to entrypoint fns +_model_has_pretrained = set() # set of model names that have pretrained weight url present def register_model(fn): + # lookup containing module mod = sys.modules[fn.__module__] module_name_split = fn.__module__.split('.') module_name = module_name_split[-1] if len(module_name_split) else '' + + # add model to __all__ in module + model_name = fn.__name__ if hasattr(mod, '__all__'): - mod.__all__.append(fn.__name__) + mod.__all__.append(model_name) else: - mod.__all__ = [fn.__name__] - _model_entrypoints[fn.__name__] = fn - _model_to_module[fn.__name__] = module_name - _module_to_models[module_name].add(fn.__name__) + mod.__all__ = [model_name] + + # add entries to registry dict/sets + _model_entrypoints[model_name] = fn + _model_to_module[model_name] = module_name + _module_to_models[module_name].add(model_name) + has_pretrained = False # check if model has a pretrained url to allow filtering on this + if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs: + # this will catch all models that have entrypoint matching cfg key, but miss any aliasing + # entrypoints or non-matching combos + has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url'] + if has_pretrained: + _model_has_pretrained.add(model_name) return fn @@ -28,7 +42,7 @@ def _natural_key(string_): return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] -def list_models(filter='', module=''): +def list_models(filter='', module='', pretrained=False): """ Return list of available model names, sorted alphabetically Args: @@ -45,6 +59,8 @@ def list_models(filter='', module=''): models = _model_entrypoints.keys() if filter: models = fnmatch.filter(models, filter) + if pretrained: + models = _model_has_pretrained.intersection(models) return list(sorted(models, key=_natural_key)) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 7ed3b2e1..9a4b22cd 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -33,14 +33,22 @@ default_cfgs = { 'resnet18': _cfg(url='https://download.pytorch.org/models/resnet18-5c106cde.pth'), 'resnet34': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth'), - 'resnet50': _cfg(url='https://download.pytorch.org/models/resnet50-19c8e357.pth'), + 'resnet50': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/rw_resnet50-86acaeed.pth', + interpolation='bicubic'), 'resnet101': _cfg(url='https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'), 'resnet152': _cfg(url='https://download.pytorch.org/models/resnet152-b121ed2d.pth'), - 'resnext50_32x4d': _cfg(url='https://www.dropbox.com/s/yxci33lfew51p6a/resnext50_32x4d-068914d1.pth?dl=1', - interpolation='bicubic'), + 'tv_resnet34': _cfg(url='https://download.pytorch.org/models/resnet34-333f7ec4.pth'), + 'tv_resnet50': _cfg(url='https://download.pytorch.org/models/resnet50-19c8e357.pth'), + 'wide_resnet50_2': _cfg(url='https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth'), + 'wide_resnet101_2': _cfg(url='https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth'), + 'resnext50_32x4d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50_32x4d-068914d1.pth', + interpolation='bicubic'), 'resnext101_32x4d': _cfg(url=''), + 'resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth'), 'resnext101_64x4d': _cfg(url=''), - 'resnext152_32x4d': _cfg(url=''), + 'tv_resnext50_32x4d': _cfg(url='https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth'), 'ig_resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth'), 'ig_resnext101_32x16d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth'), 'ig_resnext101_32x32d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth'), @@ -285,6 +293,61 @@ def resnet152(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model +def tv_resnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-34 model with original Torchvision weights. + """ + model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfgs['tv_resnet34'] + if pretrained: + load_pretrained(model, model.default_cfg, num_classes, in_chans) + return model + + +@register_model +def tv_resnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-50 model with original Torchvision weights. + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfgs['tv_resnet50'] + if pretrained: + load_pretrained(model, model.default_cfg, num_classes, in_chans) + return model + + +@register_model +def wide_resnet50_2(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a Wide ResNet-50-2 model. + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + """ + model = ResNet( + Bottleneck, [3, 4, 6, 3], base_width=128, + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfgs['wide_resnet50_2'] + if pretrained: + load_pretrained(model, model.default_cfg, num_classes, in_chans) + return model + + +@register_model +def wide_resnet101_2(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a Wide ResNet-100-2 model. + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same. + """ + model = ResNet( + Bottleneck, [3, 4, 23, 3], base_width=128, + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfgs['wide_resnet101_2'] + if pretrained: + load_pretrained(model, model.default_cfg, num_classes, in_chans) + return model + + @register_model def resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNeXt50-32x4d model. @@ -301,7 +364,7 @@ def resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): @register_model def resnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Constructs a ResNeXt-101 model. + """Constructs a ResNeXt-101 32x4d model. """ default_cfg = default_cfgs['resnext101_32x4d'] model = ResNet( @@ -313,6 +376,20 @@ def resnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model +def resnext101_32x8d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNeXt-101 32x8d model. + """ + default_cfg = default_cfgs['resnext101_32x8d'] + model = ResNet( + Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=8, + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + @register_model def resnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNeXt101-64x4d model. @@ -328,12 +405,12 @@ def resnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): @register_model -def resnext152_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Constructs a ResNeXt152-32x4d model. +def tv_resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNeXt50-32x4d model with original Torchvision weights. """ - default_cfg = default_cfgs['resnext152_32x4d'] + default_cfg = default_cfgs['tv_resnext50_32x4d'] model = ResNet( - Bottleneck, [3, 8, 36, 3], cardinality=32, base_width=4, + Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: