You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/models/registry.py

79 lines
2.3 KiB

import sys
import re
import fnmatch
from collections import defaultdict
__all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules']
_module_to_models = defaultdict(set)
_model_to_module = {}
_model_entrypoints = {}
def register_model(fn):
mod = sys.modules[fn.__module__]
module_name_split = fn.__module__.split('.')
module_name = module_name_split[-1] if len(module_name_split) else ''
if hasattr(mod, '__all__'):
mod.__all__.append(fn.__name__)
else:
mod.__all__ = [fn.__name__]
_model_entrypoints[fn.__name__] = fn
_model_to_module[fn.__name__] = module_name
_module_to_models[module_name].add(fn.__name__)
return fn
def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def list_models(filter='', module=''):
""" Return list of available model names, sorted alphabetically
Args:
filter (str) - Wildcard filter string that works with fnmatch
module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet')
Example:
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
"""
if module:
models = list(_module_to_models[module])
else:
models = _model_entrypoints.keys()
if filter:
models = fnmatch.filter(models, filter)
return list(sorted(models, key=_natural_key))
def is_model(model_name):
""" Check if a model name exists
"""
return model_name in _model_entrypoints
def model_entrypoint(model_name):
"""Fetch a model entrypoint for specified model name
"""
return _model_entrypoints[model_name]
def list_modules():
""" Return list of module names that contain models / model entrypoints
"""
modules = _module_to_models.keys()
return list(sorted(modules))
def is_model_in_modules(model_name, module_names):
"""Check if a model exists within a subset of modules
Args:
model_name (str) - name of model to check
module_names (tuple, list, set) - names of modules to search in
"""
assert isinstance(module_names, (tuple, list, set))
return any(model_name in _module_to_models[n] for n in module_names)