pull/1253/merge
Mohamed Rashad 3 years ago committed by GitHub
commit 744e010a2b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,3 +1,4 @@
torch>=1.4.0 torch>=1.4.0
torchvision>=0.5.0 torchvision>=0.5.0
pyyaml pyyaml
pandas

@ -7,6 +7,7 @@ import re
import fnmatch import fnmatch
from collections import defaultdict from collections import defaultdict
from copy import deepcopy from copy import deepcopy
import pandas as pd
__all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', __all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
'is_pretrained_cfg_key', 'has_pretrained_cfg_key', 'get_pretrained_cfg_value', 'is_model_pretrained'] 'is_pretrained_cfg_key', 'has_pretrained_cfg_key', 'get_pretrained_cfg_value', 'is_model_pretrained']
@ -66,8 +67,8 @@ def list_models(filter='', module='', pretrained=False, exclude_filters='', name
name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases) name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases)
Example: Example:
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' list_models('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module list_models('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
""" """
if module: if module:
all_models = list(_module_to_models[module]) all_models = list(_module_to_models[module])
@ -96,6 +97,25 @@ def list_models(filter='', module='', pretrained=False, exclude_filters='', name
return list(sorted(models, key=_natural_key)) return list(sorted(models, key=_natural_key))
def list_benchmarks(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False):
""" Return list of available benchmarks on imagenet, sorted alphabetically
Args:
filter (str) - Wildcard filter string that works with fnmatch
module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet')
pretrained (bool) - Include only models with pretrained weights if True
exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases)
Example:
list_benchmarks('gluon_resnet*') -- returns a pandas dataframe with all the benchmarks starting with 'gluon_resnet'
"""
models = list_models(filter=filter, module=module, pretrained=pretrained, exclude_filters=exclude_filters, name_matches_cfg=name_matches_cfg)
df = pd.read_csv("https://raw.githubusercontent.com/rwightman/pytorch-image-models/master/results/results-imagenet.csv")
return df[df["model"].isin(models)]
def is_model(model_name): def is_model(model_name):
""" Check if a model name exists """ Check if a model name exists
""" """
@ -156,4 +176,4 @@ def get_pretrained_cfg_value(model_name, cfg_key):
""" """
if model_name in _model_pretrained_cfgs: if model_name in _model_pretrained_cfgs:
return _model_pretrained_cfgs[model_name].get(cfg_key, None) return _model_pretrained_cfgs[model_name].get(cfg_key, None)
return None return None

Loading…
Cancel
Save