From 9cc289f18c80100c8630808cf0842f4eb03f0b5d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 12 May 2020 13:07:03 -0700 Subject: [PATCH] Exclude EfficientNet-L2 models from test --- tests/test_inference.py | 2 +- timm/models/registry.py | 13 +++++++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/tests/test_inference.py b/tests/test_inference.py index dc45c409..55bafb21 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -5,7 +5,7 @@ from timm import list_models, create_model @pytest.mark.timeout(300) -@pytest.mark.parametrize('model_name', list_models()) +@pytest.mark.parametrize('model_name', list_models(exclude_filters='*efficientnet_l2*')) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward(model_name, batch_size): """Run a single forward pass with each model""" diff --git a/timm/models/registry.py b/timm/models/registry.py index c15f5414..2b8a3717 100644 --- a/timm/models/registry.py +++ b/timm/models/registry.py @@ -42,12 +42,14 @@ def _natural_key(string_): return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] -def list_models(filter='', module='', pretrained=False): +def list_models(filter='', module='', pretrained=False, exclude_filters=''): """ Return list of available model names, sorted alphabetically Args: filter (str) - Wildcard filter string that works with fnmatch module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet') + pretrained (bool) - Include only models with pretrained weights if True + exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter Example: model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' @@ -58,7 +60,14 @@ def list_models(filter='', module='', pretrained=False): else: models = _model_entrypoints.keys() if filter: - models = fnmatch.filter(models, filter) + models = fnmatch.filter(models, filter) # include these models + if exclude_filters: + if not isinstance(exclude_filters, list): + exclude_filters = [exclude_filters] + for xf in exclude_filters: + exclude_models = fnmatch.filter(models, xf) # exclude these models + if len(exclude_models): + models = set(models).difference(exclude_models) if pretrained: models = _model_has_pretrained.intersection(models) return list(sorted(models, key=_natural_key))