diff --git a/benchmark.py b/benchmark.py index 4812d85c..5201f02d 100755 --- a/benchmark.py +++ b/benchmark.py @@ -374,14 +374,14 @@ def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs): batch_size = initial_batch_size results = dict() while batch_size >= 1: + torch.cuda.empty_cache() try: bench = bench_fn(model_name=model_name, batch_size=batch_size, **bench_kwargs) results = bench.run() return results except RuntimeError as e: - torch.cuda.empty_cache() - batch_size = decay_batch_exp(batch_size) print(f'Error: {str(e)} while running benchmark. Reducing batch size to {batch_size} for retry.') + batch_size = decay_batch_exp(batch_size) return results diff --git a/train.py b/train.py index 85829fc1..4264a164 100755 --- a/train.py +++ b/train.py @@ -560,7 +560,7 @@ def main(): best_metric = None best_epoch = None saver = None - output_dir = '' + output_dir = None if args.local_rank == 0: if args.experiment: exp_name = args.experiment @@ -606,9 +606,10 @@ def main(): # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) - update_summary( - epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), - write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb) + if output_dir is not None: + update_summary( + epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), + write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb) if saver is not None: # save proper checkpoint with eval metric @@ -623,7 +624,7 @@ def main(): def train_one_epoch( epoch, model, loader, optimizer, loss_fn, args, - lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress, + lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress, loss_scaler=None, model_ema=None, mixup_fn=None): if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: