diff --git a/train.py b/train.py index 3943c7d0..4be594e8 100755 --- a/train.py +++ b/train.py @@ -319,21 +319,54 @@ def main(): args.prefetcher = not args.no_prefetcher args.distributed = False - if 'WORLD_SIZE' in os.environ: + if 'SLURM_PROCID' in os.environ: + backend='nccl' + port = "13333" + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + if '[' in node_list: + beg = node_list.find('[') + pos1 = node_list.find('-', beg) + if pos1 < 0: + pos1 = 1000 + pos2 = node_list.find(',', beg) + if pos2 < 0: + pos2 = 1000 + node_list = node_list[:min(pos1, pos2)].replace('[', '') + addr = node_list[8:].replace('-', '.') + os.environ['MASTER_PORT'] = port + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['RANK'] = str(proc_id) + if backend == 'nccl': + torch.distributed.init_process_group(backend='nccl') + else: + torch.distributed.init_process_group(backend='gloo', rank=proc_id, world_size=ntasks) + rank = torch.distributed.get_rank() + device = rank % torch.cuda.device_count() + torch.cuda.set_device(device) + args.distributed = int(os.environ['WORLD_SIZE']) > 1 - args.device = 'cuda:0' - args.world_size = 1 - args.rank = 0 # global rank - if args.distributed: - args.device = 'cuda:%d' % args.local_rank - torch.cuda.set_device(args.local_rank) - torch.distributed.init_process_group(backend='nccl', init_method='env://') - args.world_size = torch.distributed.get_world_size() - args.rank = torch.distributed.get_rank() - _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' - % (args.rank, args.world_size)) - else: - _logger.info('Training with a single process on 1 GPUs.') + args.world_size = int(os.environ['WORLD_SIZE']) + args.rank = rank + args.local_rank = rank + args.device = device + elif 'WORLD_SIZE' in os.environ: + args.distributed = int(os.environ['WORLD_SIZE']) > 1 + args.device = 'cuda:0' + args.world_size = 1 + args.rank = 0 # global rank + if args.distributed: + args.device = 'cuda:%d' % args.local_rank + torch.cuda.set_device(args.local_rank) + torch.distributed.init_process_group(backend='nccl', init_method='env://') + args.world_size = torch.distributed.get_world_size() + args.rank = torch.distributed.get_rank() + _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' + % (args.rank, args.world_size)) + else: + _logger.info('Training with a single process on 1 GPUs.') assert args.rank >= 0 # resolve AMP arguments based on PyTorch / Apex availability @@ -524,7 +557,6 @@ def main(): vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, - num_aug_repeats=args.aug_repeats, num_aug_splits=num_aug_splits, interpolation=train_interpolation, mean=data_config['mean'],