Make broadcast_buffers disable its own flag for now (needs more testing on interaction with dist_bn)

pull/880/head
Ross Wightman 3 years ago
parent b1c2e3eb92
commit d9abfa48df

@ -270,6 +270,8 @@ parser.add_argument('--apex-amp', action='store_true', default=False,
help='Use NVIDIA Apex AMP mixed precision') help='Use NVIDIA Apex AMP mixed precision')
parser.add_argument('--native-amp', action='store_true', default=False, parser.add_argument('--native-amp', action='store_true', default=False,
help='Use Native Torch AMP mixed precision') help='Use Native Torch AMP mixed precision')
parser.add_argument('--no-ddp-bb', action='store_true', default=False,
help='Force broadcast buffers for native DDP to off.')
parser.add_argument('--channels-last', action='store_true', default=False, parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout') help='Use channels_last memory layout')
parser.add_argument('--pin-mem', action='store_true', default=False, parser.add_argument('--pin-mem', action='store_true', default=False,
@ -463,7 +465,7 @@ def main():
else: else:
if args.local_rank == 0: if args.local_rank == 0:
_logger.info("Using native Torch DistributedDataParallel.") _logger.info("Using native Torch DistributedDataParallel.")
model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.dist_bn) model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb)
# NOTE: EMA model does not need to be wrapped by DDP # NOTE: EMA model does not need to be wrapped by DDP
# setup learning rate schedule and starting epoch # setup learning rate schedule and starting epoch

Loading…
Cancel
Save