From d3ee3de96a21a96fff0e2f3c3a93b6c3b12306bc Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 11 Jun 2020 13:34:21 -0700 Subject: [PATCH] Update validation script first batch prime and clear cuda cache between multi-model runs --- validate.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/validate.py b/validate.py index 50010cce..ebd4d849 100755 --- a/validate.py +++ b/validate.py @@ -145,7 +145,8 @@ def validate(args): model.eval() with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non - model(torch.randn((args.batch_size,) + data_config['input_size']).cuda()) + input = torch.randn((args.batch_size,) + data_config['input_size']).cuda() + model(input) end = time.time() for i, (input, target) in enumerate(loader): if args.no_prefetcher: @@ -238,6 +239,7 @@ def main(): raise e batch_size = max(batch_size // 2, args.num_gpu) print("Validation failed, reducing batch size by 50%") + torch.cuda.empty_cache() result.update(r) if args.checkpoint: result['checkpoint'] = args.checkpoint