|
|
|
@ -397,7 +397,7 @@ def main():
|
|
|
|
|
# setup synchronized BatchNorm for distributed training
|
|
|
|
|
if args.distributed and args.sync_bn:
|
|
|
|
|
assert not args.split_bn
|
|
|
|
|
if has_apex and use_amp != 'native':
|
|
|
|
|
if has_apex and use_amp == 'apex':
|
|
|
|
|
# Apex SyncBN preferred unless native amp is activated
|
|
|
|
|
model = convert_syncbn_model(model)
|
|
|
|
|
else:
|
|
|
|
@ -451,7 +451,7 @@ def main():
|
|
|
|
|
|
|
|
|
|
# setup distributed training
|
|
|
|
|
if args.distributed:
|
|
|
|
|
if has_apex and use_amp != 'native':
|
|
|
|
|
if has_apex and use_amp == 'apex':
|
|
|
|
|
# Apex DDP preferred unless native amp is activated
|
|
|
|
|
if args.local_rank == 0:
|
|
|
|
|
_logger.info("Using NVIDIA APEX DistributedDataParallel.")
|
|
|
|
|