setup torch.distributed for slurm platform and set num_aug_splits

pull/874/head
YangYang 4 years ago committed by GitHub
parent a6e8598aaf
commit 83a60f55fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -319,21 +319,54 @@ def main():
args.prefetcher = not args.no_prefetcher args.prefetcher = not args.no_prefetcher
args.distributed = False args.distributed = False
if 'WORLD_SIZE' in os.environ: if 'SLURM_PROCID' in os.environ:
backend='nccl'
port = "13333"
proc_id = int(os.environ['SLURM_PROCID'])
ntasks = int(os.environ['SLURM_NTASKS'])
node_list = os.environ['SLURM_NODELIST']
if '[' in node_list:
beg = node_list.find('[')
pos1 = node_list.find('-', beg)
if pos1 < 0:
pos1 = 1000
pos2 = node_list.find(',', beg)
if pos2 < 0:
pos2 = 1000
node_list = node_list[:min(pos1, pos2)].replace('[', '')
addr = node_list[8:].replace('-', '.')
os.environ['MASTER_PORT'] = port
os.environ['MASTER_ADDR'] = addr
os.environ['WORLD_SIZE'] = str(ntasks)
os.environ['RANK'] = str(proc_id)
if backend == 'nccl':
torch.distributed.init_process_group(backend='nccl')
else:
torch.distributed.init_process_group(backend='gloo', rank=proc_id, world_size=ntasks)
rank = torch.distributed.get_rank()
device = rank % torch.cuda.device_count()
torch.cuda.set_device(device)
args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.distributed = int(os.environ['WORLD_SIZE']) > 1
args.device = 'cuda:0' args.world_size = int(os.environ['WORLD_SIZE'])
args.world_size = 1 args.rank = rank
args.rank = 0 # global rank args.local_rank = rank
if args.distributed: args.device = device
args.device = 'cuda:%d' % args.local_rank elif 'WORLD_SIZE' in os.environ:
torch.cuda.set_device(args.local_rank) args.distributed = int(os.environ['WORLD_SIZE']) > 1
torch.distributed.init_process_group(backend='nccl', init_method='env://') args.device = 'cuda:0'
args.world_size = torch.distributed.get_world_size() args.world_size = 1
args.rank = torch.distributed.get_rank() args.rank = 0 # global rank
_logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' if args.distributed:
% (args.rank, args.world_size)) args.device = 'cuda:%d' % args.local_rank
else: torch.cuda.set_device(args.local_rank)
_logger.info('Training with a single process on 1 GPUs.') torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
_logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (args.rank, args.world_size))
else:
_logger.info('Training with a single process on 1 GPUs.')
assert args.rank >= 0 assert args.rank >= 0
# resolve AMP arguments based on PyTorch / Apex availability # resolve AMP arguments based on PyTorch / Apex availability
@ -524,7 +557,6 @@ def main():
vflip=args.vflip, vflip=args.vflip,
color_jitter=args.color_jitter, color_jitter=args.color_jitter,
auto_augment=args.aa, auto_augment=args.aa,
num_aug_repeats=args.aug_repeats,
num_aug_splits=num_aug_splits, num_aug_splits=num_aug_splits,
interpolation=train_interpolation, interpolation=train_interpolation,
mean=data_config['mean'], mean=data_config['mean'],

Loading…
Cancel
Save