Merge pull request #882 from ShoufaChen/master

fix `use_amp`
pull/898/head
Ross Wightman 3 years ago committed by GitHub
commit 3f9959cdd2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -397,7 +397,7 @@ def main():
# setup synchronized BatchNorm for distributed training # setup synchronized BatchNorm for distributed training
if args.distributed and args.sync_bn: if args.distributed and args.sync_bn:
assert not args.split_bn assert not args.split_bn
if has_apex and use_amp != 'native': if has_apex and use_amp == 'apex':
# Apex SyncBN preferred unless native amp is activated # Apex SyncBN preferred unless native amp is activated
model = convert_syncbn_model(model) model = convert_syncbn_model(model)
else: else:
@ -451,7 +451,7 @@ def main():
# setup distributed training # setup distributed training
if args.distributed: if args.distributed:
if has_apex and use_amp != 'native': if has_apex and use_amp == 'apex':
# Apex DDP preferred unless native amp is activated # Apex DDP preferred unless native amp is activated
if args.local_rank == 0: if args.local_rank == 0:
_logger.info("Using NVIDIA APEX DistributedDataParallel.") _logger.info("Using NVIDIA APEX DistributedDataParallel.")

Loading…
Cancel
Save