setup torch.distributed for slurm platform and set num_aug_splits

pull/874/head
YangYang 4 years ago committed by GitHub
parent a6e8598aaf
commit 83a60f55fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -319,21 +319,54 @@ def main():
args.prefetcher = not args.no_prefetcher
args.distributed = False
if 'WORLD_SIZE' in os.environ:
if 'SLURM_PROCID' in os.environ:
backend='nccl'
port = "13333"
proc_id = int(os.environ['SLURM_PROCID'])
ntasks = int(os.environ['SLURM_NTASKS'])
node_list = os.environ['SLURM_NODELIST']
if '[' in node_list:
beg = node_list.find('[')
pos1 = node_list.find('-', beg)
if pos1 < 0:
pos1 = 1000
pos2 = node_list.find(',', beg)
if pos2 < 0:
pos2 = 1000
node_list = node_list[:min(pos1, pos2)].replace('[', '')
addr = node_list[8:].replace('-', '.')
os.environ['MASTER_PORT'] = port
os.environ['MASTER_ADDR'] = addr
os.environ['WORLD_SIZE'] = str(ntasks)
os.environ['RANK'] = str(proc_id)
if backend == 'nccl':
torch.distributed.init_process_group(backend='nccl')
else:
torch.distributed.init_process_group(backend='gloo', rank=proc_id, world_size=ntasks)
rank = torch.distributed.get_rank()
device = rank % torch.cuda.device_count()
torch.cuda.set_device(device)
args.distributed = int(os.environ['WORLD_SIZE']) > 1
args.device = 'cuda:0'
args.world_size = 1
args.rank = 0 # global rank
if args.distributed:
args.device = 'cuda:%d' % args.local_rank
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
_logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (args.rank, args.world_size))
else:
_logger.info('Training with a single process on 1 GPUs.')
args.world_size = int(os.environ['WORLD_SIZE'])
args.rank = rank
args.local_rank = rank
args.device = device
elif 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
args.device = 'cuda:0'
args.world_size = 1
args.rank = 0 # global rank
if args.distributed:
args.device = 'cuda:%d' % args.local_rank
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
_logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (args.rank, args.world_size))
else:
_logger.info('Training with a single process on 1 GPUs.')
assert args.rank >= 0
# resolve AMP arguments based on PyTorch / Apex availability
@ -524,7 +557,6 @@ def main():
vflip=args.vflip,
color_jitter=args.color_jitter,
auto_augment=args.aa,
num_aug_repeats=args.aug_repeats,
num_aug_splits=num_aug_splits,
interpolation=train_interpolation,
mean=data_config['mean'],

Loading…
Cancel
Save