diff --git a/requirements.txt b/requirements.txt index 2d29a27c..93884552 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ torch>=1.4.0 torchvision>=0.5.0 pyyaml +pandas diff --git a/timm/models/registry.py b/timm/models/registry.py index 9f58060f..8a1a8bdb 100644 --- a/timm/models/registry.py +++ b/timm/models/registry.py @@ -7,6 +7,7 @@ import re import fnmatch from collections import defaultdict from copy import deepcopy +import pandas as pd __all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', 'is_pretrained_cfg_key', 'has_pretrained_cfg_key', 'get_pretrained_cfg_value', 'is_model_pretrained'] @@ -66,8 +67,8 @@ def list_models(filter='', module='', pretrained=False, exclude_filters='', name name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases) Example: - model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' - model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module + list_models('gluon_resnet*') -- returns all models starting with 'gluon_resnet' + list_models('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module """ if module: all_models = list(_module_to_models[module]) @@ -96,6 +97,25 @@ def list_models(filter='', module='', pretrained=False, exclude_filters='', name return list(sorted(models, key=_natural_key)) +def list_benchmarks(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False): + """ Return list of available benchmarks on imagenet, sorted alphabetically + + Args: + filter (str) - Wildcard filter string that works with fnmatch + module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet') + pretrained (bool) - Include only models with pretrained weights if True + exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter + name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases) + + Example: + list_benchmarks('gluon_resnet*') -- returns a pandas dataframe with all the benchmarks starting with 'gluon_resnet' + """ + models = list_models(filter=filter, module=module, pretrained=pretrained, exclude_filters=exclude_filters, name_matches_cfg=name_matches_cfg) + df = pd.read_csv("https://raw.githubusercontent.com/rwightman/pytorch-image-models/master/results/results-imagenet.csv") + + return df[df["model"].isin(models)] + + def is_model(model_name): """ Check if a model name exists """ @@ -156,4 +176,4 @@ def get_pretrained_cfg_value(model_name, cfg_key): """ if model_name in _model_pretrained_cfgs: return _model_pretrained_cfgs[model_name].get(cfg_key, None) - return None \ No newline at end of file + return None