Exclude EfficientNet-L2 models from test

pull/146/head
Ross Wightman 5 years ago
parent e545bb9401
commit 9cc289f18c

@ -5,7 +5,7 @@ from timm import list_models, create_model
@pytest.mark.timeout(300) @pytest.mark.timeout(300)
@pytest.mark.parametrize('model_name', list_models()) @pytest.mark.parametrize('model_name', list_models(exclude_filters='*efficientnet_l2*'))
@pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('batch_size', [1])
def test_model_forward(model_name, batch_size): def test_model_forward(model_name, batch_size):
"""Run a single forward pass with each model""" """Run a single forward pass with each model"""

@ -42,12 +42,14 @@ def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def list_models(filter='', module='', pretrained=False): def list_models(filter='', module='', pretrained=False, exclude_filters=''):
""" Return list of available model names, sorted alphabetically """ Return list of available model names, sorted alphabetically
Args: Args:
filter (str) - Wildcard filter string that works with fnmatch filter (str) - Wildcard filter string that works with fnmatch
module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet') module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet')
pretrained (bool) - Include only models with pretrained weights if True
exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
Example: Example:
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
@ -58,7 +60,14 @@ def list_models(filter='', module='', pretrained=False):
else: else:
models = _model_entrypoints.keys() models = _model_entrypoints.keys()
if filter: if filter:
models = fnmatch.filter(models, filter) models = fnmatch.filter(models, filter) # include these models
if exclude_filters:
if not isinstance(exclude_filters, list):
exclude_filters = [exclude_filters]
for xf in exclude_filters:
exclude_models = fnmatch.filter(models, xf) # exclude these models
if len(exclude_models):
models = set(models).difference(exclude_models)
if pretrained: if pretrained:
models = _model_has_pretrained.intersection(models) models = _model_has_pretrained.intersection(models)
return list(sorted(models, key=_natural_key)) return list(sorted(models, key=_natural_key))

Loading…
Cancel
Save