Fix #566, summary.csv writing to pwd on local_rank != 0. Tweak benchmark mem handling to see if it reduces likelihood of 'bad' exceptions on OOM.

pull/571/head
Ross Wightman 4 years ago
parent 1b0c8e7b01
commit e15e68d881

@ -374,14 +374,14 @@ def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs):
batch_size = initial_batch_size batch_size = initial_batch_size
results = dict() results = dict()
while batch_size >= 1: while batch_size >= 1:
torch.cuda.empty_cache()
try: try:
bench = bench_fn(model_name=model_name, batch_size=batch_size, **bench_kwargs) bench = bench_fn(model_name=model_name, batch_size=batch_size, **bench_kwargs)
results = bench.run() results = bench.run()
return results return results
except RuntimeError as e: except RuntimeError as e:
torch.cuda.empty_cache()
batch_size = decay_batch_exp(batch_size)
print(f'Error: {str(e)} while running benchmark. Reducing batch size to {batch_size} for retry.') print(f'Error: {str(e)} while running benchmark. Reducing batch size to {batch_size} for retry.')
batch_size = decay_batch_exp(batch_size)
return results return results

@ -560,7 +560,7 @@ def main():
best_metric = None best_metric = None
best_epoch = None best_epoch = None
saver = None saver = None
output_dir = '' output_dir = None
if args.local_rank == 0: if args.local_rank == 0:
if args.experiment: if args.experiment:
exp_name = args.experiment exp_name = args.experiment
@ -606,9 +606,10 @@ def main():
# step LR for next epoch # step LR for next epoch
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
update_summary( if output_dir is not None:
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), update_summary(
write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb) epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb)
if saver is not None: if saver is not None:
# save proper checkpoint with eval metric # save proper checkpoint with eval metric
@ -623,7 +624,7 @@ def main():
def train_one_epoch( def train_one_epoch(
epoch, model, loader, optimizer, loss_fn, args, epoch, model, loader, optimizer, loss_fn, args,
lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress, lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress,
loss_scaler=None, model_ema=None, mixup_fn=None): loss_scaler=None, model_ema=None, mixup_fn=None):
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:

Loading…
Cancel
Save