From 11060f84c51d813d900ddf3a7b178b3e4fe87fb3 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Thu, 7 Jul 2022 14:44:55 -0700 Subject: [PATCH] make train.py compatible with torchrun --- train.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/train.py b/train.py index 285981fd..e5d40566 100755 --- a/train.py +++ b/train.py @@ -355,6 +355,8 @@ def main(): args.world_size = 1 args.rank = 0 # global rank if args.distributed: + if 'LOCAL_RANK' in os.environ: + args.local_rank = int(os.getenv('LOCAL_RANK')) args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://')