Add default_cfg options for min_input_size / fixed_input_size, queries in model registry, and use for testing self-attn models

pull/556/head
Ross Wightman 3 years ago
parent 4e4b863b15
commit 16f7aa9f54

@ -5,7 +5,8 @@ import os
import fnmatch
import timm
from timm import list_models, create_model, set_scriptable
from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \
get_model_default_value
if hasattr(torch._C, '_jit_set_profiling_executor'):
# legacy executor is too slow to compile large models for unit tests
@ -60,9 +61,15 @@ def test_model_backward(model_name, batch_size):
model.eval()
input_size = model.default_cfg['input_size']
if any([x > MAX_BWD_SIZE for x in input_size]):
# cap backward test at 128 * 128 to keep resource usage down
input_size = tuple([min(x, MAX_BWD_SIZE) for x in input_size])
if not is_model_default_key(model_name, 'fixed_input_size'):
min_input_size = get_model_default_value(model_name, 'min_input_size')
if min_input_size is not None:
input_size = min_input_size
else:
if any([x > MAX_BWD_SIZE for x in input_size]):
# cap backward test at 128 * 128 to keep resource usage down
input_size = tuple([min(x, MAX_BWD_SIZE) for x in input_size])
inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs)
outputs.mean().backward()
@ -155,7 +162,14 @@ def test_model_forward_torchscript(model_name, batch_size):
with set_scriptable(True):
model = create_model(model_name, pretrained=False)
model.eval()
input_size = (3, 128, 128) # jit compile is already a bit slow and we've tested normal res already...
if has_model_default_key(model_name, 'fixed_input_size'):
input_size = get_model_default_value(model_name, 'input_size')
elif has_model_default_key(model_name, 'min_input_size'):
input_size = get_model_default_value(model_name, 'min_input_size')
else:
input_size = (3, 128, 128) # jit compile is already a bit slow and we've tested normal res already...
model = torch.jit.script(model)
outputs = model(torch.randn((batch_size, *input_size)))
@ -180,7 +194,14 @@ def test_model_forward_features(model_name, batch_size):
model.eval()
expected_channels = model.feature_info.channels()
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
input_size = (3, 96, 96) # jit compile is already a bit slow and we've tested normal res already...
if has_model_default_key(model_name, 'fixed_input_size'):
input_size = get_model_default_value(model_name, 'input_size')
elif has_model_default_key(model_name, 'min_input_size'):
input_size = get_model_default_value(model_name, 'min_input_size')
else:
input_size = (3, 96, 96) # jit compile is already a bit slow and we've tested normal res already...
outputs = model(torch.randn((batch_size, *input_size)))
assert len(expected_channels) == len(outputs)
for e, o in zip(expected_channels, outputs):

@ -1,3 +1,4 @@
from .version import __version__
from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \
is_scriptable, is_exportable, set_scriptable, set_exportable
is_scriptable, is_exportable, set_scriptable, set_exportable, has_model_default_key, is_model_default_key, \
get_model_default_value, is_model_pretrained

@ -40,4 +40,5 @@ from .helpers import load_checkpoint, resume_checkpoint, model_parameters
from .layers import TestTimePoolHead, apply_test_time_pool
from .layers import convert_splitbn_model
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules
from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\
has_model_default_key, is_model_default_key, get_model_default_value, is_model_pretrained

@ -37,23 +37,24 @@ def _cfg(url='', **kwargs):
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bilinear',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.conv', 'classifier': 'head.fc',
'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc',
'fixed_input_size': False, 'min_input_size': (3, 224, 224),
**kwargs
}
default_cfgs = {
# GPU-Efficient (ResNet) weights
'botnet50t_224': _cfg(url=''),
'botnet50t_c4c5_224': _cfg(url=''),
'botnet50t_224': _cfg(url='', fixed_input_size=True),
'botnet50t_c4c5_224': _cfg(url='', fixed_input_size=True),
'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'halonet_h1_c4c5': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
'halonet_h1_c4c5': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
'halonet26t': _cfg(url=''),
'halonet50t': _cfg(url=''),
'lambda_resnet26t': _cfg(url=''),
'lambda_resnet50t': _cfg(url=''),
'lambda_resnet26t': _cfg(url='', min_input_size=(3, 128, 128)),
'lambda_resnet50t': _cfg(url='', min_input_size=(3, 128, 128)),
}

@ -6,13 +6,16 @@ import sys
import re
import fnmatch
from collections import defaultdict
from copy import deepcopy
__all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules']
__all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
'is_model_default_key', 'has_model_default_key', 'get_model_default_value', 'is_model_pretrained']
_module_to_models = defaultdict(set) # dict of sets to check membership of model in module
_model_to_module = {} # mapping of model names to module names
_model_entrypoints = {} # mapping of model names to entrypoint fns
_model_has_pretrained = set() # set of model names that have pretrained weight url present
_model_default_cfgs = dict() # central repo for model default_cfgs
def register_model(fn):
@ -37,6 +40,7 @@ def register_model(fn):
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing
# entrypoints or non-matching combos
has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url']
_model_default_cfgs[model_name] = deepcopy(mod.default_cfgs[model_name])
if has_pretrained:
_model_has_pretrained.add(model_name)
return fn
@ -105,3 +109,31 @@ def is_model_in_modules(model_name, module_names):
assert isinstance(module_names, (tuple, list, set))
return any(model_name in _module_to_models[n] for n in module_names)
def has_model_default_key(model_name, cfg_key):
""" Query model default_cfgs for existence of a specific key.
"""
if model_name in _model_default_cfgs and cfg_key in _model_default_cfgs[model_name]:
return True
return False
def is_model_default_key(model_name, cfg_key):
""" Return truthy value for specified model default_cfg key, False if does not exist.
"""
if model_name in _model_default_cfgs and _model_default_cfgs[model_name].get(cfg_key, False):
return True
return False
def get_model_default_value(model_name, cfg_key):
""" Get a specific model default_cfg value by key. None if it doesn't exist.
"""
if model_name in _model_default_cfgs:
return _model_default_cfgs[model_name].get(cfg_key, None)
else:
return None
def is_model_pretrained(model_name):
return model_name in _model_has_pretrained

Loading…
Cancel
Save