Add bulk_runner script and updates to benchmark.py and validate.py for better error handling in bulk runs (used for benchmark and validation result runs). Improved batch size decay stepping on retry...
parent
4547920f85
commit
0dbd9352ce
@ -0,0 +1,184 @@
|
||||
#!/usr/bin/env python3
|
||||
""" Bulk Model Script Runner
|
||||
|
||||
Run validation or benchmark script in separate process for each model
|
||||
|
||||
Benchmark all 'vit*' models:
|
||||
python bulk_runner.py --model-list 'vit*' --results-file vit_bench.csv benchmark.py --amp -b 512
|
||||
|
||||
Validate all models:
|
||||
python bulk_runner.py --model-list all --results-file val.csv --pretrained validate.py /imagenet/validation/ --amp -b 512 --retry
|
||||
|
||||
Hacked together by Ross Wightman (https://github.com/rwightman)
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import csv
|
||||
import json
|
||||
import subprocess
|
||||
import time
|
||||
from typing import Callable, List, Tuple, Union
|
||||
|
||||
|
||||
from timm.models import is_model, list_models
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='Per-model process launcher')
|
||||
|
||||
# model and results args
|
||||
parser.add_argument(
|
||||
'--model-list', metavar='NAME', default='',
|
||||
help='txt file based list of model names to benchmark')
|
||||
parser.add_argument(
|
||||
'--results-file', default='', type=str, metavar='FILENAME',
|
||||
help='Output csv file for validation results (summary)')
|
||||
parser.add_argument(
|
||||
'--sort-key', default='', type=str, metavar='COL',
|
||||
help='Specify sort key for results csv')
|
||||
parser.add_argument(
|
||||
"--pretrained", action='store_true',
|
||||
help="only run models with pretrained weights")
|
||||
|
||||
parser.add_argument(
|
||||
"--delay",
|
||||
type=float,
|
||||
default=0,
|
||||
help="Interval, in seconds, to delay between model invocations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--start_method", type=str, default="spawn", choices=["spawn", "fork", "forkserver"],
|
||||
help="Multiprocessing start method to use when creating workers.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_python",
|
||||
help="Skip prepending the script with 'python' - just execute it directly. Useful "
|
||||
"when the script is not a Python script.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--module",
|
||||
help="Change each process to interpret the launch script as a Python module, executing "
|
||||
"with the same behavior as 'python -m'.",
|
||||
)
|
||||
|
||||
# positional
|
||||
parser.add_argument(
|
||||
"script", type=str,
|
||||
help="Full path to the program/script to be launched for each model config.",
|
||||
)
|
||||
parser.add_argument("script_args", nargs=argparse.REMAINDER)
|
||||
|
||||
|
||||
def cmd_from_args(args) -> Tuple[Union[Callable, str], List[str]]:
|
||||
# If ``args`` not passed, defaults to ``sys.argv[:1]``
|
||||
with_python = not args.no_python
|
||||
cmd: Union[Callable, str]
|
||||
cmd_args = []
|
||||
if with_python:
|
||||
cmd = os.getenv("PYTHON_EXEC", sys.executable)
|
||||
cmd_args.append("-u")
|
||||
if args.module:
|
||||
cmd_args.append("-m")
|
||||
cmd_args.append(args.script)
|
||||
else:
|
||||
if args.module:
|
||||
raise ValueError(
|
||||
"Don't use both the '--no_python' flag"
|
||||
" and the '--module' flag at the same time."
|
||||
)
|
||||
cmd = args.script
|
||||
cmd_args.extend(args.script_args)
|
||||
|
||||
return cmd, cmd_args
|
||||
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
cmd, cmd_args = cmd_from_args(args)
|
||||
|
||||
model_cfgs = []
|
||||
model_names = []
|
||||
if args.model_list == 'all':
|
||||
# NOTE should make this config, for validation / benchmark runs the focus is 1k models,
|
||||
# so we filter out 21/22k and some other unusable heads. This will change in the future...
|
||||
exclude_model_filters = ['*in21k', '*in22k', '*dino', '*_22k']
|
||||
model_names = list_models(
|
||||
pretrained=args.pretrained, # only include models w/ pretrained checkpoints if set
|
||||
exclude_filters=exclude_model_filters
|
||||
)
|
||||
model_cfgs = [(n, None) for n in model_names]
|
||||
elif not is_model(args.model_list):
|
||||
# model name doesn't exist, try as wildcard filter
|
||||
model_names = list_models(args.model_list)
|
||||
model_cfgs = [(n, None) for n in model_names]
|
||||
|
||||
if not model_cfgs and os.path.exists(args.model_list):
|
||||
with open(args.model_list) as f:
|
||||
model_names = [line.rstrip() for line in f]
|
||||
model_cfgs = [(n, None) for n in model_names]
|
||||
|
||||
if len(model_cfgs):
|
||||
results_file = args.results_file or './results.csv'
|
||||
results = []
|
||||
errors = []
|
||||
print('Running script on these models: {}'.format(', '.join(model_names)))
|
||||
if not args.sort_key:
|
||||
if 'benchmark' in args.script:
|
||||
if any(['train' in a for a in args.script_args]):
|
||||
sort_key = 'train_samples_per_sec'
|
||||
else:
|
||||
sort_key = 'infer_samples_per_sec'
|
||||
else:
|
||||
sort_key = 'top1'
|
||||
else:
|
||||
sort_key = args.sort_key
|
||||
print(f'Script: {args.script}, Args: {args.script_args}, Sort key: {sort_key}')
|
||||
|
||||
try:
|
||||
for m, _ in model_cfgs:
|
||||
if not m:
|
||||
continue
|
||||
args_str = (cmd, *[str(e) for e in cmd_args], '--model', m)
|
||||
try:
|
||||
o = subprocess.check_output(args=args_str).decode('utf-8').split('--result')[-1]
|
||||
r = json.loads(o)
|
||||
results.append(r)
|
||||
except Exception as e:
|
||||
# FIXME batch_size retry loop is currently done in either validation.py or benchmark.py
|
||||
# for further robustness (but more overhead), we may want to manage that by looping here...
|
||||
errors.append(dict(model=m, error=str(e)))
|
||||
if args.delay:
|
||||
time.sleep(args.delay)
|
||||
except KeyboardInterrupt as e:
|
||||
pass
|
||||
|
||||
errors.extend(list(filter(lambda x: 'error' in x, results)))
|
||||
if errors:
|
||||
print(f'{len(errors)} models had errors during run.')
|
||||
for e in errors:
|
||||
print(f"\t {e['model']} ({e.get('error', 'Unknown')})")
|
||||
results = list(filter(lambda x: 'error' not in x, results))
|
||||
|
||||
no_sortkey = list(filter(lambda x: sort_key not in x, results))
|
||||
if no_sortkey:
|
||||
print(f'{len(no_sortkey)} results missing sort key, skipping sort.')
|
||||
else:
|
||||
results = sorted(results, key=lambda x: x[sort_key], reverse=True)
|
||||
|
||||
if len(results):
|
||||
print(f'{len(results)} models run successfully. Saving results to {results_file}.')
|
||||
write_results(results_file, results)
|
||||
|
||||
|
||||
def write_results(results_file, results):
|
||||
with open(results_file, mode='w') as cf:
|
||||
dw = csv.DictWriter(cf, fieldnames=results[0].keys())
|
||||
dw.writeheader()
|
||||
for r in results:
|
||||
dw.writerow(r)
|
||||
cf.flush()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -0,0 +1,43 @@
|
||||
""" Batch size decay and retry helpers.
|
||||
|
||||
Copyright 2022 Ross Wightman
|
||||
"""
|
||||
import math
|
||||
|
||||
|
||||
def decay_batch_step(batch_size, num_intra_steps=2, no_odd=False):
|
||||
""" power of two batch-size decay with intra steps
|
||||
|
||||
Decay by stepping between powers of 2:
|
||||
* determine power-of-2 floor of current batch size (base batch size)
|
||||
* divide above value by num_intra_steps to determine step size
|
||||
* floor batch_size to nearest multiple of step_size (from base batch size)
|
||||
Examples:
|
||||
num_steps == 4 --> 64, 56, 48, 40, 32, 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1
|
||||
num_steps (no_odd=True) == 4 --> 64, 56, 48, 40, 32, 28, 24, 20, 16, 14, 12, 10, 8, 6, 4, 2
|
||||
num_steps == 2 --> 64, 48, 32, 24, 16, 12, 8, 6, 4, 3, 2, 1
|
||||
num_steps == 1 --> 64, 32, 16, 8, 4, 2, 1
|
||||
"""
|
||||
if batch_size <= 1:
|
||||
# return 0 for stopping value so easy to use in loop
|
||||
return 0
|
||||
base_batch_size = int(2 ** (math.log(batch_size - 1) // math.log(2)))
|
||||
step_size = max(base_batch_size // num_intra_steps, 1)
|
||||
batch_size = base_batch_size + ((batch_size - base_batch_size - 1) // step_size) * step_size
|
||||
if no_odd and batch_size % 2:
|
||||
batch_size -= 1
|
||||
return batch_size
|
||||
|
||||
|
||||
def check_batch_size_retry(error_str):
|
||||
""" check failure error string for conditions where batch decay retry should not be attempted
|
||||
"""
|
||||
error_str = error_str.lower()
|
||||
if 'required rank' in error_str:
|
||||
# Errors involving phrase 'required rank' typically happen when a conv is used that's
|
||||
# not compatible with channels_last memory format.
|
||||
return False
|
||||
if 'illegal' in error_str:
|
||||
# 'Illegal memory access' errors in CUDA typically leave process in unusable state
|
||||
return False
|
||||
return True
|
Loading…
Reference in new issue