From 171c0b88b67e1d4b45af2a02fca0a2bee33baa5f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 23 Jun 2019 18:22:16 -0700 Subject: [PATCH] Add model registry and model listing fns, refactor model_factory/create_model fn --- timm/__init__.py | 2 +- timm/models/__init__.py | 16 ++- timm/models/densenet.py | 144 ++++++++++++++------------- timm/models/dpn.py | 152 +++++++++++++++-------------- timm/models/factory.py | 44 +++++++++ timm/models/gen_efficientnet.py | 46 +++++++-- timm/models/gluon_resnet.py | 42 ++++++-- timm/models/inception_resnet_v2.py | 12 ++- timm/models/inception_v3.py | 8 +- timm/models/inception_v4.py | 11 ++- timm/models/model_factory.py | 42 -------- timm/models/pnasnet.py | 5 +- timm/models/registry.py | 78 +++++++++++++++ timm/models/resnet.py | 24 ++++- timm/models/senet.py | 15 ++- timm/models/xception.py | 7 +- validate.py | 24 +++-- 17 files changed, 436 insertions(+), 236 deletions(-) create mode 100644 timm/models/factory.py delete mode 100644 timm/models/model_factory.py create mode 100644 timm/models/registry.py diff --git a/timm/__init__.py b/timm/__init__.py index 325a273a..86ed7a42 100644 --- a/timm/__init__.py +++ b/timm/__init__.py @@ -1,2 +1,2 @@ from .version import __version__ -from .models import create_model +from .models import create_model, list_models, is_model, list_modules, model_entrypoint diff --git a/timm/models/__init__.py b/timm/models/__init__.py index f451a270..06b1d178 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -1,4 +1,16 @@ -from .model_factory import create_model +from .inception_v4 import * +from .inception_resnet_v2 import * +from .densenet import * +from .resnet import * +from .dpn import * +from .senet import * +from .xception import * +from .pnasnet import * +from .gen_efficientnet import * +from .inception_v3 import * +from .gluon_resnet import * + +from .registry import * +from .factory import create_model from .helpers import load_checkpoint, resume_checkpoint from .test_time_pool import TestTimePoolHead, apply_test_time_pool - diff --git a/timm/models/densenet.py b/timm/models/densenet.py index dd144b70..7c104654 100644 --- a/timm/models/densenet.py +++ b/timm/models/densenet.py @@ -4,13 +4,17 @@ fixed kwargs passthrough and addition of dynamic global avg/max pool. """ from collections import OrderedDict +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import * +from .adaptive_avgmax_pool import select_adaptive_pool2d from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD import re -_models = ['densenet121', 'densenet169', 'densenet201', 'densenet161'] -__all__ = ['DenseNet'] + _models +__all__ = ['DenseNet'] def _cfg(url=''): @@ -30,71 +34,6 @@ default_cfgs = { } -def _filter_pretrained(state_dict): - pattern = re.compile( - r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') - - for key in list(state_dict.keys()): - res = pattern.match(key) - if res: - new_key = res.group(1) + res.group(2) - state_dict[new_key] = state_dict[key] - del state_dict[key] - return state_dict - - -def densenet121(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - r"""Densenet-121 model from - `"Densely Connected Convolutional Networks" ` - """ - default_cfg = default_cfgs['densenet121'] - model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained) - return model - - -def densenet169(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - r"""Densenet-169 model from - `"Densely Connected Convolutional Networks" ` - """ - default_cfg = default_cfgs['densenet169'] - model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained) - return model - - -def densenet201(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - r"""Densenet-201 model from - `"Densely Connected Convolutional Networks" ` - """ - default_cfg = default_cfgs['densenet201'] - model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained) - return model - - -def densenet161(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - r"""Densenet-201 model from - `"Densely Connected Convolutional Networks" ` - """ - default_cfg = default_cfgs['densenet161'] - model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained) - return model - - class _DenseLayer(nn.Sequential): def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): super(_DenseLayer, self).__init__() @@ -205,3 +144,72 @@ class DenseNet(nn.Module): def forward(self, x): return self.classifier(self.forward_features(x, pool=True)) + +def _filter_pretrained(state_dict): + pattern = re.compile( + r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') + + for key in list(state_dict.keys()): + res = pattern.match(key) + if res: + new_key = res.group(1) + res.group(2) + state_dict[new_key] = state_dict[key] + del state_dict[key] + return state_dict + + + +@register_model +def densenet121(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + r"""Densenet-121 model from + `"Densely Connected Convolutional Networks" ` + """ + default_cfg = default_cfgs['densenet121'] + model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained) + return model + + +@register_model +def densenet169(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + r"""Densenet-169 model from + `"Densely Connected Convolutional Networks" ` + """ + default_cfg = default_cfgs['densenet169'] + model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained) + return model + + +@register_model +def densenet201(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + r"""Densenet-201 model from + `"Densely Connected Convolutional Networks" ` + """ + default_cfg = default_cfgs['densenet201'] + model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained) + return model + + +@register_model +def densenet161(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + r"""Densenet-201 model from + `"Densely Connected Convolutional Networks" ` + """ + default_cfg = default_cfgs['densenet161'] + model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained) + return model diff --git a/timm/models/dpn.py b/timm/models/dpn.py index 636801e9..76b59ca2 100644 --- a/timm/models/dpn.py +++ b/timm/models/dpn.py @@ -14,12 +14,13 @@ import torch.nn as nn import torch.nn.functional as F from collections import OrderedDict +from .registry import register_model from .helpers import load_pretrained from .adaptive_avgmax_pool import select_adaptive_pool2d from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD -_models = ['dpn68', 'dpn68b', 'dpn92', 'dpn98', 'dpn131', 'dpn107'] -__all__ = ['DPN'] + _models + +__all__ = ['DPN'] def _cfg(url=''): @@ -47,78 +48,6 @@ default_cfgs = { } -def dpn68(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - default_cfg = default_cfgs['dpn68'] - model = DPN( - small=True, num_init_features=10, k_r=128, groups=32, - k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -def dpn68b(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - default_cfg = default_cfgs['dpn68b_extra'] - model = DPN( - small=True, num_init_features=10, k_r=128, groups=32, - b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -def dpn92(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - default_cfg = default_cfgs['dpn92_extra'] - model = DPN( - num_init_features=64, k_r=96, groups=32, - k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -def dpn98(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - default_cfg = default_cfgs['dpn98'] - model = DPN( - num_init_features=96, k_r=160, groups=40, - k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128), - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -def dpn131(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - default_cfg = default_cfgs['dpn131'] - model = DPN( - num_init_features=128, k_r=160, groups=40, - k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128), - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - -def dpn107(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - default_cfg = default_cfgs['dpn107_extra'] - model = DPN( - num_init_features=128, k_r=200, groups=50, - k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128), - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - class CatBnAct(nn.Module): def __init__(self, in_chs, activation_fn=nn.ReLU(inplace=True)): super(CatBnAct, self).__init__() @@ -317,3 +246,78 @@ class DPN(nn.Module): return out.view(out.size(0), -1) +@register_model +def dpn68(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + default_cfg = default_cfgs['dpn68'] + model = DPN( + small=True, num_init_features=10, k_r=128, groups=32, + k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def dpn68b(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + default_cfg = default_cfgs['dpn68b_extra'] + model = DPN( + small=True, num_init_features=10, k_r=128, groups=32, + b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def dpn92(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + default_cfg = default_cfgs['dpn92_extra'] + model = DPN( + num_init_features=64, k_r=96, groups=32, + k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def dpn98(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + default_cfg = default_cfgs['dpn98'] + model = DPN( + num_init_features=96, k_r=160, groups=40, + k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128), + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def dpn131(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + default_cfg = default_cfgs['dpn131'] + model = DPN( + num_init_features=128, k_r=160, groups=40, + k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128), + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +def dpn107(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + default_cfg = default_cfgs['dpn107_extra'] + model = DPN( + num_init_features=128, k_r=200, groups=50, + k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128), + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model diff --git a/timm/models/factory.py b/timm/models/factory.py new file mode 100644 index 00000000..d807a342 --- /dev/null +++ b/timm/models/factory.py @@ -0,0 +1,44 @@ +from .registry import is_model, is_model_in_modules, model_entrypoint +from .helpers import load_checkpoint + + +def create_model( + model_name, + pretrained=False, + num_classes=1000, + in_chans=3, + checkpoint_path='', + **kwargs): + """Create a model + + Args: + model_name (str): name of model to instantiate + pretrained (bool): load pretrained ImageNet-1k weights if true + num_classes (int): number of classes for final fully connected layer (default: 1000) + in_chans (int): number of input channels / colors (default: 3) + checkpoint_path (str): path of checkpoint to load after model is initialized + + Keyword Args: + drop_rate (float): dropout rate for training (default: 0.0) + global_pool (str): global pool type (default: 'avg') + **: other kwargs are model specific + """ + margs = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans) + + # Not all models have support for batchnorm params passed as args, only gen_efficientnet variants + supports_bn_params = is_model_in_modules(model_name, ['gen_efficientnet']) + if not supports_bn_params and any([x in kwargs for x in ['bn_tf', 'bn_momentum', 'bn_eps']]): + kwargs.pop('bn_tf', None) + kwargs.pop('bn_momentum', None) + kwargs.pop('bn_eps', None) + + if is_model(model_name): + create_fn = model_entrypoint(model_name) + model = create_fn(**margs, **kwargs) + else: + raise RuntimeError('Unknown model (%s)' % model_name) + + if checkpoint_path: + load_checkpoint(model, checkpoint_path) + + return model diff --git a/timm/models/gen_efficientnet.py b/timm/models/gen_efficientnet.py index 1feee809..12d280fc 100644 --- a/timm/models/gen_efficientnet.py +++ b/timm/models/gen_efficientnet.py @@ -23,19 +23,15 @@ from copy import deepcopy import torch import torch.nn as nn import torch.nn.functional as F + +from .registry import register_model from .helpers import load_pretrained from .adaptive_avgmax_pool import SelectAdaptivePool2d from .conv2d_same import sconv2d from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -_models = [ - 'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_b1', 'mnasnet_140', 'semnasnet_050', 'semnasnet_075', - 'semnasnet_100', 'mnasnet_a1', 'semnasnet_140', 'mnasnet_small', 'mobilenetv1_100', 'mobilenetv2_100', - 'mobilenetv3_050', 'mobilenetv3_075', 'mobilenetv3_100', 'chamnetv1_100', 'chamnetv2_100', - 'fbnetc_100', 'spnasnet_100', 'tflite_mnasnet_100', 'tflite_semnasnet_100', 'efficientnet_b0', 'efficientnet_b1', - 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'efficientnet_b5', 'tf_efficientnet_b0', - 'tf_efficientnet_b1', 'tf_efficientnet_b2', 'tf_efficientnet_b3', 'tf_efficientnet_b4', 'tf_efficientnet_b5'] -__all__ = ['GenEfficientNet', 'gen_efficientnet_model_names'] + _models + +__all__ = ['GenEfficientNet'] def _cfg(url='', **kwargs): @@ -1157,6 +1153,7 @@ def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes= return model +@register_model def mnasnet_050(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ MNASNet B1, depth multiplier of 0.5. """ default_cfg = default_cfgs['mnasnet_050'] @@ -1167,6 +1164,7 @@ def mnasnet_050(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def mnasnet_075(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ MNASNet B1, depth multiplier of 0.75. """ default_cfg = default_cfgs['mnasnet_075'] @@ -1177,6 +1175,7 @@ def mnasnet_075(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def mnasnet_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ MNASNet B1, depth multiplier of 1.0. """ default_cfg = default_cfgs['mnasnet_100'] @@ -1187,11 +1186,13 @@ def mnasnet_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def mnasnet_b1(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ MNASNet B1, depth multiplier of 1.0. """ return mnasnet_100(num_classes, in_chans, pretrained, **kwargs) +@register_model def tflite_mnasnet_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ MNASNet B1, depth multiplier of 1.0. """ default_cfg = default_cfgs['tflite_mnasnet_100'] @@ -1205,6 +1206,7 @@ def tflite_mnasnet_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs) return model +@register_model def mnasnet_140(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ MNASNet B1, depth multiplier of 1.4 """ default_cfg = default_cfgs['mnasnet_140'] @@ -1215,6 +1217,7 @@ def mnasnet_140(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def semnasnet_050(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ MNASNet A1 (w/ SE), depth multiplier of 0.5 """ default_cfg = default_cfgs['semnasnet_050'] @@ -1225,6 +1228,7 @@ def semnasnet_050(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def semnasnet_075(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ MNASNet A1 (w/ SE), depth multiplier of 0.75. """ default_cfg = default_cfgs['semnasnet_075'] @@ -1235,6 +1239,7 @@ def semnasnet_075(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def semnasnet_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """ default_cfg = default_cfgs['semnasnet_100'] @@ -1245,11 +1250,13 @@ def semnasnet_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def mnasnet_a1(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """ return semnasnet_100(num_classes, in_chans, pretrained, **kwargs) +@register_model def tflite_semnasnet_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ MNASNet A1, depth multiplier of 1.0. """ default_cfg = default_cfgs['tflite_semnasnet_100'] @@ -1263,6 +1270,7 @@ def tflite_semnasnet_100(pretrained=False, num_classes=1000, in_chans=3, **kwarg return model +@register_model def semnasnet_140(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ MNASNet A1 (w/ SE), depth multiplier of 1.4. """ default_cfg = default_cfgs['semnasnet_140'] @@ -1273,6 +1281,7 @@ def semnasnet_140(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def mnasnet_small(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ MNASNet Small, depth multiplier of 1.0. """ default_cfg = default_cfgs['mnasnet_small'] @@ -1283,6 +1292,7 @@ def mnasnet_small(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def mobilenetv1_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ MobileNet V1 """ default_cfg = default_cfgs['mobilenetv1_100'] @@ -1293,6 +1303,7 @@ def mobilenetv1_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def mobilenetv2_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ MobileNet V2 """ default_cfg = default_cfgs['mobilenetv2_100'] @@ -1303,6 +1314,7 @@ def mobilenetv2_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def mobilenetv3_050(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ MobileNet V3 """ default_cfg = default_cfgs['mobilenetv3_050'] @@ -1313,6 +1325,7 @@ def mobilenetv3_050(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def mobilenetv3_075(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ MobileNet V3 """ default_cfg = default_cfgs['mobilenetv3_075'] @@ -1323,6 +1336,7 @@ def mobilenetv3_075(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def mobilenetv3_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ MobileNet V3 """ default_cfg = default_cfgs['mobilenetv3_100'] @@ -1336,6 +1350,7 @@ def mobilenetv3_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def fbnetc_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ FBNet-C """ default_cfg = default_cfgs['fbnetc_100'] @@ -1349,6 +1364,7 @@ def fbnetc_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def chamnetv1_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ ChamNet """ default_cfg = default_cfgs['chamnetv1_100'] @@ -1359,6 +1375,7 @@ def chamnetv1_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def chamnetv2_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ ChamNet """ default_cfg = default_cfgs['chamnetv2_100'] @@ -1369,6 +1386,7 @@ def chamnetv2_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def spnasnet_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ Single-Path NAS Pixel1""" default_cfg = default_cfgs['spnasnet_100'] @@ -1379,6 +1397,7 @@ def spnasnet_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def efficientnet_b0(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ EfficientNet-B0 """ default_cfg = default_cfgs['efficientnet_b0'] @@ -1392,6 +1411,7 @@ def efficientnet_b0(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def efficientnet_b1(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ EfficientNet-B1 """ default_cfg = default_cfgs['efficientnet_b1'] @@ -1405,6 +1425,7 @@ def efficientnet_b1(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def efficientnet_b2(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ EfficientNet-B2 """ default_cfg = default_cfgs['efficientnet_b2'] @@ -1418,6 +1439,7 @@ def efficientnet_b2(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def efficientnet_b3(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ EfficientNet-B3 """ default_cfg = default_cfgs['efficientnet_b3'] @@ -1431,6 +1453,7 @@ def efficientnet_b3(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def efficientnet_b4(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ EfficientNet-B4 """ default_cfg = default_cfgs['efficientnet_b4'] @@ -1444,6 +1467,7 @@ def efficientnet_b4(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ EfficientNet-B5 """ # NOTE for train, drop_rate should be 0.4 @@ -1457,6 +1481,7 @@ def efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def tf_efficientnet_b0(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ EfficientNet-B0. Tensorflow compatible variant """ default_cfg = default_cfgs['tf_efficientnet_b0'] @@ -1471,6 +1496,7 @@ def tf_efficientnet_b0(pretrained=False, num_classes=1000, in_chans=3, **kwargs) return model +@register_model def tf_efficientnet_b1(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ EfficientNet-B1. Tensorflow compatible variant """ default_cfg = default_cfgs['tf_efficientnet_b1'] @@ -1485,6 +1511,7 @@ def tf_efficientnet_b1(pretrained=False, num_classes=1000, in_chans=3, **kwargs) return model +@register_model def tf_efficientnet_b2(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ EfficientNet-B2. Tensorflow compatible variant """ default_cfg = default_cfgs['tf_efficientnet_b2'] @@ -1499,6 +1526,7 @@ def tf_efficientnet_b2(pretrained=False, num_classes=1000, in_chans=3, **kwargs) return model +@register_model def tf_efficientnet_b3(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ EfficientNet-B3. Tensorflow compatible variant """ default_cfg = default_cfgs['tf_efficientnet_b3'] @@ -1513,6 +1541,7 @@ def tf_efficientnet_b3(pretrained=False, num_classes=1000, in_chans=3, **kwargs) return model +@register_model def tf_efficientnet_b4(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ EfficientNet-B4. Tensorflow compatible variant """ default_cfg = default_cfgs['tf_efficientnet_b4'] @@ -1527,6 +1556,7 @@ def tf_efficientnet_b4(pretrained=False, num_classes=1000, in_chans=3, **kwargs) return model +@register_model def tf_efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ EfficientNet-B5. Tensorflow compatible variant """ default_cfg = default_cfgs['tf_efficientnet_b5'] diff --git a/timm/models/gluon_resnet.py b/timm/models/gluon_resnet.py index 96e61aad..c5d0634f 100644 --- a/timm/models/gluon_resnet.py +++ b/timm/models/gluon_resnet.py @@ -3,21 +3,19 @@ This file evolved from https://github.com/pytorch/vision 'resnet.py' with (SE)-R and ports of Gluon variations (https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/resnet.py) by Ross Wightman """ +import math + +import torch import torch.nn as nn import torch.nn.functional as F -import math + +from .registry import register_model from .helpers import load_pretrained from .adaptive_avgmax_pool import SelectAdaptivePool2d from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -_models = [ - 'gluon_resnet18_v1b', 'gluon_resnet34_v1b', 'gluon_resnet50_v1b', 'gluon_resnet101_v1b', 'gluon_resnet152_v1b', - 'gluon_resnet50_v1c', 'gluon_resnet101_v1c', 'gluon_resnet152_v1c', 'gluon_resnet50_v1d', 'gluon_resnet101_v1d', - 'gluon_resnet152_v1d', 'gluon_resnet50_v1e', 'gluon_resnet101_v1e', 'gluon_resnet152_v1e', 'gluon_resnet50_v1s', - 'gluon_resnet101_v1s', 'gluon_resnet152_v1s', 'gluon_resnext50_32x4d', 'gluon_resnext101_32x4d', - 'gluon_resnext101_64x4d', 'gluon_resnext152_32x4d', 'gluon_seresnext50_32x4d', 'gluon_seresnext101_32x4d', - 'gluon_seresnext101_64x4d', 'gluon_seresnext152_32x4d', 'gluon_senet154'] -__all__ = ['GluonResNet'] + _models + +__all__ = ['GluonResNet'] def _cfg(url='', **kwargs): @@ -361,6 +359,7 @@ class GluonResNet(nn.Module): return x +@register_model def gluon_resnet18_v1b(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-18 model. """ @@ -372,6 +371,7 @@ def gluon_resnet18_v1b(pretrained=False, num_classes=1000, in_chans=3, **kwargs) return model +@register_model def gluon_resnet34_v1b(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-34 model. """ @@ -383,6 +383,7 @@ def gluon_resnet34_v1b(pretrained=False, num_classes=1000, in_chans=3, **kwargs) return model +@register_model def gluon_resnet50_v1b(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-50 model. """ @@ -394,6 +395,7 @@ def gluon_resnet50_v1b(pretrained=False, num_classes=1000, in_chans=3, **kwargs) return model +@register_model def gluon_resnet101_v1b(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-101 model. """ @@ -405,6 +407,7 @@ def gluon_resnet101_v1b(pretrained=False, num_classes=1000, in_chans=3, **kwargs return model +@register_model def gluon_resnet152_v1b(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-152 model. """ @@ -416,6 +419,7 @@ def gluon_resnet152_v1b(pretrained=False, num_classes=1000, in_chans=3, **kwargs return model +@register_model def gluon_resnet50_v1c(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-50 model. """ @@ -428,6 +432,7 @@ def gluon_resnet50_v1c(pretrained=False, num_classes=1000, in_chans=3, **kwargs) return model +@register_model def gluon_resnet101_v1c(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-101 model. """ @@ -440,6 +445,7 @@ def gluon_resnet101_v1c(pretrained=False, num_classes=1000, in_chans=3, **kwargs return model +@register_model def gluon_resnet152_v1c(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-152 model. """ @@ -452,6 +458,7 @@ def gluon_resnet152_v1c(pretrained=False, num_classes=1000, in_chans=3, **kwargs return model +@register_model def gluon_resnet50_v1d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-50 model. """ @@ -464,6 +471,7 @@ def gluon_resnet50_v1d(pretrained=False, num_classes=1000, in_chans=3, **kwargs) return model +@register_model def gluon_resnet101_v1d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-101 model. """ @@ -476,6 +484,7 @@ def gluon_resnet101_v1d(pretrained=False, num_classes=1000, in_chans=3, **kwargs return model +@register_model def gluon_resnet152_v1d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-152 model. """ @@ -488,6 +497,7 @@ def gluon_resnet152_v1d(pretrained=False, num_classes=1000, in_chans=3, **kwargs return model +@register_model def gluon_resnet50_v1e(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-50-V1e model. No pretrained weights for any 'e' variants """ @@ -500,6 +510,7 @@ def gluon_resnet50_v1e(pretrained=False, num_classes=1000, in_chans=3, **kwargs) return model +@register_model def gluon_resnet101_v1e(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-101 model. """ @@ -512,6 +523,7 @@ def gluon_resnet101_v1e(pretrained=False, num_classes=1000, in_chans=3, **kwargs return model +@register_model def gluon_resnet152_v1e(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-152 model. """ @@ -524,6 +536,7 @@ def gluon_resnet152_v1e(pretrained=False, num_classes=1000, in_chans=3, **kwargs return model +@register_model def gluon_resnet50_v1s(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-50 model. """ @@ -536,6 +549,7 @@ def gluon_resnet50_v1s(pretrained=False, num_classes=1000, in_chans=3, **kwargs) return model +@register_model def gluon_resnet101_v1s(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-101 model. """ @@ -548,6 +562,7 @@ def gluon_resnet101_v1s(pretrained=False, num_classes=1000, in_chans=3, **kwargs return model +@register_model def gluon_resnet152_v1s(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-152 model. """ @@ -560,6 +575,7 @@ def gluon_resnet152_v1s(pretrained=False, num_classes=1000, in_chans=3, **kwargs return model +@register_model def gluon_resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNeXt50-32x4d model. """ @@ -573,6 +589,7 @@ def gluon_resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwar return model +@register_model def gluon_resnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNeXt-101 model. """ @@ -586,6 +603,7 @@ def gluon_resnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwa return model +@register_model def gluon_resnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNeXt-101 model. """ @@ -599,6 +617,7 @@ def gluon_resnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **kwa return model +@register_model def gluon_resnext152_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNeXt152-32x4d model. """ @@ -612,6 +631,7 @@ def gluon_resnext152_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwa return model +@register_model def gluon_seresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a SEResNeXt50-32x4d model. """ @@ -625,6 +645,7 @@ def gluon_seresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kw return model +@register_model def gluon_seresnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a SEResNeXt-101-32x4d model. """ @@ -638,6 +659,7 @@ def gluon_seresnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **k return model +@register_model def gluon_seresnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a SEResNeXt-101-64x4d model. """ @@ -651,6 +673,7 @@ def gluon_seresnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **k return model +@register_model def gluon_seresnext152_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a SEResNeXt152-32x4d model. """ @@ -664,6 +687,7 @@ def gluon_seresnext152_32x4d(pretrained=False, num_classes=1000, in_chans=3, **k return model +@register_model def gluon_senet154(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs an SENet-154 model. """ diff --git a/timm/models/inception_resnet_v2.py b/timm/models/inception_resnet_v2.py index 29c68a8b..fe5679fe 100644 --- a/timm/models/inception_resnet_v2.py +++ b/timm/models/inception_resnet_v2.py @@ -2,12 +2,16 @@ Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License) """ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import * +from .adaptive_avgmax_pool import select_adaptive_pool2d from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -_models = ['inception_resnet_v2', 'ens_adv_inception_resnet_v2'] -__all__ = ['InceptionResnetV2'] + _models +__all__ = ['InceptionResnetV2'] default_cfgs = { # ported from http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz @@ -328,6 +332,7 @@ class InceptionResnetV2(nn.Module): return x +@register_model def inception_resnet_v2(pretrained=False, num_classes=1000, in_chans=3, **kwargs): r"""InceptionResnetV2 model architecture from the `"InceptionV4, Inception-ResNet..." ` paper. @@ -341,6 +346,7 @@ def inception_resnet_v2(pretrained=False, num_classes=1000, in_chans=3, **kwargs return model +@register_model def ens_adv_inception_resnet_v2(pretrained=False, num_classes=1000, in_chans=3, **kwargs): r""" Ensemble Adversarially trained InceptionResnetV2 model architecture As per https://arxiv.org/abs/1705.07204 and diff --git a/timm/models/inception_v3.py b/timm/models/inception_v3.py index ba9ed493..b83895e7 100644 --- a/timm/models/inception_v3.py +++ b/timm/models/inception_v3.py @@ -1,9 +1,9 @@ from torchvision.models import Inception3 +from .registry import register_model from .helpers import load_pretrained from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -_models = ['inception_v3', 'tf_inception_v3', 'adv_inception_v3', 'gluon_inception_v3'] -__all__ = _models +__all__ = [] default_cfgs = { # original PyTorch weights, ported from Tensorflow but modified @@ -66,6 +66,7 @@ def _assert_default_kwargs(kwargs): assert kwargs.pop('drop_rate', 0.) == 0. +@register_model def inception_v3(pretrained=False, num_classes=1000, in_chans=3, **kwargs): # original PyTorch weights, ported from Tensorflow but modified default_cfg = default_cfgs['inception_v3'] @@ -78,6 +79,7 @@ def inception_v3(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def tf_inception_v3(pretrained=False, num_classes=1000, in_chans=3, **kwargs): # my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz) default_cfg = default_cfgs['tf_inception_v3'] @@ -90,6 +92,7 @@ def tf_inception_v3(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def adv_inception_v3(pretrained=False, num_classes=1000, in_chans=3, **kwargs): # my port of Tensorflow adversarially trained Inception V3 from # http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz @@ -103,6 +106,7 @@ def adv_inception_v3(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def gluon_inception_v3(pretrained=False, num_classes=1000, in_chans=3, **kwargs): # from gluon pretrained models, best performing in terms of accuracy/loss metrics # https://gluon-cv.mxnet.io/model_zoo/classification.html diff --git a/timm/models/inception_v4.py b/timm/models/inception_v4.py index ac819cfe..e389eb88 100644 --- a/timm/models/inception_v4.py +++ b/timm/models/inception_v4.py @@ -2,12 +2,16 @@ Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License) """ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import * +from .adaptive_avgmax_pool import select_adaptive_pool2d from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -_models = ['inception_v4'] -__all__ = ['InceptionV4'] + _models +__all__ = ['InceptionV4'] default_cfgs = { 'inception_v4': { @@ -293,6 +297,7 @@ class InceptionV4(nn.Module): return x +@register_model def inception_v4(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['inception_v4'] model = InceptionV4(num_classes=num_classes, in_chans=in_chans, **kwargs) diff --git a/timm/models/model_factory.py b/timm/models/model_factory.py deleted file mode 100644 index 9cce8bd1..00000000 --- a/timm/models/model_factory.py +++ /dev/null @@ -1,42 +0,0 @@ -from .inception_v4 import * -from .inception_resnet_v2 import * -from .densenet import * -from .resnet import * -from .dpn import * -from .senet import * -from .xception import * -from .pnasnet import * -from .gen_efficientnet import * -from .inception_v3 import * -from .gluon_resnet import * - -from .helpers import load_checkpoint - - -def create_model( - model_name, - pretrained=False, - num_classes=1000, - in_chans=3, - checkpoint_path='', - **kwargs): - - margs = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans) - - # Not all models have support for batchnorm params passed as args, only gen_efficientnet variants - supports_bn_params = model_name in gen_efficientnet_model_names() - if not supports_bn_params and any([x in kwargs for x in ['bn_tf', 'bn_momentum', 'bn_eps']]): - kwargs.pop('bn_tf', None) - kwargs.pop('bn_momentum', None) - kwargs.pop('bn_eps', None) - - if model_name in globals(): - create_fn = globals()[model_name] - model = create_fn(**margs, **kwargs) - else: - raise RuntimeError('Unknown model (%s)' % model_name) - - if checkpoint_path: - load_checkpoint(model, checkpoint_path) - - return model diff --git a/timm/models/pnasnet.py b/timm/models/pnasnet.py index c4b25820..e04a2b1f 100644 --- a/timm/models/pnasnet.py +++ b/timm/models/pnasnet.py @@ -12,11 +12,11 @@ import torch import torch.nn as nn import torch.nn.functional as F +from .registry import register_model from .helpers import load_pretrained from .adaptive_avgmax_pool import SelectAdaptivePool2d -_models = ['pnasnet5large'] -__all__ = ['PNASNet5Large'] + _models +__all__ = ['PNASNet5Large'] default_cfgs = { 'pnasnet5large': { @@ -385,6 +385,7 @@ class PNASNet5Large(nn.Module): return x +@register_model def pnasnet5large(pretrained=False, num_classes=1000, in_chans=3, **kwargs): r"""PNASNet-5 model architecture from the `"Progressive Neural Architecture Search" diff --git a/timm/models/registry.py b/timm/models/registry.py new file mode 100644 index 00000000..45bc1809 --- /dev/null +++ b/timm/models/registry.py @@ -0,0 +1,78 @@ +import sys +import re +import fnmatch +from collections import defaultdict + +__all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules'] + +_module_to_models = defaultdict(set) +_model_to_module = {} +_model_entrypoints = {} + + +def register_model(fn): + mod = sys.modules[fn.__module__] + module_name_split = fn.__module__.split('.') + module_name = module_name_split[-1] if len(module_name_split) else '' + if hasattr(mod, '__all__'): + mod.__all__.append(fn.__name__) + else: + mod.__all__ = [fn.__name__] + _model_entrypoints[fn.__name__] = fn + _model_to_module[fn.__name__] = module_name + _module_to_models[module_name].add(fn.__name__) + return fn + + +def _natural_key(string_): + return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] + + +def list_models(filter='', module=''): + """ Return list of available model names, sorted alphabetically + + Args: + filter (str) - Wildcard filter string that works with fnmatch + module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet') + + Example: + model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' + model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module + """ + if module: + models = list(_module_to_models[module]) + else: + models = _model_entrypoints.keys() + if filter: + models = fnmatch.filter(models, filter) + return list(sorted(models, key=_natural_key)) + + +def is_model(model_name): + """ Check if a model name exists + """ + return model_name in _model_entrypoints + + +def model_entrypoint(model_name): + """Fetch a model entrypoint for specified model name + """ + return _model_entrypoints[model_name] + + +def list_modules(): + """ Return list of module names that contain models / model entrypoints + """ + modules = _module_to_models.keys() + return list(sorted(modules)) + + +def is_model_in_modules(model_name, module_names): + """Check if a model exists within a subset of modules + Args: + model_name (str) - name of model to check + module_names (tuple, list, set) - names of modules to search in + """ + assert isinstance(module_names, (tuple, list, set)) + return any(model_name in _module_to_models[n] for n in module_names) + diff --git a/timm/models/resnet.py b/timm/models/resnet.py index cea922be..7ed3b2e1 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -4,17 +4,18 @@ additional dropout and dynamic global avg/max pool. ResNext additions added by Ross Wightman """ +import math + import torch.nn as nn import torch.nn.functional as F -import math + +from .registry import register_model from .helpers import load_pretrained from .adaptive_avgmax_pool import SelectAdaptivePool2d from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -_models = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', - 'resnext50_32x4d', 'resnext101_32x4d', 'resnext101_64x4d', 'resnext152_32x4d', - 'ig_resnext101_32x8d', 'ig_resnext101_32x16d', 'ig_resnext101_32x32d', 'ig_resnext101_32x48d'] -__all__ = ['ResNet'] + _models + +__all__ = ['ResNet'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): @@ -224,6 +225,7 @@ class ResNet(nn.Module): return x +@register_model def resnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-18 model. """ @@ -235,6 +237,7 @@ def resnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def resnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-34 model. """ @@ -246,6 +249,7 @@ def resnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def resnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-50 model. """ @@ -257,6 +261,7 @@ def resnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def resnet101(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-101 model. """ @@ -268,6 +273,7 @@ def resnet101(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def resnet152(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-152 model. """ @@ -279,6 +285,7 @@ def resnet152(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNeXt50-32x4d model. """ @@ -292,6 +299,7 @@ def resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def resnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNeXt-101 model. """ @@ -305,6 +313,7 @@ def resnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def resnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNeXt101-64x4d model. """ @@ -318,6 +327,7 @@ def resnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def resnext152_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNeXt152-32x4d model. """ @@ -331,6 +341,7 @@ def resnext152_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def ig_resnext101_32x8d(pretrained=True, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNeXt-101 32x8 model pre-trained on weakly-supervised data and finetuned on ImageNet from Figure 5 in @@ -349,6 +360,7 @@ def ig_resnext101_32x8d(pretrained=True, num_classes=1000, in_chans=3, **kwargs) return model +@register_model def ig_resnext101_32x16d(pretrained=True, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNeXt-101 32x16 model pre-trained on weakly-supervised data and finetuned on ImageNet from Figure 5 in @@ -367,6 +379,7 @@ def ig_resnext101_32x16d(pretrained=True, num_classes=1000, in_chans=3, **kwargs return model +@register_model def ig_resnext101_32x32d(pretrained=True, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNeXt-101 32x32 model pre-trained on weakly-supervised data and finetuned on ImageNet from Figure 5 in @@ -385,6 +398,7 @@ def ig_resnext101_32x32d(pretrained=True, num_classes=1000, in_chans=3, **kwargs return model +@register_model def ig_resnext101_32x48d(pretrained=True, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNeXt-101 32x48 model pre-trained on weakly-supervised data and finetuned on ImageNet from Figure 5 in diff --git a/timm/models/senet.py b/timm/models/senet.py index 22283116..0be9ac96 100644 --- a/timm/models/senet.py +++ b/timm/models/senet.py @@ -8,20 +8,18 @@ Original model: https://github.com/hujie-frank/SENet ResNet code gently borrowed from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py """ -from __future__ import print_function, division, absolute_import from collections import OrderedDict import math import torch.nn as nn import torch.nn.functional as F +from .registry import register_model from .helpers import load_pretrained from .adaptive_avgmax_pool import SelectAdaptivePool2d from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -_models = ['seresnet18', 'seresnet34', 'seresnet50', 'seresnet101', 'seresnet152', 'senet154', - 'seresnext26_32x4d', 'seresnext50_32x4d', 'seresnext101_32x4d'] -__all__ = ['SENet'] + _models +__all__ = ['SENet'] def _cfg(url='', **kwargs): @@ -400,6 +398,7 @@ class SENet(nn.Module): return x +@register_model def seresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['seresnet18'] model = SENet(SEResNetBlock, [2, 2, 2, 2], groups=1, reduction=16, @@ -412,6 +411,7 @@ def seresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def seresnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['seresnet34'] model = SENet(SEResNetBlock, [3, 4, 6, 3], groups=1, reduction=16, @@ -424,6 +424,7 @@ def seresnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def seresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['seresnet50'] model = SENet(SEResNetBottleneck, [3, 4, 6, 3], groups=1, reduction=16, @@ -436,6 +437,7 @@ def seresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def seresnet101(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['seresnet101'] model = SENet(SEResNetBottleneck, [3, 4, 23, 3], groups=1, reduction=16, @@ -448,6 +450,7 @@ def seresnet101(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def seresnet152(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['seresnet152'] model = SENet(SEResNetBottleneck, [3, 8, 36, 3], groups=1, reduction=16, @@ -460,6 +463,7 @@ def seresnet152(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def senet154(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['senet154'] model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16, @@ -470,6 +474,7 @@ def senet154(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def seresnext26_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['seresnext26_32x4d'] model = SENet(SEResNeXtBottleneck, [2, 2, 2, 2], groups=32, reduction=16, @@ -482,6 +487,7 @@ def seresnext26_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def seresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['seresnext50_32x4d'] model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16, @@ -494,6 +500,7 @@ def seresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model def seresnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['seresnext101_32x4d'] model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16, diff --git a/timm/models/xception.py b/timm/models/xception.py index a2d63b6e..e76ed9ff 100644 --- a/timm/models/xception.py +++ b/timm/models/xception.py @@ -21,17 +21,17 @@ normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 """ -from __future__ import print_function, division, absolute_import import math + import torch import torch.nn as nn import torch.nn.functional as F +from .registry import register_model from .helpers import load_pretrained from .adaptive_avgmax_pool import select_adaptive_pool2d -_models = ['xception'] -__all__ = ['Xception'] + _models +__all__ = ['Xception'] default_cfgs = { 'xception': { @@ -228,6 +228,7 @@ class Xception(nn.Module): return x +@register_model def xception(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['xception'] model = Xception(num_classes=num_classes, in_chans=in_chans, **kwargs) diff --git a/validate.py b/validate.py index af7e343b..199888ad 100644 --- a/validate.py +++ b/validate.py @@ -13,7 +13,7 @@ import torch.nn as nn import torch.nn.parallel from collections import OrderedDict -from timm.models import create_model, apply_test_time_pool, load_checkpoint +from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models from timm.data import Dataset, create_loader, resolve_data_config from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging @@ -144,22 +144,26 @@ def validate(args): def main(): setup_default_logging() args = parser.parse_args() - if args.model == 'all': - # validate all models in a list of names with pretrained checkpoints - args.pretrained = True - # FIXME just an example list, need to add model name collections for - # batch testing of various pretrained combinations by arg string - models = ['tf_efficientnet_b0', 'tf_efficientnet_b1', 'tf_efficientnet_b2', 'tf_efficientnet_b3'] - model_cfgs = [(n, '') for n in models] - elif os.path.isdir(args.checkpoint): + model_cfgs = [] + model_names = [] + if os.path.isdir(args.checkpoint): # validate all checkpoints in a path with same model checkpoints = glob.glob(args.checkpoint + '/*.pth.tar') checkpoints += glob.glob(args.checkpoint + '/*.pth') model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)] else: - model_cfgs = [] + if args.model == 'all': + # validate all models in a list of names with pretrained checkpoints + args.pretrained = True + model_names = list_models() + model_cfgs = [(n, '') for n in model_names] + elif not is_model(args.model): + # model name doesn't exist, try as wildcard filter + model_names = list_models(args.model) + model_cfgs = [(n, '') for n in model_names] if len(model_cfgs): + print('Running bulk validation on these pretrained models:', ', '.join(model_names)) header_written = False with open('./results-all.csv', mode='w') as cf: for m, c in model_cfgs: