|
|
|
@ -315,10 +315,10 @@ group.add_argument('--save-images', action='store_true', default=False,
|
|
|
|
|
help='save images of input bathes every log interval for debugging')
|
|
|
|
|
group.add_argument('--amp', action='store_true', default=False,
|
|
|
|
|
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
|
|
|
|
|
group.add_argument('--apex-amp', action='store_true', default=False,
|
|
|
|
|
help='Use NVIDIA Apex AMP mixed precision')
|
|
|
|
|
group.add_argument('--native-amp', action='store_true', default=False,
|
|
|
|
|
help='Use Native Torch AMP mixed precision')
|
|
|
|
|
group.add_argument('--amp-dtype', default='float16', type=str,
|
|
|
|
|
help='lower precision AMP dtype (default: float16)')
|
|
|
|
|
group.add_argument('--amp-impl', default='native', type=str,
|
|
|
|
|
help='AMP impl to use, "native" or "apex" (default: native)')
|
|
|
|
|
group.add_argument('--no-ddp-bb', action='store_true', default=False,
|
|
|
|
|
help='Force broadcast buffers for native DDP to off.')
|
|
|
|
|
group.add_argument('--pin-mem', action='store_true', default=False,
|
|
|
|
@ -385,19 +385,18 @@ def main():
|
|
|
|
|
|
|
|
|
|
# resolve AMP arguments based on PyTorch / Apex availability
|
|
|
|
|
use_amp = None
|
|
|
|
|
amp_dtype = torch.float16
|
|
|
|
|
if args.amp:
|
|
|
|
|
# `--amp` chooses native amp before apex (APEX ver not actively maintained)
|
|
|
|
|
if has_native_amp:
|
|
|
|
|
args.native_amp = True
|
|
|
|
|
elif has_apex:
|
|
|
|
|
args.apex_amp = True
|
|
|
|
|
if args.apex_amp and has_apex:
|
|
|
|
|
if args.amp_impl == 'apex':
|
|
|
|
|
assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
|
|
|
|
|
use_amp = 'apex'
|
|
|
|
|
elif args.native_amp and has_native_amp:
|
|
|
|
|
assert args.amp_dtype == 'float16'
|
|
|
|
|
else:
|
|
|
|
|
assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
|
|
|
|
|
use_amp = 'native'
|
|
|
|
|
elif args.apex_amp or args.native_amp:
|
|
|
|
|
_logger.warning("Neither APEX or native Torch AMP is available, using float32. "
|
|
|
|
|
"Install NVIDA apex or upgrade to PyTorch 1.6")
|
|
|
|
|
assert args.amp_dtype in ('float16', 'bfloat16')
|
|
|
|
|
if args.amp_dtype == 'bfloat16':
|
|
|
|
|
amp_dtype = torch.bfloat16
|
|
|
|
|
|
|
|
|
|
utils.random_seed(args.seed, args.rank)
|
|
|
|
|
|
|
|
|
@ -484,7 +483,7 @@ def main():
|
|
|
|
|
batch_ratio = global_batch_size / args.lr_base_size
|
|
|
|
|
if not args.lr_base_scale:
|
|
|
|
|
on = args.opt.lower()
|
|
|
|
|
args.base_scale = 'sqrt' if any([o in on for o in ('ada', 'lamb')]) else 'linear'
|
|
|
|
|
args.lr_base_scale = 'sqrt' if any([o in on for o in ('ada', 'lamb')]) else 'linear'
|
|
|
|
|
if args.lr_base_scale == 'sqrt':
|
|
|
|
|
batch_ratio = batch_ratio ** 0.5
|
|
|
|
|
args.lr = args.lr_base * batch_ratio
|
|
|
|
@ -505,7 +504,7 @@ def main():
|
|
|
|
|
if utils.is_primary(args):
|
|
|
|
|
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
|
|
|
|
|
elif use_amp == 'native':
|
|
|
|
|
amp_autocast = partial(torch.autocast, device_type=device.type)
|
|
|
|
|
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
|
|
|
|
|
if device.type == 'cuda':
|
|
|
|
|
loss_scaler = NativeScaler()
|
|
|
|
|
if utils.is_primary(args):
|
|
|
|
|