|
|
|
@ -116,15 +116,20 @@ def validate(args):
|
|
|
|
|
args.prefetcher = not args.no_prefetcher
|
|
|
|
|
amp_autocast = suppress # do nothing
|
|
|
|
|
if args.amp:
|
|
|
|
|
if has_apex:
|
|
|
|
|
args.apex_amp = True
|
|
|
|
|
elif has_native_amp:
|
|
|
|
|
if has_native_amp:
|
|
|
|
|
args.native_amp = True
|
|
|
|
|
elif has_apex:
|
|
|
|
|
args.apex_amp = True
|
|
|
|
|
else:
|
|
|
|
|
_logger.warning("Neither APEX or Native Torch AMP is available, using FP32.")
|
|
|
|
|
_logger.warning("Neither APEX or Native Torch AMP is available.")
|
|
|
|
|
assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
|
|
|
|
|
if args.native_amp:
|
|
|
|
|
amp_autocast = torch.cuda.amp.autocast
|
|
|
|
|
_logger.info('Validating in mixed precision with native PyTorch AMP.')
|
|
|
|
|
elif args.apex_amp:
|
|
|
|
|
_logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
|
|
|
|
|
else:
|
|
|
|
|
_logger.info('Validating in float32. AMP not enabled.')
|
|
|
|
|
|
|
|
|
|
if args.legacy_jit:
|
|
|
|
|
set_jit_legacy()
|
|
|
|
|