From 07a3522cf075480a6d8b812c40ebc8feafa236a2 Mon Sep 17 00:00:00 2001 From: jason-furiosa Date: Thu, 29 Dec 2022 04:08:04 +0000 Subject: [PATCH] validation using cpu --- validate.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/validate.py b/validate.py index fd55d408..d119d03a 100755 --- a/validate.py +++ b/validate.py @@ -131,6 +131,8 @@ def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher + if not torch.cuda.is_available(): + args.prefetcher = False amp_autocast = suppress # do nothing if args.amp: if has_native_amp: @@ -185,8 +187,10 @@ def validate(args): if args.aot_autograd: assert has_functorch, "functorch is needed for --aot-autograd" model = memory_efficient_fusion(model) - - model = model.cuda() + + if torch.cuda.is_available(): + model = model.cuda() + if args.apex_amp: model = amp.initialize(model, opt_level='O1') @@ -236,7 +240,10 @@ def validate(args): model.eval() with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non - input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda() + input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])) + if torch.cuda.is_available(): + input = input.cuda() + if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) with amp_autocast():