diff --git a/tests/test_models.py b/tests/test_models.py index 4fbdc85b..1eece766 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -5,7 +5,8 @@ import os import fnmatch import timm -from timm import list_models, create_model, set_scriptable +from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \ + get_model_default_value if hasattr(torch._C, '_jit_set_profiling_executor'): # legacy executor is too slow to compile large models for unit tests @@ -60,9 +61,15 @@ def test_model_backward(model_name, batch_size): model.eval() input_size = model.default_cfg['input_size'] - if any([x > MAX_BWD_SIZE for x in input_size]): - # cap backward test at 128 * 128 to keep resource usage down - input_size = tuple([min(x, MAX_BWD_SIZE) for x in input_size]) + if not is_model_default_key(model_name, 'fixed_input_size'): + min_input_size = get_model_default_value(model_name, 'min_input_size') + if min_input_size is not None: + input_size = min_input_size + else: + if any([x > MAX_BWD_SIZE for x in input_size]): + # cap backward test at 128 * 128 to keep resource usage down + input_size = tuple([min(x, MAX_BWD_SIZE) for x in input_size]) + inputs = torch.randn((batch_size, *input_size)) outputs = model(inputs) outputs.mean().backward() @@ -155,7 +162,14 @@ def test_model_forward_torchscript(model_name, batch_size): with set_scriptable(True): model = create_model(model_name, pretrained=False) model.eval() - input_size = (3, 128, 128) # jit compile is already a bit slow and we've tested normal res already... + + if has_model_default_key(model_name, 'fixed_input_size'): + input_size = get_model_default_value(model_name, 'input_size') + elif has_model_default_key(model_name, 'min_input_size'): + input_size = get_model_default_value(model_name, 'min_input_size') + else: + input_size = (3, 128, 128) # jit compile is already a bit slow and we've tested normal res already... + model = torch.jit.script(model) outputs = model(torch.randn((batch_size, *input_size))) @@ -180,7 +194,14 @@ def test_model_forward_features(model_name, batch_size): model.eval() expected_channels = model.feature_info.channels() assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6 - input_size = (3, 96, 96) # jit compile is already a bit slow and we've tested normal res already... + + if has_model_default_key(model_name, 'fixed_input_size'): + input_size = get_model_default_value(model_name, 'input_size') + elif has_model_default_key(model_name, 'min_input_size'): + input_size = get_model_default_value(model_name, 'min_input_size') + else: + input_size = (3, 96, 96) # jit compile is already a bit slow and we've tested normal res already... + outputs = model(torch.randn((batch_size, *input_size))) assert len(expected_channels) == len(outputs) for e, o in zip(expected_channels, outputs): diff --git a/timm/__init__.py b/timm/__init__.py index db3d3f22..04ec7e51 100644 --- a/timm/__init__.py +++ b/timm/__init__.py @@ -1,3 +1,4 @@ from .version import __version__ from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \ - is_scriptable, is_exportable, set_scriptable, set_exportable + is_scriptable, is_exportable, set_scriptable, set_exportable, has_model_default_key, is_model_default_key, \ + get_model_default_value, is_model_pretrained diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 3ed8bdb3..0ed02652 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -40,4 +40,5 @@ from .helpers import load_checkpoint, resume_checkpoint, model_parameters from .layers import TestTimePoolHead, apply_test_time_pool from .layers import convert_splitbn_model from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit -from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules +from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\ + has_model_default_key, is_model_default_key, get_model_default_value, is_model_pretrained diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index 935b6309..faf366e6 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -37,23 +37,24 @@ def _cfg(url='', **kwargs): 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bilinear', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'stem.conv', 'classifier': 'head.fc', + 'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc', + 'fixed_input_size': False, 'min_input_size': (3, 224, 224), **kwargs } default_cfgs = { # GPU-Efficient (ResNet) weights - 'botnet50t_224': _cfg(url=''), - 'botnet50t_c4c5_224': _cfg(url=''), + 'botnet50t_224': _cfg(url='', fixed_input_size=True), + 'botnet50t_c4c5_224': _cfg(url='', fixed_input_size=True), - 'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), - 'halonet_h1_c4c5': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), + 'halonet_h1_c4c5': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), 'halonet26t': _cfg(url=''), 'halonet50t': _cfg(url=''), - 'lambda_resnet26t': _cfg(url=''), - 'lambda_resnet50t': _cfg(url=''), + 'lambda_resnet26t': _cfg(url='', min_input_size=(3, 128, 128)), + 'lambda_resnet50t': _cfg(url='', min_input_size=(3, 128, 128)), } diff --git a/timm/models/registry.py b/timm/models/registry.py index 3317eece..9172ac7e 100644 --- a/timm/models/registry.py +++ b/timm/models/registry.py @@ -6,13 +6,16 @@ import sys import re import fnmatch from collections import defaultdict +from copy import deepcopy -__all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules'] +__all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', + 'is_model_default_key', 'has_model_default_key', 'get_model_default_value', 'is_model_pretrained'] _module_to_models = defaultdict(set) # dict of sets to check membership of model in module _model_to_module = {} # mapping of model names to module names _model_entrypoints = {} # mapping of model names to entrypoint fns _model_has_pretrained = set() # set of model names that have pretrained weight url present +_model_default_cfgs = dict() # central repo for model default_cfgs def register_model(fn): @@ -37,6 +40,7 @@ def register_model(fn): # this will catch all models that have entrypoint matching cfg key, but miss any aliasing # entrypoints or non-matching combos has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url'] + _model_default_cfgs[model_name] = deepcopy(mod.default_cfgs[model_name]) if has_pretrained: _model_has_pretrained.add(model_name) return fn @@ -105,3 +109,31 @@ def is_model_in_modules(model_name, module_names): assert isinstance(module_names, (tuple, list, set)) return any(model_name in _module_to_models[n] for n in module_names) + +def has_model_default_key(model_name, cfg_key): + """ Query model default_cfgs for existence of a specific key. + """ + if model_name in _model_default_cfgs and cfg_key in _model_default_cfgs[model_name]: + return True + return False + + +def is_model_default_key(model_name, cfg_key): + """ Return truthy value for specified model default_cfg key, False if does not exist. + """ + if model_name in _model_default_cfgs and _model_default_cfgs[model_name].get(cfg_key, False): + return True + return False + + +def get_model_default_value(model_name, cfg_key): + """ Get a specific model default_cfg value by key. None if it doesn't exist. + """ + if model_name in _model_default_cfgs: + return _model_default_cfgs[model_name].get(cfg_key, None) + else: + return None + + +def is_model_pretrained(model_name): + return model_name in _model_has_pretrained