From eb333d6641e2d68d63bddf731adb58909fc1b89c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 14 Oct 2022 15:51:20 -0700 Subject: [PATCH] Update validate.py to use updated amp args for impl/dtype --- validate.py | 39 ++++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/validate.py b/validate.py index cdce82bf..1a1ea9cd 100755 --- a/validate.py +++ b/validate.py @@ -103,11 +103,11 @@ parser.add_argument('--channels-last', action='store_true', default=False, parser.add_argument('--device', default='cuda', type=str, help="Device (accelerator) to use.") parser.add_argument('--amp', action='store_true', default=False, - help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.') -parser.add_argument('--apex-amp', action='store_true', default=False, - help='Use NVIDIA Apex AMP mixed precision') -parser.add_argument('--native-amp', action='store_true', default=False, - help='Use Native Torch AMP mixed precision') + help='use NVIDIA Apex AMP or Native AMP for mixed precision training') +parser.add_argument('--amp-dtype', default='float16', type=str, + help='lower precision AMP dtype (default: float16)') +parser.add_argument('--amp-impl', default='native', type=str, + help='AMP impl to use, "native" or "apex" (default: native)') parser.add_argument('--tf-preprocessing', action='store_true', default=False, help='Use Tensorflow preprocessing pipeline (require CPU TF installed') parser.add_argument('--use-ema', dest='use_ema', action='store_true', @@ -142,21 +142,22 @@ def validate(args): device = torch.device(args.device) - amp_autocast = suppress # do nothing + # resolve AMP arguments based on PyTorch / Apex availability + use_amp = None + amp_autocast = suppress if args.amp: - if has_native_amp: - args.native_amp = True - elif has_apex: - args.apex_amp = True + if args.amp_impl == 'apex': + assert has_apex, 'AMP impl specified as APEX but APEX is not installed.' + assert args.amp_dtype == 'float16' + use_amp = 'apex' + _logger.info('Validating in mixed precision with NVIDIA APEX AMP.') else: - _logger.warning("Neither APEX or Native Torch AMP is available.") - assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set." - if args.native_amp: - amp_autocast = partial(torch.autocast, device_type=device.type) - _logger.info('Validating in mixed precision with native PyTorch AMP.') - elif args.apex_amp: - assert device.type == 'cuda' - _logger.info('Validating in mixed precision with NVIDIA APEX AMP.') + assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).' + assert args.amp_dtype in ('float16', 'bfloat16') + use_amp = 'native' + amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16 + amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype) + _logger.info('Validating in mixed precision with native PyTorch AMP.') else: _logger.info('Validating in float32. AMP not enabled.') @@ -204,7 +205,7 @@ def validate(args): model = memory_efficient_fusion(model) model = model.to(device) - if args.apex_amp: + if use_amp == 'apex': model = amp.initialize(model, opt_level='O1') if args.channels_last: