diff --git a/validate.py b/validate.py index 50010cce..ebd4d849 100755 --- a/validate.py +++ b/validate.py @@ -145,7 +145,8 @@ def validate(args): model.eval() with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non - model(torch.randn((args.batch_size,) + data_config['input_size']).cuda()) + input = torch.randn((args.batch_size,) + data_config['input_size']).cuda() + model(input) end = time.time() for i, (input, target) in enumerate(loader): if args.no_prefetcher: @@ -238,6 +239,7 @@ def main(): raise e batch_size = max(batch_size // 2, args.num_gpu) print("Validation failed, reducing batch size by 50%") + torch.cuda.empty_cache() result.update(r) if args.checkpoint: result['checkpoint'] = args.checkpoint