diff --git a/benchmark.py b/benchmark.py index 1362eeab..74f09489 100755 --- a/benchmark.py +++ b/benchmark.py @@ -71,6 +71,8 @@ parser.add_argument('--bench', default='both', type=str, help="Benchmark mode. One of 'inference', 'train', 'both'. Defaults to 'both'") parser.add_argument('--detail', action='store_true', default=False, help='Provide train fwd/bwd/opt breakdown detail if True. Defaults to False') +parser.add_argument('--no-retry', action='store_true', default=False, + help='Do not decay batch size and retry on error.') parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', help='Output csv file for validation results (summary)') parser.add_argument('--num-warm-iter', default=10, type=int, @@ -169,10 +171,9 @@ def resolve_precision(precision: str): def profile_deepspeed(model, input_size=(3, 224, 224), batch_size=1, detailed=False): - macs, _ = get_model_profile( + _, macs, _ = get_model_profile( model=model, - input_res=(batch_size,) + input_size, # input shape or input to the input_constructor - input_constructor=None, # if specified, a constructor taking input_res is used as input to the model + input_shape=(batch_size,) + input_size, # input shape/resolution print_profile=detailed, # prints the model graph with the measured profile attached to each module detailed=detailed, # print the detailed profile warm_up=10, # the number of warm-ups before measuring the time of each module @@ -197,8 +198,19 @@ def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False class BenchmarkRunner: def __init__( - self, model_name, detail=False, device='cuda', torchscript=False, aot_autograd=False, precision='float32', - fuser='', num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs): + self, + model_name, + detail=False, + device='cuda', + torchscript=False, + aot_autograd=False, + precision='float32', + fuser='', + num_warm_iter=10, + num_bench_iter=50, + use_train_size=False, + **kwargs + ): self.model_name = model_name self.detail = detail self.device = device @@ -256,7 +268,13 @@ class BenchmarkRunner: class InferenceBenchmarkRunner(BenchmarkRunner): - def __init__(self, model_name, device='cuda', torchscript=False, **kwargs): + def __init__( + self, + model_name, + device='cuda', + torchscript=False, + **kwargs + ): super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs) self.model.eval() @@ -325,7 +343,13 @@ class InferenceBenchmarkRunner(BenchmarkRunner): class TrainBenchmarkRunner(BenchmarkRunner): - def __init__(self, model_name, device='cuda', torchscript=False, **kwargs): + def __init__( + self, + model_name, + device='cuda', + torchscript=False, + **kwargs + ): super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs) self.model.train() @@ -492,7 +516,7 @@ def decay_batch_exp(batch_size, factor=0.5, divisor=16): return max(0, int(out_batch_size)) -def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs): +def _try_run(model_name, bench_fn, bench_kwargs, initial_batch_size, no_batch_size_retry=False): batch_size = initial_batch_size results = dict() error_str = 'Unknown' @@ -507,8 +531,11 @@ def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs): if 'channels_last' in error_str: _logger.error(f'{model_name} not supported in channels_last, skipping.') break - _logger.warning(f'"{error_str}" while running benchmark. Reducing batch size to {batch_size} for retry.') + _logger.error(f'"{error_str}" while running benchmark.') + if no_batch_size_retry: + break batch_size = decay_batch_exp(batch_size) + _logger.warning(f'Reducing batch size to {batch_size} for retry.') results['error'] = error_str return results @@ -550,7 +577,13 @@ def benchmark(args): model_results = OrderedDict(model=model) for prefix, bench_fn in zip(prefixes, bench_fns): - run_results = _try_run(model, bench_fn, initial_batch_size=batch_size, bench_kwargs=bench_kwargs) + run_results = _try_run( + model, + bench_fn, + bench_kwargs=bench_kwargs, + initial_batch_size=batch_size, + no_batch_size_retry=args.no_retry, + ) if prefix and 'error' not in run_results: run_results = {'_'.join([prefix, k]): v for k, v in run_results.items()} model_results.update(run_results)