diff --git a/train.py b/train.py index 285981fd..e5d40566 100755 --- a/train.py +++ b/train.py @@ -355,6 +355,8 @@ def main(): args.world_size = 1 args.rank = 0 # global rank if args.distributed: + if 'LOCAL_RANK' in os.environ: + args.local_rank = int(os.getenv('LOCAL_RANK')) args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://')