From 23aeb88b57adeb55cb018e2f688fa4de1a17a843 Mon Sep 17 00:00:00 2001 From: Mohamed Rashad Date: Fri, 6 May 2022 23:03:07 +0200 Subject: [PATCH] Add `list_benchmarks` to compare models easily This is a simple function for listing ImageNet benchmarks found [here](https://github.com/rwightman/pytorch-image-models/blob/master/results/results-imagenet.csv) for users to use --- timm/models/registry.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/timm/models/registry.py b/timm/models/registry.py index 9f58060f..91746db7 100644 --- a/timm/models/registry.py +++ b/timm/models/registry.py @@ -66,8 +66,8 @@ def list_models(filter='', module='', pretrained=False, exclude_filters='', name name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases) Example: - model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' - model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module + list_models('gluon_resnet*') -- returns all models starting with 'gluon_resnet' + list_models('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module """ if module: all_models = list(_module_to_models[module]) @@ -96,6 +96,25 @@ def list_models(filter='', module='', pretrained=False, exclude_filters='', name return list(sorted(models, key=_natural_key)) +def list_benchmarks(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False): + """ Return list of available benchmarks on imagenet, sorted alphabetically + + Args: + filter (str) - Wildcard filter string that works with fnmatch + module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet') + pretrained (bool) - Include only models with pretrained weights if True + exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter + name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases) + + Example: + list_benchmarks('gluon_resnet*') -- returns a pandas dataframe with all the benchmarks starting with 'gluon_resnet' + """ + models = list_models(filter=filter, module=module, pretrained=pretrained, exclude_filters=exclude_filters, name_matches_cfg=name_matches_cfg) + df = pd.read_csv("https://raw.githubusercontent.com/rwightman/pytorch-image-models/master/results/results-imagenet.csv") + + return df[df["model"].isin(models)] + + def is_model(model_name): """ Check if a model name exists """ @@ -156,4 +175,4 @@ def get_pretrained_cfg_value(model_name, cfg_key): """ if model_name in _model_pretrained_cfgs: return _model_pretrained_cfgs[model_name].get(cfg_key, None) - return None \ No newline at end of file + return None