|
|
|
@ -319,21 +319,54 @@ def main():
|
|
|
|
|
|
|
|
|
|
args.prefetcher = not args.no_prefetcher
|
|
|
|
|
args.distributed = False
|
|
|
|
|
if 'WORLD_SIZE' in os.environ:
|
|
|
|
|
if 'SLURM_PROCID' in os.environ:
|
|
|
|
|
backend='nccl'
|
|
|
|
|
port = "13333"
|
|
|
|
|
proc_id = int(os.environ['SLURM_PROCID'])
|
|
|
|
|
ntasks = int(os.environ['SLURM_NTASKS'])
|
|
|
|
|
node_list = os.environ['SLURM_NODELIST']
|
|
|
|
|
if '[' in node_list:
|
|
|
|
|
beg = node_list.find('[')
|
|
|
|
|
pos1 = node_list.find('-', beg)
|
|
|
|
|
if pos1 < 0:
|
|
|
|
|
pos1 = 1000
|
|
|
|
|
pos2 = node_list.find(',', beg)
|
|
|
|
|
if pos2 < 0:
|
|
|
|
|
pos2 = 1000
|
|
|
|
|
node_list = node_list[:min(pos1, pos2)].replace('[', '')
|
|
|
|
|
addr = node_list[8:].replace('-', '.')
|
|
|
|
|
os.environ['MASTER_PORT'] = port
|
|
|
|
|
os.environ['MASTER_ADDR'] = addr
|
|
|
|
|
os.environ['WORLD_SIZE'] = str(ntasks)
|
|
|
|
|
os.environ['RANK'] = str(proc_id)
|
|
|
|
|
if backend == 'nccl':
|
|
|
|
|
torch.distributed.init_process_group(backend='nccl')
|
|
|
|
|
else:
|
|
|
|
|
torch.distributed.init_process_group(backend='gloo', rank=proc_id, world_size=ntasks)
|
|
|
|
|
rank = torch.distributed.get_rank()
|
|
|
|
|
device = rank % torch.cuda.device_count()
|
|
|
|
|
torch.cuda.set_device(device)
|
|
|
|
|
|
|
|
|
|
args.distributed = int(os.environ['WORLD_SIZE']) > 1
|
|
|
|
|
args.device = 'cuda:0'
|
|
|
|
|
args.world_size = 1
|
|
|
|
|
args.rank = 0 # global rank
|
|
|
|
|
if args.distributed:
|
|
|
|
|
args.device = 'cuda:%d' % args.local_rank
|
|
|
|
|
torch.cuda.set_device(args.local_rank)
|
|
|
|
|
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
|
|
|
|
args.world_size = torch.distributed.get_world_size()
|
|
|
|
|
args.rank = torch.distributed.get_rank()
|
|
|
|
|
_logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
|
|
|
|
|
% (args.rank, args.world_size))
|
|
|
|
|
else:
|
|
|
|
|
_logger.info('Training with a single process on 1 GPUs.')
|
|
|
|
|
args.world_size = int(os.environ['WORLD_SIZE'])
|
|
|
|
|
args.rank = rank
|
|
|
|
|
args.local_rank = rank
|
|
|
|
|
args.device = device
|
|
|
|
|
elif 'WORLD_SIZE' in os.environ:
|
|
|
|
|
args.distributed = int(os.environ['WORLD_SIZE']) > 1
|
|
|
|
|
args.device = 'cuda:0'
|
|
|
|
|
args.world_size = 1
|
|
|
|
|
args.rank = 0 # global rank
|
|
|
|
|
if args.distributed:
|
|
|
|
|
args.device = 'cuda:%d' % args.local_rank
|
|
|
|
|
torch.cuda.set_device(args.local_rank)
|
|
|
|
|
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
|
|
|
|
args.world_size = torch.distributed.get_world_size()
|
|
|
|
|
args.rank = torch.distributed.get_rank()
|
|
|
|
|
_logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
|
|
|
|
|
% (args.rank, args.world_size))
|
|
|
|
|
else:
|
|
|
|
|
_logger.info('Training with a single process on 1 GPUs.')
|
|
|
|
|
assert args.rank >= 0
|
|
|
|
|
|
|
|
|
|
# resolve AMP arguments based on PyTorch / Apex availability
|
|
|
|
@ -524,7 +557,6 @@ def main():
|
|
|
|
|
vflip=args.vflip,
|
|
|
|
|
color_jitter=args.color_jitter,
|
|
|
|
|
auto_augment=args.aa,
|
|
|
|
|
num_aug_repeats=args.aug_repeats,
|
|
|
|
|
num_aug_splits=num_aug_splits,
|
|
|
|
|
interpolation=train_interpolation,
|
|
|
|
|
mean=data_config['mean'],
|
|
|
|
|