Filter test models before creation for backward/torchscript tests

pull/637/head
Ross Wightman 4 years ago
parent c4572cc5aa
commit d400f1dbdd

@ -40,14 +40,25 @@ TARGET_FFEAT_SIZE = 96
MAX_FFEAT_SIZE = 256
def _get_input_size(model, target=None):
default_cfg = model.default_cfg
input_size = default_cfg['input_size']
if 'fixed_input_size' in default_cfg and default_cfg['fixed_input_size']:
def _get_input_size(model=None, model_name='', target=None):
if model is None:
assert model_name, "One of model or model_name must be provided"
input_size = get_model_default_value(model_name, 'input_size')
fixed_input_size = get_model_default_value(model_name, 'fixed_input_size')
min_input_size = get_model_default_value(model_name, 'min_input_size')
else:
default_cfg = model.default_cfg
input_size = default_cfg['input_size']
fixed_input_size = default_cfg.get('fixed_input_size', None)
min_input_size = default_cfg.get('min_input_size', None)
assert input_size is not None
if fixed_input_size:
return input_size
if 'min_input_size' in default_cfg:
if min_input_size:
if target and max(input_size) > target:
input_size = default_cfg['min_input_size']
input_size = min_input_size
else:
if target and max(input_size) > target:
input_size = tuple([min(x, target) for x in input_size])
@ -73,18 +84,18 @@ def test_model_forward(model_name, batch_size):
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS))
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [2])
def test_model_backward(model_name, batch_size):
"""Run a single forward pass with each model"""
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE)
if max(input_size) > MAX_BWD_SIZE:
pytest.skip("Fixed input size model > limit.")
model = create_model(model_name, pretrained=False, num_classes=42)
num_params = sum([x.numel() for x in model.parameters()])
model.train()
input_size = _get_input_size(model, TARGET_BWD_SIZE)
if max(input_size) > MAX_BWD_SIZE:
pytest.skip("Fixed input size model > limit.")
inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs)
if isinstance(outputs, tuple):
@ -172,18 +183,19 @@ EXCLUDE_JIT_FILTERS = [
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS))
@pytest.mark.parametrize(
'model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_torchscript(model_name, batch_size):
"""Run a single forward pass with each model"""
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
if max(input_size) > MAX_JIT_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional
pytest.skip("Fixed input size model > limit.")
with set_scriptable(True):
model = create_model(model_name, pretrained=False)
model.eval()
input_size = _get_input_size(model, TARGET_JIT_SIZE)
if max(input_size) > MAX_JIT_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional
pytest.skip("Fixed input size model > limit.")
model = torch.jit.script(model)
outputs = model(torch.randn((batch_size, *input_size)))

@ -50,7 +50,7 @@ def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def list_models(filter='', module='', pretrained=False, exclude_filters=''):
def list_models(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False):
""" Return list of available model names, sorted alphabetically
Args:
@ -58,6 +58,7 @@ def list_models(filter='', module='', pretrained=False, exclude_filters=''):
module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet')
pretrained (bool) - Include only models with pretrained weights if True
exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases)
Example:
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
@ -70,7 +71,7 @@ def list_models(filter='', module='', pretrained=False, exclude_filters=''):
if filter:
models = fnmatch.filter(models, filter) # include these models
if exclude_filters:
if not isinstance(exclude_filters, list):
if not isinstance(exclude_filters, (tuple, list)):
exclude_filters = [exclude_filters]
for xf in exclude_filters:
exclude_models = fnmatch.filter(models, xf) # exclude these models
@ -78,6 +79,8 @@ def list_models(filter='', module='', pretrained=False, exclude_filters=''):
models = set(models).difference(exclude_models)
if pretrained:
models = _model_has_pretrained.intersection(models)
if name_matches_cfg:
models = set(_model_default_cfgs).intersection(models)
return list(sorted(models, key=_natural_key))

Loading…
Cancel
Save