Add bulk_runner script and updates to benchmark.py and validate.py for better error handling in bulk runs (used for benchmark and validation result runs). Improved batch size decay stepping on retry...

pull/1363/head
Ross Wightman 2 years ago
parent 4547920f85
commit 0dbd9352ce

@ -21,7 +21,7 @@ import torch.nn.parallel
from timm.data import resolve_data_config from timm.data import resolve_data_config
from timm.models import create_model, is_model, list_models from timm.models import create_model, is_model, list_models
from timm.optim import create_optimizer_v2 from timm.optim import create_optimizer_v2
from timm.utils import setup_default_logging, set_jit_fuser from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry
has_apex = False has_apex = False
try: try:
@ -506,34 +506,31 @@ class ProfileRunner(BenchmarkRunner):
return results return results
def decay_batch_exp(batch_size, factor=0.5, divisor=16): def _try_run(
out_batch_size = batch_size * factor model_name,
if out_batch_size > divisor: bench_fn,
out_batch_size = (out_batch_size + 1) // divisor * divisor bench_kwargs,
else: initial_batch_size,
out_batch_size = batch_size - 1 no_batch_size_retry=False
return max(0, int(out_batch_size)) ):
def _try_run(model_name, bench_fn, bench_kwargs, initial_batch_size, no_batch_size_retry=False):
batch_size = initial_batch_size batch_size = initial_batch_size
results = dict() results = dict()
error_str = 'Unknown' error_str = 'Unknown'
while batch_size >= 1: while batch_size:
torch.cuda.empty_cache()
try: try:
torch.cuda.empty_cache()
bench = bench_fn(model_name=model_name, batch_size=batch_size, **bench_kwargs) bench = bench_fn(model_name=model_name, batch_size=batch_size, **bench_kwargs)
results = bench.run() results = bench.run()
return results return results
except RuntimeError as e: except RuntimeError as e:
error_str = str(e) error_str = str(e)
if 'channels_last' in error_str:
_logger.error(f'{model_name} not supported in channels_last, skipping.')
break
_logger.error(f'"{error_str}" while running benchmark.') _logger.error(f'"{error_str}" while running benchmark.')
if not check_batch_size_retry(error_str):
_logger.error(f'Unrecoverable error encountered while benchmarking {model_name}, skipping.')
break
if no_batch_size_retry: if no_batch_size_retry:
break break
batch_size = decay_batch_exp(batch_size) batch_size = decay_batch_step(batch_size)
_logger.warning(f'Reducing batch size to {batch_size} for retry.') _logger.warning(f'Reducing batch size to {batch_size} for retry.')
results['error'] = error_str results['error'] = error_str
return results return results
@ -586,6 +583,8 @@ def benchmark(args):
if prefix and 'error' not in run_results: if prefix and 'error' not in run_results:
run_results = {'_'.join([prefix, k]): v for k, v in run_results.items()} run_results = {'_'.join([prefix, k]): v for k, v in run_results.items()}
model_results.update(run_results) model_results.update(run_results)
if 'error' in run_results:
break
if 'error' not in model_results: if 'error' not in model_results:
param_count = model_results.pop('infer_param_count', model_results.pop('train_param_count', 0)) param_count = model_results.pop('infer_param_count', model_results.pop('train_param_count', 0))
model_results.setdefault('param_count', param_count) model_results.setdefault('param_count', param_count)

@ -0,0 +1,184 @@
#!/usr/bin/env python3
""" Bulk Model Script Runner
Run validation or benchmark script in separate process for each model
Benchmark all 'vit*' models:
python bulk_runner.py --model-list 'vit*' --results-file vit_bench.csv benchmark.py --amp -b 512
Validate all models:
python bulk_runner.py --model-list all --results-file val.csv --pretrained validate.py /imagenet/validation/ --amp -b 512 --retry
Hacked together by Ross Wightman (https://github.com/rwightman)
"""
import argparse
import os
import sys
import csv
import json
import subprocess
import time
from typing import Callable, List, Tuple, Union
from timm.models import is_model, list_models
parser = argparse.ArgumentParser(description='Per-model process launcher')
# model and results args
parser.add_argument(
'--model-list', metavar='NAME', default='',
help='txt file based list of model names to benchmark')
parser.add_argument(
'--results-file', default='', type=str, metavar='FILENAME',
help='Output csv file for validation results (summary)')
parser.add_argument(
'--sort-key', default='', type=str, metavar='COL',
help='Specify sort key for results csv')
parser.add_argument(
"--pretrained", action='store_true',
help="only run models with pretrained weights")
parser.add_argument(
"--delay",
type=float,
default=0,
help="Interval, in seconds, to delay between model invocations.",
)
parser.add_argument(
"--start_method", type=str, default="spawn", choices=["spawn", "fork", "forkserver"],
help="Multiprocessing start method to use when creating workers.",
)
parser.add_argument(
"--no_python",
help="Skip prepending the script with 'python' - just execute it directly. Useful "
"when the script is not a Python script.",
)
parser.add_argument(
"-m",
"--module",
help="Change each process to interpret the launch script as a Python module, executing "
"with the same behavior as 'python -m'.",
)
# positional
parser.add_argument(
"script", type=str,
help="Full path to the program/script to be launched for each model config.",
)
parser.add_argument("script_args", nargs=argparse.REMAINDER)
def cmd_from_args(args) -> Tuple[Union[Callable, str], List[str]]:
# If ``args`` not passed, defaults to ``sys.argv[:1]``
with_python = not args.no_python
cmd: Union[Callable, str]
cmd_args = []
if with_python:
cmd = os.getenv("PYTHON_EXEC", sys.executable)
cmd_args.append("-u")
if args.module:
cmd_args.append("-m")
cmd_args.append(args.script)
else:
if args.module:
raise ValueError(
"Don't use both the '--no_python' flag"
" and the '--module' flag at the same time."
)
cmd = args.script
cmd_args.extend(args.script_args)
return cmd, cmd_args
def main():
args = parser.parse_args()
cmd, cmd_args = cmd_from_args(args)
model_cfgs = []
model_names = []
if args.model_list == 'all':
# NOTE should make this config, for validation / benchmark runs the focus is 1k models,
# so we filter out 21/22k and some other unusable heads. This will change in the future...
exclude_model_filters = ['*in21k', '*in22k', '*dino', '*_22k']
model_names = list_models(
pretrained=args.pretrained, # only include models w/ pretrained checkpoints if set
exclude_filters=exclude_model_filters
)
model_cfgs = [(n, None) for n in model_names]
elif not is_model(args.model_list):
# model name doesn't exist, try as wildcard filter
model_names = list_models(args.model_list)
model_cfgs = [(n, None) for n in model_names]
if not model_cfgs and os.path.exists(args.model_list):
with open(args.model_list) as f:
model_names = [line.rstrip() for line in f]
model_cfgs = [(n, None) for n in model_names]
if len(model_cfgs):
results_file = args.results_file or './results.csv'
results = []
errors = []
print('Running script on these models: {}'.format(', '.join(model_names)))
if not args.sort_key:
if 'benchmark' in args.script:
if any(['train' in a for a in args.script_args]):
sort_key = 'train_samples_per_sec'
else:
sort_key = 'infer_samples_per_sec'
else:
sort_key = 'top1'
else:
sort_key = args.sort_key
print(f'Script: {args.script}, Args: {args.script_args}, Sort key: {sort_key}')
try:
for m, _ in model_cfgs:
if not m:
continue
args_str = (cmd, *[str(e) for e in cmd_args], '--model', m)
try:
o = subprocess.check_output(args=args_str).decode('utf-8').split('--result')[-1]
r = json.loads(o)
results.append(r)
except Exception as e:
# FIXME batch_size retry loop is currently done in either validation.py or benchmark.py
# for further robustness (but more overhead), we may want to manage that by looping here...
errors.append(dict(model=m, error=str(e)))
if args.delay:
time.sleep(args.delay)
except KeyboardInterrupt as e:
pass
errors.extend(list(filter(lambda x: 'error' in x, results)))
if errors:
print(f'{len(errors)} models had errors during run.')
for e in errors:
print(f"\t {e['model']} ({e.get('error', 'Unknown')})")
results = list(filter(lambda x: 'error' not in x, results))
no_sortkey = list(filter(lambda x: sort_key not in x, results))
if no_sortkey:
print(f'{len(no_sortkey)} results missing sort key, skipping sort.')
else:
results = sorted(results, key=lambda x: x[sort_key], reverse=True)
if len(results):
print(f'{len(results)} models run successfully. Saving results to {results_file}.')
write_results(results_file, results)
def write_results(results_file, results):
with open(results_file, mode='w') as cf:
dw = csv.DictWriter(cf, fieldnames=results[0].keys())
dw.writeheader()
for r in results:
dw.writerow(r)
cf.flush()
if __name__ == '__main__':
main()

@ -2,6 +2,7 @@ from .agc import adaptive_clip_grad
from .checkpoint_saver import CheckpointSaver from .checkpoint_saver import CheckpointSaver
from .clip_grad import dispatch_clip_grad from .clip_grad import dispatch_clip_grad
from .cuda import ApexScaler, NativeScaler from .cuda import ApexScaler, NativeScaler
from .decay_batch import decay_batch_step, check_batch_size_retry
from .distributed import distribute_bn, reduce_tensor from .distributed import distribute_bn, reduce_tensor
from .jit import set_jit_legacy, set_jit_fuser from .jit import set_jit_legacy, set_jit_fuser
from .log import setup_default_logging, FormatterNoInfo from .log import setup_default_logging, FormatterNoInfo

@ -0,0 +1,43 @@
""" Batch size decay and retry helpers.
Copyright 2022 Ross Wightman
"""
import math
def decay_batch_step(batch_size, num_intra_steps=2, no_odd=False):
""" power of two batch-size decay with intra steps
Decay by stepping between powers of 2:
* determine power-of-2 floor of current batch size (base batch size)
* divide above value by num_intra_steps to determine step size
* floor batch_size to nearest multiple of step_size (from base batch size)
Examples:
num_steps == 4 --> 64, 56, 48, 40, 32, 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1
num_steps (no_odd=True) == 4 --> 64, 56, 48, 40, 32, 28, 24, 20, 16, 14, 12, 10, 8, 6, 4, 2
num_steps == 2 --> 64, 48, 32, 24, 16, 12, 8, 6, 4, 3, 2, 1
num_steps == 1 --> 64, 32, 16, 8, 4, 2, 1
"""
if batch_size <= 1:
# return 0 for stopping value so easy to use in loop
return 0
base_batch_size = int(2 ** (math.log(batch_size - 1) // math.log(2)))
step_size = max(base_batch_size // num_intra_steps, 1)
batch_size = base_batch_size + ((batch_size - base_batch_size - 1) // step_size) * step_size
if no_odd and batch_size % 2:
batch_size -= 1
return batch_size
def check_batch_size_retry(error_str):
""" check failure error string for conditions where batch decay retry should not be attempted
"""
error_str = error_str.lower()
if 'required rank' in error_str:
# Errors involving phrase 'required rank' typically happen when a conv is used that's
# not compatible with channels_last memory format.
return False
if 'illegal' in error_str:
# 'Illegal memory access' errors in CUDA typically leave process in unusable state
return False
return True

@ -22,7 +22,8 @@ from contextlib import suppress
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser,\
decay_batch_step, check_batch_size_retry
has_apex = False has_apex = False
try: try:
@ -122,6 +123,8 @@ parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
help='Real labels JSON file for imagenet evaluation') help='Real labels JSON file for imagenet evaluation')
parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME', parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME',
help='Valid label indices txt file for validation of partial label space') help='Valid label indices txt file for validation of partial label space')
parser.add_argument('--retry', default=False, action='store_true',
help='Enable batch size decay & retry for single model validation')
def validate(args): def validate(args):
@ -303,18 +306,19 @@ def _try_run(args, initial_batch_size):
batch_size = initial_batch_size batch_size = initial_batch_size
results = OrderedDict() results = OrderedDict()
error_str = 'Unknown' error_str = 'Unknown'
while batch_size >= 1: while batch_size:
args.batch_size = batch_size args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case
torch.cuda.empty_cache()
try: try:
torch.cuda.empty_cache()
results = validate(args) results = validate(args)
return results return results
except RuntimeError as e: except RuntimeError as e:
error_str = str(e) error_str = str(e)
if 'channels_last' in error_str: _logger.error(f'"{error_str}" while running validation.')
if not check_batch_size_retry(error_str):
break break
_logger.warning(f'"{error_str}" while running validation. Reducing batch size to {batch_size} for retry.') batch_size = decay_batch_step(batch_size)
batch_size = batch_size // 2 _logger.warning(f'Reducing batch size to {batch_size} for retry.')
results['error'] = error_str results['error'] = error_str
_logger.error(f'{args.model} failed to validate ({error_str}).') _logger.error(f'{args.model} failed to validate ({error_str}).')
return results return results
@ -368,7 +372,10 @@ def main():
if len(results): if len(results):
write_results(results_file, results) write_results(results_file, results)
else: else:
results = validate(args) if args.retry:
results = _try_run(args, args.batch_size)
else:
results = validate(args)
# output results in JSON to stdout w/ delimiter for runner script # output results in JSON to stdout w/ delimiter for runner script
print(f'--result\n{json.dumps(results, indent=4)}') print(f'--result\n{json.dumps(results, indent=4)}')

Loading…
Cancel
Save