Add --no-retry flag to benchmark.py to skip batch_size decay and retry on error. Fix #1226. Update deepspeed profile usage for latest DS releases. Fix # 1333

pull/1327/head
Ross Wightman 2 years ago
parent db0cee9910
commit 28e0152043

@ -71,6 +71,8 @@ parser.add_argument('--bench', default='both', type=str,
help="Benchmark mode. One of 'inference', 'train', 'both'. Defaults to 'both'")
parser.add_argument('--detail', action='store_true', default=False,
help='Provide train fwd/bwd/opt breakdown detail if True. Defaults to False')
parser.add_argument('--no-retry', action='store_true', default=False,
help='Do not decay batch size and retry on error.')
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
help='Output csv file for validation results (summary)')
parser.add_argument('--num-warm-iter', default=10, type=int,
@ -169,10 +171,9 @@ def resolve_precision(precision: str):
def profile_deepspeed(model, input_size=(3, 224, 224), batch_size=1, detailed=False):
macs, _ = get_model_profile(
_, macs, _ = get_model_profile(
model=model,
input_res=(batch_size,) + input_size, # input shape or input to the input_constructor
input_constructor=None, # if specified, a constructor taking input_res is used as input to the model
input_shape=(batch_size,) + input_size, # input shape/resolution
print_profile=detailed, # prints the model graph with the measured profile attached to each module
detailed=detailed, # print the detailed profile
warm_up=10, # the number of warm-ups before measuring the time of each module
@ -197,8 +198,19 @@ def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False
class BenchmarkRunner:
def __init__(
self, model_name, detail=False, device='cuda', torchscript=False, aot_autograd=False, precision='float32',
fuser='', num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs):
self,
model_name,
detail=False,
device='cuda',
torchscript=False,
aot_autograd=False,
precision='float32',
fuser='',
num_warm_iter=10,
num_bench_iter=50,
use_train_size=False,
**kwargs
):
self.model_name = model_name
self.detail = detail
self.device = device
@ -256,7 +268,13 @@ class BenchmarkRunner:
class InferenceBenchmarkRunner(BenchmarkRunner):
def __init__(self, model_name, device='cuda', torchscript=False, **kwargs):
def __init__(
self,
model_name,
device='cuda',
torchscript=False,
**kwargs
):
super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs)
self.model.eval()
@ -325,7 +343,13 @@ class InferenceBenchmarkRunner(BenchmarkRunner):
class TrainBenchmarkRunner(BenchmarkRunner):
def __init__(self, model_name, device='cuda', torchscript=False, **kwargs):
def __init__(
self,
model_name,
device='cuda',
torchscript=False,
**kwargs
):
super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs)
self.model.train()
@ -492,7 +516,7 @@ def decay_batch_exp(batch_size, factor=0.5, divisor=16):
return max(0, int(out_batch_size))
def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs):
def _try_run(model_name, bench_fn, bench_kwargs, initial_batch_size, no_batch_size_retry=False):
batch_size = initial_batch_size
results = dict()
error_str = 'Unknown'
@ -507,8 +531,11 @@ def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs):
if 'channels_last' in error_str:
_logger.error(f'{model_name} not supported in channels_last, skipping.')
break
_logger.warning(f'"{error_str}" while running benchmark. Reducing batch size to {batch_size} for retry.')
_logger.error(f'"{error_str}" while running benchmark.')
if no_batch_size_retry:
break
batch_size = decay_batch_exp(batch_size)
_logger.warning(f'Reducing batch size to {batch_size} for retry.')
results['error'] = error_str
return results
@ -550,7 +577,13 @@ def benchmark(args):
model_results = OrderedDict(model=model)
for prefix, bench_fn in zip(prefixes, bench_fns):
run_results = _try_run(model, bench_fn, initial_batch_size=batch_size, bench_kwargs=bench_kwargs)
run_results = _try_run(
model,
bench_fn,
bench_kwargs=bench_kwargs,
initial_batch_size=batch_size,
no_batch_size_retry=args.no_retry,
)
if prefix and 'error' not in run_results:
run_results = {'_'.join([prefix, k]): v for k, v in run_results.items()}
model_results.update(run_results)

Loading…
Cancel
Save