Merge pull request #1336 from xwang233/add-local-rank

Make train.py compatible with torchrun
pull/1340/head
Ross Wightman 2 years ago committed by GitHub
commit 2456223052
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -355,6 +355,8 @@ def main():
args.world_size = 1 args.world_size = 1
args.rank = 0 # global rank args.rank = 0 # global rank
if args.distributed: if args.distributed:
if 'LOCAL_RANK' in os.environ:
args.local_rank = int(os.getenv('LOCAL_RANK'))
args.device = 'cuda:%d' % args.local_rank args.device = 'cuda:%d' % args.local_rank
torch.cuda.set_device(args.local_rank) torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://') torch.distributed.init_process_group(backend='nccl', init_method='env://')

Loading…
Cancel
Save