diff --git a/tests/test_models.py b/tests/test_models.py index e6a73619..63a95fa5 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -40,14 +40,25 @@ TARGET_FFEAT_SIZE = 96 MAX_FFEAT_SIZE = 256 -def _get_input_size(model, target=None): - default_cfg = model.default_cfg - input_size = default_cfg['input_size'] - if 'fixed_input_size' in default_cfg and default_cfg['fixed_input_size']: +def _get_input_size(model=None, model_name='', target=None): + if model is None: + assert model_name, "One of model or model_name must be provided" + input_size = get_model_default_value(model_name, 'input_size') + fixed_input_size = get_model_default_value(model_name, 'fixed_input_size') + min_input_size = get_model_default_value(model_name, 'min_input_size') + else: + default_cfg = model.default_cfg + input_size = default_cfg['input_size'] + fixed_input_size = default_cfg.get('fixed_input_size', None) + min_input_size = default_cfg.get('min_input_size', None) + assert input_size is not None + + if fixed_input_size: return input_size - if 'min_input_size' in default_cfg: + + if min_input_size: if target and max(input_size) > target: - input_size = default_cfg['min_input_size'] + input_size = min_input_size else: if target and max(input_size) > target: input_size = tuple([min(x, target) for x in input_size]) @@ -73,18 +84,18 @@ def test_model_forward(model_name, batch_size): @pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS)) +@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True)) @pytest.mark.parametrize('batch_size', [2]) def test_model_backward(model_name, batch_size): """Run a single forward pass with each model""" + input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE) + if max(input_size) > MAX_BWD_SIZE: + pytest.skip("Fixed input size model > limit.") + model = create_model(model_name, pretrained=False, num_classes=42) num_params = sum([x.numel() for x in model.parameters()]) model.train() - input_size = _get_input_size(model, TARGET_BWD_SIZE) - if max(input_size) > MAX_BWD_SIZE: - pytest.skip("Fixed input size model > limit.") - inputs = torch.randn((batch_size, *input_size)) outputs = model(inputs) if isinstance(outputs, tuple): @@ -172,18 +183,19 @@ EXCLUDE_JIT_FILTERS = [ @pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS)) +@pytest.mark.parametrize( + 'model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS, name_matches_cfg=True)) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward_torchscript(model_name, batch_size): """Run a single forward pass with each model""" + input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE) + if max(input_size) > MAX_JIT_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional + pytest.skip("Fixed input size model > limit.") + with set_scriptable(True): model = create_model(model_name, pretrained=False) model.eval() - input_size = _get_input_size(model, TARGET_JIT_SIZE) - if max(input_size) > MAX_JIT_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional - pytest.skip("Fixed input size model > limit.") - model = torch.jit.script(model) outputs = model(torch.randn((batch_size, *input_size))) diff --git a/timm/models/registry.py b/timm/models/registry.py index 9172ac7e..6927b6d6 100644 --- a/timm/models/registry.py +++ b/timm/models/registry.py @@ -50,7 +50,7 @@ def _natural_key(string_): return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] -def list_models(filter='', module='', pretrained=False, exclude_filters=''): +def list_models(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False): """ Return list of available model names, sorted alphabetically Args: @@ -58,6 +58,7 @@ def list_models(filter='', module='', pretrained=False, exclude_filters=''): module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet') pretrained (bool) - Include only models with pretrained weights if True exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter + name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases) Example: model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' @@ -70,7 +71,7 @@ def list_models(filter='', module='', pretrained=False, exclude_filters=''): if filter: models = fnmatch.filter(models, filter) # include these models if exclude_filters: - if not isinstance(exclude_filters, list): + if not isinstance(exclude_filters, (tuple, list)): exclude_filters = [exclude_filters] for xf in exclude_filters: exclude_models = fnmatch.filter(models, xf) # exclude these models @@ -78,6 +79,8 @@ def list_models(filter='', module='', pretrained=False, exclude_filters=''): models = set(models).difference(exclude_models) if pretrained: models = _model_has_pretrained.intersection(models) + if name_matches_cfg: + models = set(_model_default_cfgs).intersection(models) return list(sorted(models, key=_natural_key))