|
|
|
@ -131,6 +131,8 @@ def validate(args):
|
|
|
|
|
# might as well try to validate something
|
|
|
|
|
args.pretrained = args.pretrained or not args.checkpoint
|
|
|
|
|
args.prefetcher = not args.no_prefetcher
|
|
|
|
|
if not torch.cuda.is_available():
|
|
|
|
|
args.prefetcher = False
|
|
|
|
|
amp_autocast = suppress # do nothing
|
|
|
|
|
if args.amp:
|
|
|
|
|
if has_native_amp:
|
|
|
|
@ -185,8 +187,10 @@ def validate(args):
|
|
|
|
|
if args.aot_autograd:
|
|
|
|
|
assert has_functorch, "functorch is needed for --aot-autograd"
|
|
|
|
|
model = memory_efficient_fusion(model)
|
|
|
|
|
|
|
|
|
|
model = model.cuda()
|
|
|
|
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
|
model = model.cuda()
|
|
|
|
|
|
|
|
|
|
if args.apex_amp:
|
|
|
|
|
model = amp.initialize(model, opt_level='O1')
|
|
|
|
|
|
|
|
|
@ -236,7 +240,10 @@ def validate(args):
|
|
|
|
|
model.eval()
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
|
|
|
|
|
input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda()
|
|
|
|
|
input = torch.randn((args.batch_size,) + tuple(data_config['input_size']))
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
|
input = input.cuda()
|
|
|
|
|
|
|
|
|
|
if args.channels_last:
|
|
|
|
|
input = input.contiguous(memory_format=torch.channels_last)
|
|
|
|
|
with amp_autocast():
|
|
|
|
|