Add `list_benchmarks` to compare models easily

This is a simple function for listing ImageNet benchmarks found [here](https://github.com/rwightman/pytorch-image-models/blob/master/results/results-imagenet.csv) for users to use
pull/1253/head
Mohamed Rashad 3 years ago committed by GitHub
parent 6d4665bb52
commit 23aeb88b57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -66,8 +66,8 @@ def list_models(filter='', module='', pretrained=False, exclude_filters='', name
name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases)
Example:
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
list_models('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
list_models('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
"""
if module:
all_models = list(_module_to_models[module])
@ -96,6 +96,25 @@ def list_models(filter='', module='', pretrained=False, exclude_filters='', name
return list(sorted(models, key=_natural_key))
def list_benchmarks(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False):
""" Return list of available benchmarks on imagenet, sorted alphabetically
Args:
filter (str) - Wildcard filter string that works with fnmatch
module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet')
pretrained (bool) - Include only models with pretrained weights if True
exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases)
Example:
list_benchmarks('gluon_resnet*') -- returns a pandas dataframe with all the benchmarks starting with 'gluon_resnet'
"""
models = list_models(filter=filter, module=module, pretrained=pretrained, exclude_filters=exclude_filters, name_matches_cfg=name_matches_cfg)
df = pd.read_csv("https://raw.githubusercontent.com/rwightman/pytorch-image-models/master/results/results-imagenet.csv")
return df[df["model"].isin(models)]
def is_model(model_name):
""" Check if a model name exists
"""
@ -156,4 +175,4 @@ def get_pretrained_cfg_value(model_name, cfg_key):
"""
if model_name in _model_pretrained_cfgs:
return _model_pretrained_cfgs[model_name].get(cfg_key, None)
return None
return None

Loading…
Cancel
Save