From d9abfa48df3090e6157fefa22e9ae05c28e62d07 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 1 Oct 2021 13:43:55 -0700 Subject: [PATCH] Make broadcast_buffers disable its own flag for now (needs more testing on interaction with dist_bn) --- train.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index d95611ad..785b99e2 100755 --- a/train.py +++ b/train.py @@ -270,6 +270,8 @@ parser.add_argument('--apex-amp', action='store_true', default=False, help='Use NVIDIA Apex AMP mixed precision') parser.add_argument('--native-amp', action='store_true', default=False, help='Use Native Torch AMP mixed precision') +parser.add_argument('--no-ddp-bb', action='store_true', default=False, + help='Force broadcast buffers for native DDP to off.') parser.add_argument('--channels-last', action='store_true', default=False, help='Use channels_last memory layout') parser.add_argument('--pin-mem', action='store_true', default=False, @@ -463,7 +465,7 @@ def main(): else: if args.local_rank == 0: _logger.info("Using native Torch DistributedDataParallel.") - model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.dist_bn) + model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb) # NOTE: EMA model does not need to be wrapped by DDP # setup learning rate schedule and starting epoch