|
|
|
@ -103,11 +103,11 @@ parser.add_argument('--channels-last', action='store_true', default=False,
|
|
|
|
|
parser.add_argument('--device', default='cuda', type=str,
|
|
|
|
|
help="Device (accelerator) to use.")
|
|
|
|
|
parser.add_argument('--amp', action='store_true', default=False,
|
|
|
|
|
help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.')
|
|
|
|
|
parser.add_argument('--apex-amp', action='store_true', default=False,
|
|
|
|
|
help='Use NVIDIA Apex AMP mixed precision')
|
|
|
|
|
parser.add_argument('--native-amp', action='store_true', default=False,
|
|
|
|
|
help='Use Native Torch AMP mixed precision')
|
|
|
|
|
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
|
|
|
|
|
parser.add_argument('--amp-dtype', default='float16', type=str,
|
|
|
|
|
help='lower precision AMP dtype (default: float16)')
|
|
|
|
|
parser.add_argument('--amp-impl', default='native', type=str,
|
|
|
|
|
help='AMP impl to use, "native" or "apex" (default: native)')
|
|
|
|
|
parser.add_argument('--tf-preprocessing', action='store_true', default=False,
|
|
|
|
|
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
|
|
|
|
|
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
|
|
|
|
@ -142,21 +142,22 @@ def validate(args):
|
|
|
|
|
|
|
|
|
|
device = torch.device(args.device)
|
|
|
|
|
|
|
|
|
|
amp_autocast = suppress # do nothing
|
|
|
|
|
# resolve AMP arguments based on PyTorch / Apex availability
|
|
|
|
|
use_amp = None
|
|
|
|
|
amp_autocast = suppress
|
|
|
|
|
if args.amp:
|
|
|
|
|
if has_native_amp:
|
|
|
|
|
args.native_amp = True
|
|
|
|
|
elif has_apex:
|
|
|
|
|
args.apex_amp = True
|
|
|
|
|
if args.amp_impl == 'apex':
|
|
|
|
|
assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
|
|
|
|
|
assert args.amp_dtype == 'float16'
|
|
|
|
|
use_amp = 'apex'
|
|
|
|
|
_logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
|
|
|
|
|
else:
|
|
|
|
|
_logger.warning("Neither APEX or Native Torch AMP is available.")
|
|
|
|
|
assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
|
|
|
|
|
if args.native_amp:
|
|
|
|
|
amp_autocast = partial(torch.autocast, device_type=device.type)
|
|
|
|
|
_logger.info('Validating in mixed precision with native PyTorch AMP.')
|
|
|
|
|
elif args.apex_amp:
|
|
|
|
|
assert device.type == 'cuda'
|
|
|
|
|
_logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
|
|
|
|
|
assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
|
|
|
|
|
assert args.amp_dtype in ('float16', 'bfloat16')
|
|
|
|
|
use_amp = 'native'
|
|
|
|
|
amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16
|
|
|
|
|
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
|
|
|
|
|
_logger.info('Validating in mixed precision with native PyTorch AMP.')
|
|
|
|
|
else:
|
|
|
|
|
_logger.info('Validating in float32. AMP not enabled.')
|
|
|
|
|
|
|
|
|
@ -204,7 +205,7 @@ def validate(args):
|
|
|
|
|
model = memory_efficient_fusion(model)
|
|
|
|
|
|
|
|
|
|
model = model.to(device)
|
|
|
|
|
if args.apex_amp:
|
|
|
|
|
if use_amp == 'apex':
|
|
|
|
|
model = amp.initialize(model, opt_level='O1')
|
|
|
|
|
|
|
|
|
|
if args.channels_last:
|
|
|
|
|