diff --git a/train.py b/train.py index 3943c7d0..84d8b2ea 100755 --- a/train.py +++ b/train.py @@ -397,7 +397,7 @@ def main(): # setup synchronized BatchNorm for distributed training if args.distributed and args.sync_bn: assert not args.split_bn - if has_apex and use_amp != 'native': + if has_apex and use_amp == 'apex': # Apex SyncBN preferred unless native amp is activated model = convert_syncbn_model(model) else: @@ -451,7 +451,7 @@ def main(): # setup distributed training if args.distributed: - if has_apex and use_amp != 'native': + if has_apex and use_amp == 'apex': # Apex DDP preferred unless native amp is activated if args.local_rank == 0: _logger.info("Using NVIDIA APEX DistributedDataParallel.")