Add type annotations to _registry.py

Description

Add type annotations to _registry.py so that they will pass mypy
--strict.

Comment

I was reading the code and felt that this module would be easier to
understand with type annotations. Therefore, I went ahead and added the
annotations.

The idea with this PR is to start small to see if we can align on _how_
to annotate types. I've seen people in the past disagree on how strictly
to annotate the code base, so before spending too much time on this, I
wanted to check if you agree, Ross.

Most of the added types should be straightforward. Some notes on the
non-trivial changes:

- I made no assumption about the fn passed to register_model, but maybe
  the type could be stricter. Are all models nn.Modules?
- If I'm not mistaken, the type hint for get_arch_name was incorrect
- I had to add a # type: ignore to model.__all__ = ...
- I made some minor code changes to list_models to facilitate the
  typing. I think the changes should not affect the logic of the function.
- I removed list from list(sorted(...)) because sorted returns always a
  list.
main
Benjamin Bossan 1 year ago committed by Ross Wightman
parent c9406ce608
commit a5b01ec04e

@ -93,7 +93,7 @@ class DefaultCfg:
return tag, self.cfgs[tag] return tag, self.cfgs[tag]
def split_model_name_tag(model_name: str, no_tag=''): def split_model_name_tag(model_name: str, no_tag: str = '') -> Tuple[str, str]:
model_name, *tag_list = model_name.split('.', 1) model_name, *tag_list = model_name.split('.', 1)
tag = tag_list[0] if tag_list else no_tag tag = tag_list[0] if tag_list else no_tag
return model_name, tag return model_name, tag

@ -8,7 +8,7 @@ import sys
from collections import defaultdict, deque from collections import defaultdict, deque
from copy import deepcopy from copy import deepcopy
from dataclasses import replace from dataclasses import replace
from typing import List, Optional, Union, Tuple from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Sequence, Union, Tuple
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
@ -16,20 +16,20 @@ __all__ = [
'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', 'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name'] 'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name']
_module_to_models = defaultdict(set) # dict of sets to check membership of model in module _module_to_models: Dict[str, Set[str]] = defaultdict(set) # dict of sets to check membership of model in module
_model_to_module = {} # mapping of model names to module names _model_to_module: Dict[str, str] = {} # mapping of model names to module names
_model_entrypoints = {} # mapping of model names to architecture entrypoint fns _model_entrypoints: Dict[str, Callable[..., Any]] = {} # mapping of model names to architecture entrypoint fns
_model_has_pretrained = set() # set of model names that have pretrained weight url present _model_has_pretrained: Set[str] = set() # set of model names that have pretrained weight url present
_model_default_cfgs = dict() # central repo for model arch -> default cfg objects _model_default_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch -> default cfg objects
_model_pretrained_cfgs = dict() # central repo for model arch.tag -> pretrained cfgs _model_pretrained_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch.tag -> pretrained cfgs
_model_with_tags = defaultdict(list) # shortcut to map each model arch to all model + tag names _model_with_tags: Dict[str, List[str]] = defaultdict(list) # shortcut to map each model arch to all model + tag names
def get_arch_name(model_name: str) -> Tuple[str, Optional[str]]: def get_arch_name(model_name: str) -> str:
return split_model_name_tag(model_name)[0] return split_model_name_tag(model_name)[0]
def register_model(fn): def register_model(fn: Callable[..., Any]) -> Callable[..., Any]:
# lookup containing module # lookup containing module
mod = sys.modules[fn.__module__] mod = sys.modules[fn.__module__]
module_name_split = fn.__module__.split('.') module_name_split = fn.__module__.split('.')
@ -40,7 +40,7 @@ def register_model(fn):
if hasattr(mod, '__all__'): if hasattr(mod, '__all__'):
mod.__all__.append(model_name) mod.__all__.append(model_name)
else: else:
mod.__all__ = [model_name] mod.__all__ = [model_name] # type: ignore
# add entries to registry dict/sets # add entries to registry dict/sets
_model_entrypoints[model_name] = fn _model_entrypoints[model_name] = fn
@ -87,28 +87,33 @@ def register_model(fn):
return fn return fn
def _natural_key(string_): def _natural_key(string_: str) -> List[Union[int, str]]:
"""See https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/"""
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def list_models( def list_models(
filter: Union[str, List[str]] = '', filter: Union[str, List[str]] = '',
module: str = '', module: str = '',
pretrained=False, pretrained: bool = False,
exclude_filters: str = '', exclude_filters: Union[str, List[str]] = '',
name_matches_cfg: bool = False, name_matches_cfg: bool = False,
include_tags: Optional[bool] = None, include_tags: Optional[bool] = None,
): ) -> List[str]:
""" Return list of available model names, sorted alphabetically """ Return list of available model names, sorted alphabetically
Args: Args:
filter (str) - Wildcard filter string that works with fnmatch filter - Wildcard filter string that works with fnmatch
module (str) - Limit model selection to a specific submodule (ie 'vision_transformer') module - Limit model selection to a specific submodule (ie 'vision_transformer')
pretrained (bool) - Include only models with valid pretrained weights if True pretrained - Include only models with valid pretrained weights if True
exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter exclude_filters - Wildcard filters to exclude models after including them with filter
name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases) name_matches_cfg - Include only models w/ model_name matching default_cfg name (excludes some aliases)
include_tags (Optional[boo]) - Include pretrained tags in model names (model.tag). If None, defaults include_tags - Include pretrained tags in model names (model.tag). If None, defaults
set to True when pretrained=True else False (default: None) set to True when pretrained=True else False (default: None)
Returns:
models - The sorted list of models
Example: Example:
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
@ -118,7 +123,7 @@ def list_models(
include_tags = pretrained include_tags = pretrained
if module: if module:
all_models = list(_module_to_models[module]) all_models: Iterable[str] = list(_module_to_models[module])
else: else:
all_models = _model_entrypoints.keys() all_models = _model_entrypoints.keys()
@ -130,14 +135,14 @@ def list_models(
all_models = models_with_tags all_models = models_with_tags
if filter: if filter:
models = [] models: Set[str] = set()
include_filters = filter if isinstance(filter, (tuple, list)) else [filter] include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
for f in include_filters: for f in include_filters:
include_models = fnmatch.filter(all_models, f) # include these models include_models = fnmatch.filter(all_models, f) # include these models
if len(include_models): if len(include_models):
models = set(models).union(include_models) models = models.union(include_models)
else: else:
models = all_models models = set(all_models)
if exclude_filters: if exclude_filters:
if not isinstance(exclude_filters, (tuple, list)): if not isinstance(exclude_filters, (tuple, list)):
@ -145,7 +150,7 @@ def list_models(
for xf in exclude_filters: for xf in exclude_filters:
exclude_models = fnmatch.filter(models, xf) # exclude these models exclude_models = fnmatch.filter(models, xf) # exclude these models
if len(exclude_models): if len(exclude_models):
models = set(models).difference(exclude_models) models = models.difference(exclude_models)
if pretrained: if pretrained:
models = _model_has_pretrained.intersection(models) models = _model_has_pretrained.intersection(models)
@ -153,13 +158,13 @@ def list_models(
if name_matches_cfg: if name_matches_cfg:
models = set(_model_pretrained_cfgs).intersection(models) models = set(_model_pretrained_cfgs).intersection(models)
return list(sorted(models, key=_natural_key)) return sorted(models, key=_natural_key)
def list_pretrained( def list_pretrained(
filter: Union[str, List[str]] = '', filter: Union[str, List[str]] = '',
exclude_filters: str = '', exclude_filters: str = '',
): ) -> List[str]:
return list_models( return list_models(
filter=filter, filter=filter,
pretrained=True, pretrained=True,
@ -168,14 +173,14 @@ def list_pretrained(
) )
def is_model(model_name): def is_model(model_name: str) -> bool:
""" Check if a model name exists """ Check if a model name exists
""" """
arch_name = get_arch_name(model_name) arch_name = get_arch_name(model_name)
return arch_name in _model_entrypoints return arch_name in _model_entrypoints
def model_entrypoint(model_name, module_filter: Optional[str] = None): def model_entrypoint(model_name: str, module_filter: Optional[str] = None) -> Callable[..., Any]:
"""Fetch a model entrypoint for specified model name """Fetch a model entrypoint for specified model name
""" """
arch_name = get_arch_name(model_name) arch_name = get_arch_name(model_name)
@ -184,29 +189,32 @@ def model_entrypoint(model_name, module_filter: Optional[str] = None):
return _model_entrypoints[arch_name] return _model_entrypoints[arch_name]
def list_modules(): def list_modules() -> List[str]:
""" Return list of module names that contain models / model entrypoints """ Return list of module names that contain models / model entrypoints
""" """
modules = _module_to_models.keys() modules = _module_to_models.keys()
return list(sorted(modules)) return sorted(modules)
def is_model_in_modules(model_name, module_names): def is_model_in_modules(
model_name: str, module_names: Union[Tuple[str, ...], List[str], Set[str]]
) -> bool:
"""Check if a model exists within a subset of modules """Check if a model exists within a subset of modules
Args: Args:
model_name (str) - name of model to check model_name - name of model to check
module_names (tuple, list, set) - names of modules to search in module_names - names of modules to search in
""" """
arch_name = get_arch_name(model_name) arch_name = get_arch_name(model_name)
assert isinstance(module_names, (tuple, list, set)) assert isinstance(module_names, (tuple, list, set))
return any(arch_name in _module_to_models[n] for n in module_names) return any(arch_name in _module_to_models[n] for n in module_names)
def is_model_pretrained(model_name): def is_model_pretrained(model_name: str) -> bool:
return model_name in _model_has_pretrained return model_name in _model_has_pretrained
def get_pretrained_cfg(model_name, allow_unregistered=True): def get_pretrained_cfg(model_name: str, allow_unregistered: bool = True) -> Optional[PretrainedCfg]:
if model_name in _model_pretrained_cfgs: if model_name in _model_pretrained_cfgs:
return deepcopy(_model_pretrained_cfgs[model_name]) return deepcopy(_model_pretrained_cfgs[model_name])
arch_name, tag = split_model_name_tag(model_name) arch_name, tag = split_model_name_tag(model_name)
@ -219,7 +227,7 @@ def get_pretrained_cfg(model_name, allow_unregistered=True):
raise RuntimeError(f'Model architecture ({arch_name}) has no pretrained cfg registered.') raise RuntimeError(f'Model architecture ({arch_name}) has no pretrained cfg registered.')
def get_pretrained_cfg_value(model_name, cfg_key): def get_pretrained_cfg_value(model_name: str, cfg_key: str) -> Optional[Any]:
""" Get a specific model default_cfg value by key. None if key doesn't exist. """ Get a specific model default_cfg value by key. None if key doesn't exist.
""" """
cfg = get_pretrained_cfg(model_name, allow_unregistered=False) cfg = get_pretrained_cfg(model_name, allow_unregistered=False)

Loading…
Cancel
Save