|
|
@ -270,6 +270,8 @@ parser.add_argument('--apex-amp', action='store_true', default=False,
|
|
|
|
help='Use NVIDIA Apex AMP mixed precision')
|
|
|
|
help='Use NVIDIA Apex AMP mixed precision')
|
|
|
|
parser.add_argument('--native-amp', action='store_true', default=False,
|
|
|
|
parser.add_argument('--native-amp', action='store_true', default=False,
|
|
|
|
help='Use Native Torch AMP mixed precision')
|
|
|
|
help='Use Native Torch AMP mixed precision')
|
|
|
|
|
|
|
|
parser.add_argument('--no-ddp-bb', action='store_true', default=False,
|
|
|
|
|
|
|
|
help='Force broadcast buffers for native DDP to off.')
|
|
|
|
parser.add_argument('--channels-last', action='store_true', default=False,
|
|
|
|
parser.add_argument('--channels-last', action='store_true', default=False,
|
|
|
|
help='Use channels_last memory layout')
|
|
|
|
help='Use channels_last memory layout')
|
|
|
|
parser.add_argument('--pin-mem', action='store_true', default=False,
|
|
|
|
parser.add_argument('--pin-mem', action='store_true', default=False,
|
|
|
@ -463,7 +465,7 @@ def main():
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
if args.local_rank == 0:
|
|
|
|
if args.local_rank == 0:
|
|
|
|
_logger.info("Using native Torch DistributedDataParallel.")
|
|
|
|
_logger.info("Using native Torch DistributedDataParallel.")
|
|
|
|
model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.dist_bn)
|
|
|
|
model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb)
|
|
|
|
# NOTE: EMA model does not need to be wrapped by DDP
|
|
|
|
# NOTE: EMA model does not need to be wrapped by DDP
|
|
|
|
|
|
|
|
|
|
|
|
# setup learning rate schedule and starting epoch
|
|
|
|
# setup learning rate schedule and starting epoch
|
|
|
|