diff --git a/timm/models/_pretrained.py b/timm/models/_pretrained.py index 11e4cff5..6345ce4e 100644 --- a/timm/models/_pretrained.py +++ b/timm/models/_pretrained.py @@ -93,7 +93,7 @@ class DefaultCfg: return tag, self.cfgs[tag] -def split_model_name_tag(model_name: str, no_tag=''): +def split_model_name_tag(model_name: str, no_tag: str = '') -> Tuple[str, str]: model_name, *tag_list = model_name.split('.', 1) tag = tag_list[0] if tag_list else no_tag return model_name, tag diff --git a/timm/models/_registry.py b/timm/models/_registry.py index 80eb2e94..76361ec5 100644 --- a/timm/models/_registry.py +++ b/timm/models/_registry.py @@ -8,7 +8,7 @@ import sys from collections import defaultdict, deque from copy import deepcopy from dataclasses import replace -from typing import List, Optional, Union, Tuple +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Sequence, Union, Tuple from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag @@ -16,20 +16,20 @@ __all__ = [ 'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', 'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name'] -_module_to_models = defaultdict(set) # dict of sets to check membership of model in module -_model_to_module = {} # mapping of model names to module names -_model_entrypoints = {} # mapping of model names to architecture entrypoint fns -_model_has_pretrained = set() # set of model names that have pretrained weight url present -_model_default_cfgs = dict() # central repo for model arch -> default cfg objects -_model_pretrained_cfgs = dict() # central repo for model arch.tag -> pretrained cfgs -_model_with_tags = defaultdict(list) # shortcut to map each model arch to all model + tag names +_module_to_models: Dict[str, Set[str]] = defaultdict(set) # dict of sets to check membership of model in module +_model_to_module: Dict[str, str] = {} # mapping of model names to module names +_model_entrypoints: Dict[str, Callable[..., Any]] = {} # mapping of model names to architecture entrypoint fns +_model_has_pretrained: Set[str] = set() # set of model names that have pretrained weight url present +_model_default_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch -> default cfg objects +_model_pretrained_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch.tag -> pretrained cfgs +_model_with_tags: Dict[str, List[str]] = defaultdict(list) # shortcut to map each model arch to all model + tag names -def get_arch_name(model_name: str) -> Tuple[str, Optional[str]]: +def get_arch_name(model_name: str) -> str: return split_model_name_tag(model_name)[0] -def register_model(fn): +def register_model(fn: Callable[..., Any]) -> Callable[..., Any]: # lookup containing module mod = sys.modules[fn.__module__] module_name_split = fn.__module__.split('.') @@ -40,7 +40,7 @@ def register_model(fn): if hasattr(mod, '__all__'): mod.__all__.append(model_name) else: - mod.__all__ = [model_name] + mod.__all__ = [model_name] # type: ignore # add entries to registry dict/sets _model_entrypoints[model_name] = fn @@ -87,28 +87,33 @@ def register_model(fn): return fn -def _natural_key(string_): +def _natural_key(string_: str) -> List[Union[int, str]]: + """See https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/""" return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] def list_models( filter: Union[str, List[str]] = '', module: str = '', - pretrained=False, - exclude_filters: str = '', + pretrained: bool = False, + exclude_filters: Union[str, List[str]] = '', name_matches_cfg: bool = False, include_tags: Optional[bool] = None, -): +) -> List[str]: """ Return list of available model names, sorted alphabetically Args: - filter (str) - Wildcard filter string that works with fnmatch - module (str) - Limit model selection to a specific submodule (ie 'vision_transformer') - pretrained (bool) - Include only models with valid pretrained weights if True - exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter - name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases) - include_tags (Optional[boo]) - Include pretrained tags in model names (model.tag). If None, defaults + filter - Wildcard filter string that works with fnmatch + module - Limit model selection to a specific submodule (ie 'vision_transformer') + pretrained - Include only models with valid pretrained weights if True + exclude_filters - Wildcard filters to exclude models after including them with filter + name_matches_cfg - Include only models w/ model_name matching default_cfg name (excludes some aliases) + include_tags - Include pretrained tags in model names (model.tag). If None, defaults set to True when pretrained=True else False (default: None) + + Returns: + models - The sorted list of models + Example: model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module @@ -118,7 +123,7 @@ def list_models( include_tags = pretrained if module: - all_models = list(_module_to_models[module]) + all_models: Iterable[str] = list(_module_to_models[module]) else: all_models = _model_entrypoints.keys() @@ -130,14 +135,14 @@ def list_models( all_models = models_with_tags if filter: - models = [] + models: Set[str] = set() include_filters = filter if isinstance(filter, (tuple, list)) else [filter] for f in include_filters: include_models = fnmatch.filter(all_models, f) # include these models if len(include_models): - models = set(models).union(include_models) + models = models.union(include_models) else: - models = all_models + models = set(all_models) if exclude_filters: if not isinstance(exclude_filters, (tuple, list)): @@ -145,7 +150,7 @@ def list_models( for xf in exclude_filters: exclude_models = fnmatch.filter(models, xf) # exclude these models if len(exclude_models): - models = set(models).difference(exclude_models) + models = models.difference(exclude_models) if pretrained: models = _model_has_pretrained.intersection(models) @@ -153,13 +158,13 @@ def list_models( if name_matches_cfg: models = set(_model_pretrained_cfgs).intersection(models) - return list(sorted(models, key=_natural_key)) + return sorted(models, key=_natural_key) def list_pretrained( filter: Union[str, List[str]] = '', exclude_filters: str = '', -): +) -> List[str]: return list_models( filter=filter, pretrained=True, @@ -168,14 +173,14 @@ def list_pretrained( ) -def is_model(model_name): +def is_model(model_name: str) -> bool: """ Check if a model name exists """ arch_name = get_arch_name(model_name) return arch_name in _model_entrypoints -def model_entrypoint(model_name, module_filter: Optional[str] = None): +def model_entrypoint(model_name: str, module_filter: Optional[str] = None) -> Callable[..., Any]: """Fetch a model entrypoint for specified model name """ arch_name = get_arch_name(model_name) @@ -184,29 +189,32 @@ def model_entrypoint(model_name, module_filter: Optional[str] = None): return _model_entrypoints[arch_name] -def list_modules(): +def list_modules() -> List[str]: """ Return list of module names that contain models / model entrypoints """ modules = _module_to_models.keys() - return list(sorted(modules)) + return sorted(modules) -def is_model_in_modules(model_name, module_names): +def is_model_in_modules( + model_name: str, module_names: Union[Tuple[str, ...], List[str], Set[str]] +) -> bool: """Check if a model exists within a subset of modules + Args: - model_name (str) - name of model to check - module_names (tuple, list, set) - names of modules to search in + model_name - name of model to check + module_names - names of modules to search in """ arch_name = get_arch_name(model_name) assert isinstance(module_names, (tuple, list, set)) return any(arch_name in _module_to_models[n] for n in module_names) -def is_model_pretrained(model_name): +def is_model_pretrained(model_name: str) -> bool: return model_name in _model_has_pretrained -def get_pretrained_cfg(model_name, allow_unregistered=True): +def get_pretrained_cfg(model_name: str, allow_unregistered: bool = True) -> Optional[PretrainedCfg]: if model_name in _model_pretrained_cfgs: return deepcopy(_model_pretrained_cfgs[model_name]) arch_name, tag = split_model_name_tag(model_name) @@ -219,7 +227,7 @@ def get_pretrained_cfg(model_name, allow_unregistered=True): raise RuntimeError(f'Model architecture ({arch_name}) has no pretrained cfg registered.') -def get_pretrained_cfg_value(model_name, cfg_key): +def get_pretrained_cfg_value(model_name: str, cfg_key: str) -> Optional[Any]: """ Get a specific model default_cfg value by key. None if key doesn't exist. """ cfg = get_pretrained_cfg(model_name, allow_unregistered=False)