validation using cpu

pull/1606/head
jason-furiosa 3 years ago
parent 7cd4204a28
commit 07a3522cf0

@ -131,6 +131,8 @@ def validate(args):
# might as well try to validate something
args.pretrained = args.pretrained or not args.checkpoint
args.prefetcher = not args.no_prefetcher
if not torch.cuda.is_available():
args.prefetcher = False
amp_autocast = suppress # do nothing
if args.amp:
if has_native_amp:
@ -185,8 +187,10 @@ def validate(args):
if args.aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd"
model = memory_efficient_fusion(model)
model = model.cuda()
if torch.cuda.is_available():
model = model.cuda()
if args.apex_amp:
model = amp.initialize(model, opt_level='O1')
@ -236,7 +240,10 @@ def validate(args):
model.eval()
with torch.no_grad():
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda()
input = torch.randn((args.batch_size,) + tuple(data_config['input_size']))
if torch.cuda.is_available():
input = input.cuda()
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
with amp_autocast():

Loading…
Cancel
Save