|
|
@ -473,6 +473,7 @@ def decay_batch_exp(batch_size, factor=0.5, divisor=16):
|
|
|
|
def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs):
|
|
|
|
def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs):
|
|
|
|
batch_size = initial_batch_size
|
|
|
|
batch_size = initial_batch_size
|
|
|
|
results = dict()
|
|
|
|
results = dict()
|
|
|
|
|
|
|
|
error_str = 'Unknown'
|
|
|
|
while batch_size >= 1:
|
|
|
|
while batch_size >= 1:
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
try:
|
|
|
|
try:
|
|
|
@ -480,13 +481,13 @@ def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs):
|
|
|
|
results = bench.run()
|
|
|
|
results = bench.run()
|
|
|
|
return results
|
|
|
|
return results
|
|
|
|
except RuntimeError as e:
|
|
|
|
except RuntimeError as e:
|
|
|
|
e_str = str(e)
|
|
|
|
error_str = str(e)
|
|
|
|
print(e_str)
|
|
|
|
if 'channels_last' in error_str:
|
|
|
|
if 'channels_last' in e_str:
|
|
|
|
_logger.error(f'{model_name} not supported in channels_last, skipping.')
|
|
|
|
print(f'Error: {model_name} not supported in channels_last, skipping.')
|
|
|
|
|
|
|
|
break
|
|
|
|
break
|
|
|
|
print(f'Error: "{e_str}" while running benchmark. Reducing batch size to {batch_size} for retry.')
|
|
|
|
_logger.warning(f'"{error_str}" while running benchmark. Reducing batch size to {batch_size} for retry.')
|
|
|
|
batch_size = decay_batch_exp(batch_size)
|
|
|
|
batch_size = decay_batch_exp(batch_size)
|
|
|
|
|
|
|
|
results['error'] = error_str
|
|
|
|
return results
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -528,13 +529,14 @@ def benchmark(args):
|
|
|
|
model_results = OrderedDict(model=model)
|
|
|
|
model_results = OrderedDict(model=model)
|
|
|
|
for prefix, bench_fn in zip(prefixes, bench_fns):
|
|
|
|
for prefix, bench_fn in zip(prefixes, bench_fns):
|
|
|
|
run_results = _try_run(model, bench_fn, initial_batch_size=batch_size, bench_kwargs=bench_kwargs)
|
|
|
|
run_results = _try_run(model, bench_fn, initial_batch_size=batch_size, bench_kwargs=bench_kwargs)
|
|
|
|
if prefix:
|
|
|
|
if prefix and 'error' not in run_results:
|
|
|
|
run_results = {'_'.join([prefix, k]): v for k, v in run_results.items()}
|
|
|
|
run_results = {'_'.join([prefix, k]): v for k, v in run_results.items()}
|
|
|
|
model_results.update(run_results)
|
|
|
|
model_results.update(run_results)
|
|
|
|
|
|
|
|
if 'error' not in model_results:
|
|
|
|
param_count = model_results.pop('infer_param_count', model_results.pop('train_param_count', 0))
|
|
|
|
param_count = model_results.pop('infer_param_count', model_results.pop('train_param_count', 0))
|
|
|
|
model_results.setdefault('param_count', param_count)
|
|
|
|
model_results.setdefault('param_count', param_count)
|
|
|
|
model_results.pop('train_param_count', 0)
|
|
|
|
model_results.pop('train_param_count', 0)
|
|
|
|
return model_results if model_results['param_count'] else dict()
|
|
|
|
return model_results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
def main():
|
|
|
@ -578,13 +580,15 @@ def main():
|
|
|
|
sort_key = 'train_samples_per_sec'
|
|
|
|
sort_key = 'train_samples_per_sec'
|
|
|
|
elif 'profile' in args.bench:
|
|
|
|
elif 'profile' in args.bench:
|
|
|
|
sort_key = 'infer_gmacs'
|
|
|
|
sort_key = 'infer_gmacs'
|
|
|
|
|
|
|
|
results = filter(lambda x: sort_key in x, results)
|
|
|
|
results = sorted(results, key=lambda x: x[sort_key], reverse=True)
|
|
|
|
results = sorted(results, key=lambda x: x[sort_key], reverse=True)
|
|
|
|
if len(results):
|
|
|
|
if len(results):
|
|
|
|
write_results(results_file, results)
|
|
|
|
write_results(results_file, results)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
results = benchmark(args)
|
|
|
|
results = benchmark(args)
|
|
|
|
json_str = json.dumps(results, indent=4)
|
|
|
|
|
|
|
|
print(json_str)
|
|
|
|
# output results in JSON to stdout w/ delimiter for runner script
|
|
|
|
|
|
|
|
print(f'--result\n{json.dumps(results, indent=4)}')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def write_results(results_file, results):
|
|
|
|
def write_results(results_file, results):
|
|
|
|