diff --git a/train.py b/train.py index 25a03af2..08dadb02 100755 --- a/train.py +++ b/train.py @@ -315,10 +315,10 @@ group.add_argument('--save-images', action='store_true', default=False, help='save images of input bathes every log interval for debugging') group.add_argument('--amp', action='store_true', default=False, help='use NVIDIA Apex AMP or Native AMP for mixed precision training') -group.add_argument('--apex-amp', action='store_true', default=False, - help='Use NVIDIA Apex AMP mixed precision') -group.add_argument('--native-amp', action='store_true', default=False, - help='Use Native Torch AMP mixed precision') +group.add_argument('--amp-dtype', default='float16', type=str, + help='lower precision AMP dtype (default: float16)') +group.add_argument('--amp-impl', default='native', type=str, + help='AMP impl to use, "native" or "apex" (default: native)') group.add_argument('--no-ddp-bb', action='store_true', default=False, help='Force broadcast buffers for native DDP to off.') group.add_argument('--pin-mem', action='store_true', default=False, @@ -385,19 +385,18 @@ def main(): # resolve AMP arguments based on PyTorch / Apex availability use_amp = None + amp_dtype = torch.float16 if args.amp: - # `--amp` chooses native amp before apex (APEX ver not actively maintained) - if has_native_amp: - args.native_amp = True - elif has_apex: - args.apex_amp = True - if args.apex_amp and has_apex: - use_amp = 'apex' - elif args.native_amp and has_native_amp: - use_amp = 'native' - elif args.apex_amp or args.native_amp: - _logger.warning("Neither APEX or native Torch AMP is available, using float32. " - "Install NVIDA apex or upgrade to PyTorch 1.6") + if args.amp_impl == 'apex': + assert has_apex, 'AMP impl specified as APEX but APEX is not installed.' + use_amp = 'apex' + assert args.amp_dtype == 'float16' + else: + assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).' + use_amp = 'native' + assert args.amp_dtype in ('float16', 'bfloat16') + if args.amp_dtype == 'bfloat16': + amp_dtype = torch.bfloat16 utils.random_seed(args.seed, args.rank) @@ -484,7 +483,7 @@ def main(): batch_ratio = global_batch_size / args.lr_base_size if not args.lr_base_scale: on = args.opt.lower() - args.base_scale = 'sqrt' if any([o in on for o in ('ada', 'lamb')]) else 'linear' + args.lr_base_scale = 'sqrt' if any([o in on for o in ('ada', 'lamb')]) else 'linear' if args.lr_base_scale == 'sqrt': batch_ratio = batch_ratio ** 0.5 args.lr = args.lr_base * batch_ratio @@ -505,7 +504,7 @@ def main(): if utils.is_primary(args): _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif use_amp == 'native': - amp_autocast = partial(torch.autocast, device_type=device.type) + amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype) if device.type == 'cuda': loss_scaler = NativeScaler() if utils.is_primary(args):