From 0dbd9352ce7bd7bfd86f64e47a3436312eb6e5dc Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 18 Jul 2022 18:01:39 -0700 Subject: [PATCH] Add bulk_runner script and updates to benchmark.py and validate.py for better error handling in bulk runs (used for benchmark and validation result runs). Improved batch size decay stepping on retry... --- benchmark.py | 33 ++++--- bulk_runner.py | 184 ++++++++++++++++++++++++++++++++++++++ timm/utils/__init__.py | 1 + timm/utils/decay_batch.py | 43 +++++++++ validate.py | 23 +++-- 5 files changed, 259 insertions(+), 25 deletions(-) create mode 100755 bulk_runner.py create mode 100644 timm/utils/decay_batch.py diff --git a/benchmark.py b/benchmark.py index 23047bb5..4679a009 100755 --- a/benchmark.py +++ b/benchmark.py @@ -21,7 +21,7 @@ import torch.nn.parallel from timm.data import resolve_data_config from timm.models import create_model, is_model, list_models from timm.optim import create_optimizer_v2 -from timm.utils import setup_default_logging, set_jit_fuser +from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry has_apex = False try: @@ -506,34 +506,31 @@ class ProfileRunner(BenchmarkRunner): return results -def decay_batch_exp(batch_size, factor=0.5, divisor=16): - out_batch_size = batch_size * factor - if out_batch_size > divisor: - out_batch_size = (out_batch_size + 1) // divisor * divisor - else: - out_batch_size = batch_size - 1 - return max(0, int(out_batch_size)) - - -def _try_run(model_name, bench_fn, bench_kwargs, initial_batch_size, no_batch_size_retry=False): +def _try_run( + model_name, + bench_fn, + bench_kwargs, + initial_batch_size, + no_batch_size_retry=False +): batch_size = initial_batch_size results = dict() error_str = 'Unknown' - while batch_size >= 1: - torch.cuda.empty_cache() + while batch_size: try: + torch.cuda.empty_cache() bench = bench_fn(model_name=model_name, batch_size=batch_size, **bench_kwargs) results = bench.run() return results except RuntimeError as e: error_str = str(e) - if 'channels_last' in error_str: - _logger.error(f'{model_name} not supported in channels_last, skipping.') - break _logger.error(f'"{error_str}" while running benchmark.') + if not check_batch_size_retry(error_str): + _logger.error(f'Unrecoverable error encountered while benchmarking {model_name}, skipping.') + break if no_batch_size_retry: break - batch_size = decay_batch_exp(batch_size) + batch_size = decay_batch_step(batch_size) _logger.warning(f'Reducing batch size to {batch_size} for retry.') results['error'] = error_str return results @@ -586,6 +583,8 @@ def benchmark(args): if prefix and 'error' not in run_results: run_results = {'_'.join([prefix, k]): v for k, v in run_results.items()} model_results.update(run_results) + if 'error' in run_results: + break if 'error' not in model_results: param_count = model_results.pop('infer_param_count', model_results.pop('train_param_count', 0)) model_results.setdefault('param_count', param_count) diff --git a/bulk_runner.py b/bulk_runner.py new file mode 100755 index 00000000..b71d0bb6 --- /dev/null +++ b/bulk_runner.py @@ -0,0 +1,184 @@ +#!/usr/bin/env python3 +""" Bulk Model Script Runner + +Run validation or benchmark script in separate process for each model + +Benchmark all 'vit*' models: +python bulk_runner.py --model-list 'vit*' --results-file vit_bench.csv benchmark.py --amp -b 512 + +Validate all models: +python bulk_runner.py --model-list all --results-file val.csv --pretrained validate.py /imagenet/validation/ --amp -b 512 --retry + +Hacked together by Ross Wightman (https://github.com/rwightman) +""" +import argparse +import os +import sys +import csv +import json +import subprocess +import time +from typing import Callable, List, Tuple, Union + + +from timm.models import is_model, list_models + + +parser = argparse.ArgumentParser(description='Per-model process launcher') + +# model and results args +parser.add_argument( + '--model-list', metavar='NAME', default='', + help='txt file based list of model names to benchmark') +parser.add_argument( + '--results-file', default='', type=str, metavar='FILENAME', + help='Output csv file for validation results (summary)') +parser.add_argument( + '--sort-key', default='', type=str, metavar='COL', + help='Specify sort key for results csv') +parser.add_argument( + "--pretrained", action='store_true', + help="only run models with pretrained weights") + +parser.add_argument( + "--delay", + type=float, + default=0, + help="Interval, in seconds, to delay between model invocations.", +) +parser.add_argument( + "--start_method", type=str, default="spawn", choices=["spawn", "fork", "forkserver"], + help="Multiprocessing start method to use when creating workers.", +) +parser.add_argument( + "--no_python", + help="Skip prepending the script with 'python' - just execute it directly. Useful " + "when the script is not a Python script.", +) +parser.add_argument( + "-m", + "--module", + help="Change each process to interpret the launch script as a Python module, executing " + "with the same behavior as 'python -m'.", +) + +# positional +parser.add_argument( + "script", type=str, + help="Full path to the program/script to be launched for each model config.", +) +parser.add_argument("script_args", nargs=argparse.REMAINDER) + + +def cmd_from_args(args) -> Tuple[Union[Callable, str], List[str]]: + # If ``args`` not passed, defaults to ``sys.argv[:1]`` + with_python = not args.no_python + cmd: Union[Callable, str] + cmd_args = [] + if with_python: + cmd = os.getenv("PYTHON_EXEC", sys.executable) + cmd_args.append("-u") + if args.module: + cmd_args.append("-m") + cmd_args.append(args.script) + else: + if args.module: + raise ValueError( + "Don't use both the '--no_python' flag" + " and the '--module' flag at the same time." + ) + cmd = args.script + cmd_args.extend(args.script_args) + + return cmd, cmd_args + + +def main(): + args = parser.parse_args() + cmd, cmd_args = cmd_from_args(args) + + model_cfgs = [] + model_names = [] + if args.model_list == 'all': + # NOTE should make this config, for validation / benchmark runs the focus is 1k models, + # so we filter out 21/22k and some other unusable heads. This will change in the future... + exclude_model_filters = ['*in21k', '*in22k', '*dino', '*_22k'] + model_names = list_models( + pretrained=args.pretrained, # only include models w/ pretrained checkpoints if set + exclude_filters=exclude_model_filters + ) + model_cfgs = [(n, None) for n in model_names] + elif not is_model(args.model_list): + # model name doesn't exist, try as wildcard filter + model_names = list_models(args.model_list) + model_cfgs = [(n, None) for n in model_names] + + if not model_cfgs and os.path.exists(args.model_list): + with open(args.model_list) as f: + model_names = [line.rstrip() for line in f] + model_cfgs = [(n, None) for n in model_names] + + if len(model_cfgs): + results_file = args.results_file or './results.csv' + results = [] + errors = [] + print('Running script on these models: {}'.format(', '.join(model_names))) + if not args.sort_key: + if 'benchmark' in args.script: + if any(['train' in a for a in args.script_args]): + sort_key = 'train_samples_per_sec' + else: + sort_key = 'infer_samples_per_sec' + else: + sort_key = 'top1' + else: + sort_key = args.sort_key + print(f'Script: {args.script}, Args: {args.script_args}, Sort key: {sort_key}') + + try: + for m, _ in model_cfgs: + if not m: + continue + args_str = (cmd, *[str(e) for e in cmd_args], '--model', m) + try: + o = subprocess.check_output(args=args_str).decode('utf-8').split('--result')[-1] + r = json.loads(o) + results.append(r) + except Exception as e: + # FIXME batch_size retry loop is currently done in either validation.py or benchmark.py + # for further robustness (but more overhead), we may want to manage that by looping here... + errors.append(dict(model=m, error=str(e))) + if args.delay: + time.sleep(args.delay) + except KeyboardInterrupt as e: + pass + + errors.extend(list(filter(lambda x: 'error' in x, results))) + if errors: + print(f'{len(errors)} models had errors during run.') + for e in errors: + print(f"\t {e['model']} ({e.get('error', 'Unknown')})") + results = list(filter(lambda x: 'error' not in x, results)) + + no_sortkey = list(filter(lambda x: sort_key not in x, results)) + if no_sortkey: + print(f'{len(no_sortkey)} results missing sort key, skipping sort.') + else: + results = sorted(results, key=lambda x: x[sort_key], reverse=True) + + if len(results): + print(f'{len(results)} models run successfully. Saving results to {results_file}.') + write_results(results_file, results) + + +def write_results(results_file, results): + with open(results_file, mode='w') as cf: + dw = csv.DictWriter(cf, fieldnames=results[0].keys()) + dw.writeheader() + for r in results: + dw.writerow(r) + cf.flush() + + +if __name__ == '__main__': + main() diff --git a/timm/utils/__init__.py b/timm/utils/__init__.py index b8cef321..7b139852 100644 --- a/timm/utils/__init__.py +++ b/timm/utils/__init__.py @@ -2,6 +2,7 @@ from .agc import adaptive_clip_grad from .checkpoint_saver import CheckpointSaver from .clip_grad import dispatch_clip_grad from .cuda import ApexScaler, NativeScaler +from .decay_batch import decay_batch_step, check_batch_size_retry from .distributed import distribute_bn, reduce_tensor from .jit import set_jit_legacy, set_jit_fuser from .log import setup_default_logging, FormatterNoInfo diff --git a/timm/utils/decay_batch.py b/timm/utils/decay_batch.py new file mode 100644 index 00000000..852fa4b8 --- /dev/null +++ b/timm/utils/decay_batch.py @@ -0,0 +1,43 @@ +""" Batch size decay and retry helpers. + +Copyright 2022 Ross Wightman +""" +import math + + +def decay_batch_step(batch_size, num_intra_steps=2, no_odd=False): + """ power of two batch-size decay with intra steps + + Decay by stepping between powers of 2: + * determine power-of-2 floor of current batch size (base batch size) + * divide above value by num_intra_steps to determine step size + * floor batch_size to nearest multiple of step_size (from base batch size) + Examples: + num_steps == 4 --> 64, 56, 48, 40, 32, 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1 + num_steps (no_odd=True) == 4 --> 64, 56, 48, 40, 32, 28, 24, 20, 16, 14, 12, 10, 8, 6, 4, 2 + num_steps == 2 --> 64, 48, 32, 24, 16, 12, 8, 6, 4, 3, 2, 1 + num_steps == 1 --> 64, 32, 16, 8, 4, 2, 1 + """ + if batch_size <= 1: + # return 0 for stopping value so easy to use in loop + return 0 + base_batch_size = int(2 ** (math.log(batch_size - 1) // math.log(2))) + step_size = max(base_batch_size // num_intra_steps, 1) + batch_size = base_batch_size + ((batch_size - base_batch_size - 1) // step_size) * step_size + if no_odd and batch_size % 2: + batch_size -= 1 + return batch_size + + +def check_batch_size_retry(error_str): + """ check failure error string for conditions where batch decay retry should not be attempted + """ + error_str = error_str.lower() + if 'required rank' in error_str: + # Errors involving phrase 'required rank' typically happen when a conv is used that's + # not compatible with channels_last memory format. + return False + if 'illegal' in error_str: + # 'Illegal memory access' errors in CUDA typically leave process in unusable state + return False + return True diff --git a/validate.py b/validate.py index 7fa22b49..fd55d408 100755 --- a/validate.py +++ b/validate.py @@ -22,7 +22,8 @@ from contextlib import suppress from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet -from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser +from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser,\ + decay_batch_step, check_batch_size_retry has_apex = False try: @@ -122,6 +123,8 @@ parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME', help='Real labels JSON file for imagenet evaluation') parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME', help='Valid label indices txt file for validation of partial label space') +parser.add_argument('--retry', default=False, action='store_true', + help='Enable batch size decay & retry for single model validation') def validate(args): @@ -303,18 +306,19 @@ def _try_run(args, initial_batch_size): batch_size = initial_batch_size results = OrderedDict() error_str = 'Unknown' - while batch_size >= 1: - args.batch_size = batch_size - torch.cuda.empty_cache() + while batch_size: + args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case try: + torch.cuda.empty_cache() results = validate(args) return results except RuntimeError as e: error_str = str(e) - if 'channels_last' in error_str: + _logger.error(f'"{error_str}" while running validation.') + if not check_batch_size_retry(error_str): break - _logger.warning(f'"{error_str}" while running validation. Reducing batch size to {batch_size} for retry.') - batch_size = batch_size // 2 + batch_size = decay_batch_step(batch_size) + _logger.warning(f'Reducing batch size to {batch_size} for retry.') results['error'] = error_str _logger.error(f'{args.model} failed to validate ({error_str}).') return results @@ -368,7 +372,10 @@ def main(): if len(results): write_results(results_file, results) else: - results = validate(args) + if args.retry: + results = _try_run(args, args.batch_size) + else: + results = validate(args) # output results in JSON to stdout w/ delimiter for runner script print(f'--result\n{json.dumps(results, indent=4)}')