Exclude EfficientNet-L2 models from test

pull/146/head
Ross Wightman 5 years ago
parent e545bb9401
commit 9cc289f18c

@ -5,7 +5,7 @@ from timm import list_models, create_model
@pytest.mark.timeout(300)
@pytest.mark.parametrize('model_name', list_models())
@pytest.mark.parametrize('model_name', list_models(exclude_filters='*efficientnet_l2*'))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward(model_name, batch_size):
"""Run a single forward pass with each model"""

@ -42,12 +42,14 @@ def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def list_models(filter='', module='', pretrained=False):
def list_models(filter='', module='', pretrained=False, exclude_filters=''):
""" Return list of available model names, sorted alphabetically
Args:
filter (str) - Wildcard filter string that works with fnmatch
module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet')
pretrained (bool) - Include only models with pretrained weights if True
exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
Example:
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
@ -58,7 +60,14 @@ def list_models(filter='', module='', pretrained=False):
else:
models = _model_entrypoints.keys()
if filter:
models = fnmatch.filter(models, filter)
models = fnmatch.filter(models, filter) # include these models
if exclude_filters:
if not isinstance(exclude_filters, list):
exclude_filters = [exclude_filters]
for xf in exclude_filters:
exclude_models = fnmatch.filter(models, xf) # exclude these models
if len(exclude_models):
models = set(models).difference(exclude_models)
if pretrained:
models = _model_has_pretrained.intersection(models)
return list(sorted(models, key=_natural_key))

Loading…
Cancel
Save