@ -8,7 +8,7 @@ import sys
from collections import defaultdict , deque
from copy import deepcopy
from dataclasses import replace
from typing import List, Optional , Union , Tuple
from typing import Any, Callable , Dict , Iterable , List, Optional , Set , Sequence , Union , Tuple
from . _pretrained import PretrainedCfg , DefaultCfg , split_model_name_tag
@ -16,20 +16,20 @@ __all__ = [
' list_models ' , ' list_pretrained ' , ' is_model ' , ' model_entrypoint ' , ' list_modules ' , ' is_model_in_modules ' ,
' get_pretrained_cfg_value ' , ' is_model_pretrained ' , ' get_arch_name ' ]
_module_to_models = defaultdict ( set ) # dict of sets to check membership of model in module
_model_to_module = { } # mapping of model names to module names
_model_entrypoints = { } # mapping of model names to architecture entrypoint fns
_model_has_pretrained = set ( ) # set of model names that have pretrained weight url present
_model_default_cfgs = dict ( ) # central repo for model arch -> default cfg objects
_model_pretrained_cfgs = dict ( ) # central repo for model arch.tag -> pretrained cfgs
_model_with_tags = defaultdict ( list ) # shortcut to map each model arch to all model + tag names
_module_to_models : Dict [ str , Set [ str ] ] = defaultdict ( set ) # dict of sets to check membership of model in module
_model_to_module : Dict [ str , str ] = { } # mapping of model names to module names
_model_entrypoints : Dict [ str , Callable [ . . . , Any ] ] = { } # mapping of model names to architecture entrypoint fns
_model_has_pretrained : Set [ str ] = set ( ) # set of model names that have pretrained weight url present
_model_default_cfgs : Dict [ str , PretrainedCfg ] = { } # central repo for model arch -> default cfg objects
_model_pretrained_cfgs : Dict [ str , PretrainedCfg ] = { } # central repo for model arch.tag -> pretrained cfgs
_model_with_tags : Dict [ str , List [ str ] ] = defaultdict ( list ) # shortcut to map each model arch to all model + tag names
def get_arch_name ( model_name : str ) - > Tuple [ str , Optional [ str ] ] :
def get_arch_name ( model_name : str ) - > str :
return split_model_name_tag ( model_name ) [ 0 ]
def register_model ( fn ):
def register_model ( fn : Callable [ . . . , Any ] ) - > Callable [ . . . , Any ] :
# lookup containing module
mod = sys . modules [ fn . __module__ ]
module_name_split = fn . __module__ . split ( ' . ' )
@ -40,7 +40,7 @@ def register_model(fn):
if hasattr ( mod , ' __all__ ' ) :
mod . __all__ . append ( model_name )
else :
mod . __all__ = [ model_name ]
mod . __all__ = [ model_name ] # type: ignore
# add entries to registry dict/sets
_model_entrypoints [ model_name ] = fn
@ -87,28 +87,33 @@ def register_model(fn):
return fn
def _natural_key ( string_ ) :
def _natural_key ( string_ : str ) - > List [ Union [ int , str ] ] :
""" See https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/ """
return [ int ( s ) if s . isdigit ( ) else s for s in re . split ( r ' ( \ d+) ' , string_ . lower ( ) ) ]
def list_models (
filter : Union [ str , List [ str ] ] = ' ' ,
module : str = ' ' ,
pretrained = False ,
exclude_filters : str = ' ' ,
pretrained : bool = False ,
exclude_filters : Union [ str , List [ str ] ] = ' ' ,
name_matches_cfg : bool = False ,
include_tags : Optional [ bool ] = None ,
) :
) - > List [ str ] :
""" Return list of available model names, sorted alphabetically
Args :
filter ( str ) - Wildcard filter string that works with fnmatch
module ( str ) - Limit model selection to a specific submodule ( ie ' vision_transformer ' )
pretrained ( bool ) - Include only models with valid pretrained weights if True
exclude_filters ( str or list [ str ] ) - Wildcard filters to exclude models after including them with filter
name_matches_cfg ( bool ) - Include only models w / model_name matching default_cfg name ( excludes some aliases )
include_tags ( Optional [ boo ] ) - Include pretrained tags in model names ( model . tag ) . If None , defaults
filter - Wildcard filter string that works with fnmatch
module - Limit model selection to a specific submodule ( ie ' vision_transformer ' )
pretrained - Include only models with valid pretrained weights if True
exclude_filters - Wildcard filters to exclude models after including them with filter
name_matches_cfg - Include only models w / model_name matching default_cfg name ( excludes some aliases )
include_tags - Include pretrained tags in model names ( model . tag ) . If None , defaults
set to True when pretrained = True else False ( default : None )
Returns :
models - The sorted list of models
Example :
model_list ( ' gluon_resnet* ' ) - - returns all models starting with ' gluon_resnet '
model_list ( ' *resnext*, ' resnet ' ) -- returns all models with ' resnext ' in ' resnet ' module
@ -118,7 +123,7 @@ def list_models(
include_tags = pretrained
if module :
all_models = list ( _module_to_models [ module ] )
all_models : Iterable [ str ] = list ( _module_to_models [ module ] )
else :
all_models = _model_entrypoints . keys ( )
@ -130,14 +135,14 @@ def list_models(
all_models = models_with_tags
if filter :
models = [ ]
models : Set [ str ] = set ( )
include_filters = filter if isinstance ( filter , ( tuple , list ) ) else [ filter ]
for f in include_filters :
include_models = fnmatch . filter ( all_models , f ) # include these models
if len ( include_models ) :
models = set ( models ) . union ( include_models )
models = models . union ( include_models )
else :
models = all_models
models = set ( all_models )
if exclude_filters :
if not isinstance ( exclude_filters , ( tuple , list ) ) :
@ -145,7 +150,7 @@ def list_models(
for xf in exclude_filters :
exclude_models = fnmatch . filter ( models , xf ) # exclude these models
if len ( exclude_models ) :
models = set ( models ) . difference ( exclude_models )
models = models . difference ( exclude_models )
if pretrained :
models = _model_has_pretrained . intersection ( models )
@ -153,13 +158,13 @@ def list_models(
if name_matches_cfg :
models = set ( _model_pretrained_cfgs ) . intersection ( models )
return list( sorted( models , key = _natural_key ) )
return sorted( models , key = _natural_key )
def list_pretrained (
filter : Union [ str , List [ str ] ] = ' ' ,
exclude_filters : str = ' ' ,
) :
) - > List [ str ] :
return list_models (
filter = filter ,
pretrained = True ,
@ -168,14 +173,14 @@ def list_pretrained(
)
def is_model ( model_name ):
def is_model ( model_name : str ) - > bool :
""" Check if a model name exists
"""
arch_name = get_arch_name ( model_name )
return arch_name in _model_entrypoints
def model_entrypoint ( model_name , module_filter : Optional [ str ] = None ) :
def model_entrypoint ( model_name : str , module_filter : Optional [ str ] = None ) - > Callable [ . . . , Any ] :
""" Fetch a model entrypoint for specified model name
"""
arch_name = get_arch_name ( model_name )
@ -184,29 +189,32 @@ def model_entrypoint(model_name, module_filter: Optional[str] = None):
return _model_entrypoints [ arch_name ]
def list_modules ( ) :
def list_modules ( ) - > List [ str ] :
""" Return list of module names that contain models / model entrypoints
"""
modules = _module_to_models . keys ( )
return list( sorted( modules ) )
return sorted( modules )
def is_model_in_modules ( model_name , module_names ) :
def is_model_in_modules (
model_name : str , module_names : Union [ Tuple [ str , . . . ] , List [ str ] , Set [ str ] ]
) - > bool :
""" Check if a model exists within a subset of modules
Args :
model_name ( str ) - name of model to check
module_names ( tuple , list , set ) - names of modules to search in
model_name - name of model to check
module_names - names of modules to search in
"""
arch_name = get_arch_name ( model_name )
assert isinstance ( module_names , ( tuple , list , set ) )
return any ( arch_name in _module_to_models [ n ] for n in module_names )
def is_model_pretrained ( model_name ):
def is_model_pretrained ( model_name : str ) - > bool :
return model_name in _model_has_pretrained
def get_pretrained_cfg ( model_name , allow_unregistered = True ) :
def get_pretrained_cfg ( model_name : str , allow_unregistered : bool = True ) - > Optional [ PretrainedCfg ] :
if model_name in _model_pretrained_cfgs :
return deepcopy ( _model_pretrained_cfgs [ model_name ] )
arch_name , tag = split_model_name_tag ( model_name )
@ -219,7 +227,7 @@ def get_pretrained_cfg(model_name, allow_unregistered=True):
raise RuntimeError ( f ' Model architecture ( { arch_name } ) has no pretrained cfg registered. ' )
def get_pretrained_cfg_value ( model_name , cfg_key ):
def get_pretrained_cfg_value ( model_name : str , cfg_key : str ) - > Optional [ Any ] :
""" Get a specific model default_cfg value by key. None if key doesn ' t exist.
"""
cfg = get_pretrained_cfg ( model_name , allow_unregistered = False )