You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
185 lines
6.4 KiB
185 lines
6.4 KiB
2 years ago
|
#!/usr/bin/env python3
|
||
|
""" Bulk Model Script Runner
|
||
|
|
||
|
Run validation or benchmark script in separate process for each model
|
||
|
|
||
|
Benchmark all 'vit*' models:
|
||
|
python bulk_runner.py --model-list 'vit*' --results-file vit_bench.csv benchmark.py --amp -b 512
|
||
|
|
||
|
Validate all models:
|
||
|
python bulk_runner.py --model-list all --results-file val.csv --pretrained validate.py /imagenet/validation/ --amp -b 512 --retry
|
||
|
|
||
|
Hacked together by Ross Wightman (https://github.com/rwightman)
|
||
|
"""
|
||
|
import argparse
|
||
|
import os
|
||
|
import sys
|
||
|
import csv
|
||
|
import json
|
||
|
import subprocess
|
||
|
import time
|
||
|
from typing import Callable, List, Tuple, Union
|
||
|
|
||
|
|
||
|
from timm.models import is_model, list_models
|
||
|
|
||
|
|
||
|
parser = argparse.ArgumentParser(description='Per-model process launcher')
|
||
|
|
||
|
# model and results args
|
||
|
parser.add_argument(
|
||
|
'--model-list', metavar='NAME', default='',
|
||
|
help='txt file based list of model names to benchmark')
|
||
|
parser.add_argument(
|
||
|
'--results-file', default='', type=str, metavar='FILENAME',
|
||
|
help='Output csv file for validation results (summary)')
|
||
|
parser.add_argument(
|
||
|
'--sort-key', default='', type=str, metavar='COL',
|
||
|
help='Specify sort key for results csv')
|
||
|
parser.add_argument(
|
||
|
"--pretrained", action='store_true',
|
||
|
help="only run models with pretrained weights")
|
||
|
|
||
|
parser.add_argument(
|
||
|
"--delay",
|
||
|
type=float,
|
||
|
default=0,
|
||
|
help="Interval, in seconds, to delay between model invocations.",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--start_method", type=str, default="spawn", choices=["spawn", "fork", "forkserver"],
|
||
|
help="Multiprocessing start method to use when creating workers.",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--no_python",
|
||
|
help="Skip prepending the script with 'python' - just execute it directly. Useful "
|
||
|
"when the script is not a Python script.",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"-m",
|
||
|
"--module",
|
||
|
help="Change each process to interpret the launch script as a Python module, executing "
|
||
|
"with the same behavior as 'python -m'.",
|
||
|
)
|
||
|
|
||
|
# positional
|
||
|
parser.add_argument(
|
||
|
"script", type=str,
|
||
|
help="Full path to the program/script to be launched for each model config.",
|
||
|
)
|
||
|
parser.add_argument("script_args", nargs=argparse.REMAINDER)
|
||
|
|
||
|
|
||
|
def cmd_from_args(args) -> Tuple[Union[Callable, str], List[str]]:
|
||
|
# If ``args`` not passed, defaults to ``sys.argv[:1]``
|
||
|
with_python = not args.no_python
|
||
|
cmd: Union[Callable, str]
|
||
|
cmd_args = []
|
||
|
if with_python:
|
||
|
cmd = os.getenv("PYTHON_EXEC", sys.executable)
|
||
|
cmd_args.append("-u")
|
||
|
if args.module:
|
||
|
cmd_args.append("-m")
|
||
|
cmd_args.append(args.script)
|
||
|
else:
|
||
|
if args.module:
|
||
|
raise ValueError(
|
||
|
"Don't use both the '--no_python' flag"
|
||
|
" and the '--module' flag at the same time."
|
||
|
)
|
||
|
cmd = args.script
|
||
|
cmd_args.extend(args.script_args)
|
||
|
|
||
|
return cmd, cmd_args
|
||
|
|
||
|
|
||
|
def main():
|
||
|
args = parser.parse_args()
|
||
|
cmd, cmd_args = cmd_from_args(args)
|
||
|
|
||
|
model_cfgs = []
|
||
|
model_names = []
|
||
|
if args.model_list == 'all':
|
||
|
# NOTE should make this config, for validation / benchmark runs the focus is 1k models,
|
||
|
# so we filter out 21/22k and some other unusable heads. This will change in the future...
|
||
|
exclude_model_filters = ['*in21k', '*in22k', '*dino', '*_22k']
|
||
|
model_names = list_models(
|
||
|
pretrained=args.pretrained, # only include models w/ pretrained checkpoints if set
|
||
|
exclude_filters=exclude_model_filters
|
||
|
)
|
||
|
model_cfgs = [(n, None) for n in model_names]
|
||
|
elif not is_model(args.model_list):
|
||
|
# model name doesn't exist, try as wildcard filter
|
||
|
model_names = list_models(args.model_list)
|
||
|
model_cfgs = [(n, None) for n in model_names]
|
||
|
|
||
|
if not model_cfgs and os.path.exists(args.model_list):
|
||
|
with open(args.model_list) as f:
|
||
|
model_names = [line.rstrip() for line in f]
|
||
|
model_cfgs = [(n, None) for n in model_names]
|
||
|
|
||
|
if len(model_cfgs):
|
||
|
results_file = args.results_file or './results.csv'
|
||
|
results = []
|
||
|
errors = []
|
||
|
print('Running script on these models: {}'.format(', '.join(model_names)))
|
||
|
if not args.sort_key:
|
||
|
if 'benchmark' in args.script:
|
||
|
if any(['train' in a for a in args.script_args]):
|
||
|
sort_key = 'train_samples_per_sec'
|
||
|
else:
|
||
|
sort_key = 'infer_samples_per_sec'
|
||
|
else:
|
||
|
sort_key = 'top1'
|
||
|
else:
|
||
|
sort_key = args.sort_key
|
||
|
print(f'Script: {args.script}, Args: {args.script_args}, Sort key: {sort_key}')
|
||
|
|
||
|
try:
|
||
|
for m, _ in model_cfgs:
|
||
|
if not m:
|
||
|
continue
|
||
|
args_str = (cmd, *[str(e) for e in cmd_args], '--model', m)
|
||
|
try:
|
||
|
o = subprocess.check_output(args=args_str).decode('utf-8').split('--result')[-1]
|
||
|
r = json.loads(o)
|
||
|
results.append(r)
|
||
|
except Exception as e:
|
||
|
# FIXME batch_size retry loop is currently done in either validation.py or benchmark.py
|
||
|
# for further robustness (but more overhead), we may want to manage that by looping here...
|
||
|
errors.append(dict(model=m, error=str(e)))
|
||
|
if args.delay:
|
||
|
time.sleep(args.delay)
|
||
|
except KeyboardInterrupt as e:
|
||
|
pass
|
||
|
|
||
|
errors.extend(list(filter(lambda x: 'error' in x, results)))
|
||
|
if errors:
|
||
|
print(f'{len(errors)} models had errors during run.')
|
||
|
for e in errors:
|
||
|
print(f"\t {e['model']} ({e.get('error', 'Unknown')})")
|
||
|
results = list(filter(lambda x: 'error' not in x, results))
|
||
|
|
||
|
no_sortkey = list(filter(lambda x: sort_key not in x, results))
|
||
|
if no_sortkey:
|
||
|
print(f'{len(no_sortkey)} results missing sort key, skipping sort.')
|
||
|
else:
|
||
|
results = sorted(results, key=lambda x: x[sort_key], reverse=True)
|
||
|
|
||
|
if len(results):
|
||
|
print(f'{len(results)} models run successfully. Saving results to {results_file}.')
|
||
|
write_results(results_file, results)
|
||
|
|
||
|
|
||
|
def write_results(results_file, results):
|
||
|
with open(results_file, mode='w') as cf:
|
||
|
dw = csv.DictWriter(cf, fieldnames=results[0].keys())
|
||
|
dw.writeheader()
|
||
|
for r in results:
|
||
|
dw.writerow(r)
|
||
|
cf.flush()
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main()
|