|
|
|
@ -66,8 +66,8 @@ def list_models(filter='', module='', pretrained=False, exclude_filters='', name
|
|
|
|
|
name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases)
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
|
|
|
|
|
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
|
|
|
|
|
list_models('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
|
|
|
|
|
list_models('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
|
|
|
|
|
"""
|
|
|
|
|
if module:
|
|
|
|
|
all_models = list(_module_to_models[module])
|
|
|
|
@ -96,6 +96,25 @@ def list_models(filter='', module='', pretrained=False, exclude_filters='', name
|
|
|
|
|
return list(sorted(models, key=_natural_key))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def list_benchmarks(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False):
|
|
|
|
|
""" Return list of available benchmarks on imagenet, sorted alphabetically
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
filter (str) - Wildcard filter string that works with fnmatch
|
|
|
|
|
module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet')
|
|
|
|
|
pretrained (bool) - Include only models with pretrained weights if True
|
|
|
|
|
exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
|
|
|
|
|
name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases)
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
list_benchmarks('gluon_resnet*') -- returns a pandas dataframe with all the benchmarks starting with 'gluon_resnet'
|
|
|
|
|
"""
|
|
|
|
|
models = list_models(filter=filter, module=module, pretrained=pretrained, exclude_filters=exclude_filters, name_matches_cfg=name_matches_cfg)
|
|
|
|
|
df = pd.read_csv("https://raw.githubusercontent.com/rwightman/pytorch-image-models/master/results/results-imagenet.csv")
|
|
|
|
|
|
|
|
|
|
return df[df["model"].isin(models)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_model(model_name):
|
|
|
|
|
""" Check if a model name exists
|
|
|
|
|
"""
|
|
|
|
@ -156,4 +175,4 @@ def get_pretrained_cfg_value(model_name, cfg_key):
|
|
|
|
|
"""
|
|
|
|
|
if model_name in _model_pretrained_cfgs:
|
|
|
|
|
return _model_pretrained_cfgs[model_name].get(cfg_key, None)
|
|
|
|
|
return None
|
|
|
|
|
return None
|
|
|
|
|