You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
190 lines
7.2 KiB
190 lines
7.2 KiB
4 years ago
|
""" PyTorch distributed helpers
|
||
|
|
||
|
Some of this lifted from Detectron2 with other fns added by myself.
|
||
|
|
||
|
FIXME many functions remain unfinished/untested
|
||
|
"""
|
||
|
from typing import Dict, Tuple, List, Union, Any, Callable
|
||
|
|
||
|
import torch
|
||
|
import torch.distributed as dist
|
||
|
from torch.distributed import ReduceOp
|
||
|
|
||
|
TensorSeq = Union[torch.Tensor, Tuple[torch.Tensor, ...], List[torch.Tensor], Dict[Any, torch.Tensor]]
|
||
|
|
||
|
|
||
|
def synchronize_torch():
|
||
|
"""
|
||
|
Helper function to synchronize (barrier) among all processes when
|
||
|
using distributed training
|
||
|
"""
|
||
|
if not dist.is_available():
|
||
|
return
|
||
|
if not dist.is_initialized():
|
||
|
return
|
||
|
world_size = dist.get_world_size()
|
||
|
if world_size == 1:
|
||
|
return
|
||
|
dist.barrier()
|
||
|
|
||
|
|
||
|
def all_reduce_sequence_torch(values: TensorSeq, op=ReduceOp.SUM, average=False, group=None):
|
||
|
"""
|
||
|
All reduce the tensors in a sequence (dict, list, tuple)
|
||
|
|
||
|
Args:
|
||
|
values (dict): inputs to be reduced. All the values must be scalar Tensor.
|
||
|
average (bool): whether to do average or sum
|
||
|
Returns:
|
||
|
a sequence with the same type as input (dict, list, tuple)
|
||
|
"""
|
||
|
world_size = dist.get_world_size(group)
|
||
|
if world_size <= 1:
|
||
|
return values
|
||
|
|
||
|
with torch.no_grad():
|
||
|
names = None
|
||
|
if isinstance(values, dict):
|
||
|
names = values.keys()
|
||
|
reduce_values = torch.stack(tuple(values.values()), dim=0)
|
||
|
elif isinstance(values, (tuple, list)):
|
||
|
reduce_values = torch.stack(values, dim=0)
|
||
|
else:
|
||
|
reduce_values = values
|
||
|
dist.all_reduce(reduce_values, op=op, group=group)
|
||
|
if average:
|
||
|
reduce_values /= world_size
|
||
|
if isinstance(values, dict):
|
||
|
reduce_values = {k: v for k, v in zip(names, reduce_values)}
|
||
|
elif isinstance(values, (tuple, list)):
|
||
|
reduce_values = type(values)(v for v in reduce_values)
|
||
|
return reduce_values
|
||
|
|
||
|
|
||
|
def reduce_sequence_torch(values: TensorSeq, dst_rank=0, op=ReduceOp.SUM, average=False, group=None):
|
||
|
"""
|
||
|
All reduce the tensors in a sequence (dict, list, tuple)
|
||
|
|
||
|
Args:
|
||
|
values (dict): inputs to be reduced. All the values must be scalar Tensor.
|
||
|
average (bool): whether to do average or sum
|
||
|
Returns:
|
||
|
a sequence with the same type as input (dict, list, tuple)
|
||
|
"""
|
||
|
world_size = dist.get_world_size(group)
|
||
|
this_rank = dist.get_rank()
|
||
|
if world_size <= 1:
|
||
|
return values
|
||
|
|
||
|
with torch.no_grad():
|
||
|
names = None
|
||
|
if isinstance(values, dict):
|
||
|
names = values.keys()
|
||
|
reduce_values = torch.stack(tuple(values.values()), dim=0)
|
||
|
elif isinstance(values, (tuple, list)):
|
||
|
reduce_values = torch.stack(values, dim=0)
|
||
|
else:
|
||
|
reduce_values = values
|
||
|
reduce_values = torch.stack(reduce_values, dim=0)
|
||
|
dist.reduce(reduce_values, dst=dst_rank, op=op, group=group)
|
||
|
if average and this_rank == dst_rank:
|
||
|
reduce_values /= world_size
|
||
|
if isinstance(values, dict):
|
||
|
reduce_values = {k: v for k, v in zip(names, reduce_values)}
|
||
|
elif isinstance(values, (tuple, list)):
|
||
|
reduce_values = type(values)(v for v in reduce_values)
|
||
|
return reduce_values
|
||
|
|
||
|
|
||
|
def all_gather_sequence_torch(values: TensorSeq, group=None, join_fn=torch.cat, join_dim=0):
|
||
|
world_size = dist.get_world_size(group)
|
||
|
|
||
|
def _do_gather(tensor):
|
||
|
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
|
||
|
dist.all_gather(tensor_list, tensor, group=group)
|
||
|
return join_fn(tensor_list, dim=join_dim)
|
||
|
|
||
|
if isinstance(values, dict):
|
||
|
gathered = {k: _do_gather(v) for k, v in values.items()}
|
||
|
return gathered
|
||
|
elif isinstance(values, (list, tuple)):
|
||
|
gathered = type(values)(_do_gather(v) for v in values)
|
||
|
return gathered
|
||
|
else:
|
||
|
# if not a dict, list, tuple, expect a singular tensor
|
||
|
assert isinstance(values, torch.Tensor)
|
||
|
return _do_gather(values)
|
||
|
|
||
|
|
||
|
def gather_sequence_torch(values: TensorSeq, dst_rank, group=None, join_fn=torch.cat, join_dim=0):
|
||
|
world_size = dist.get_world_size(group)
|
||
|
this_rank = dist.get_rank(group)
|
||
|
|
||
|
def _do_gather(tensor):
|
||
|
tensor_list = None
|
||
|
if this_rank == dst_rank:
|
||
|
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
|
||
|
dist.gather(tensor, tensor_list, dst=dst_rank, group=group)
|
||
|
return join_fn(tensor_list, dim=join_dim)
|
||
|
|
||
|
if isinstance(values, dict):
|
||
|
gathered = {k: _do_gather(v) for k, v in values.items()}
|
||
|
return gathered
|
||
|
elif isinstance(values, (list, tuple)):
|
||
|
gathered = type(values)(_do_gather(v) for v in values)
|
||
|
return gathered
|
||
|
else:
|
||
|
# if not a dict, list, tuple, expect a singular tensor
|
||
|
assert isinstance(values, torch.Tensor)
|
||
|
return _do_gather(values)
|
||
|
|
||
|
|
||
|
def all_gather_torch(value: TensorSeq, group=None, join_fn: Callable = None, join_dim=0):
|
||
|
if isinstance(value, torch.Tensor):
|
||
|
world_size = dist.get_world_size(group)
|
||
|
out_tensors = [torch.empty_like(value) for _ in range(world_size)]
|
||
|
dist.all_gather(out_tensors, value, group=group)
|
||
|
if join_fn is not None:
|
||
|
out_tensors = join_fn(out_tensors, dim=join_dim)
|
||
|
return out_tensors
|
||
|
elif isinstance(value, dict):
|
||
|
return {k: all_gather_torch(v, group, join_fn, join_dim) for k, v in value.items()}
|
||
|
elif isinstance(value, (tuple, list)):
|
||
|
return type(value)(all_gather_torch(v, group, join_fn, join_dim) for v in value)
|
||
|
|
||
|
|
||
|
def gather_torch(value: TensorSeq, dst_rank=0, group=None, join_fn: Callable = None, join_dim=0):
|
||
|
if isinstance(value, torch.Tensor):
|
||
|
world_size = dist.get_world_size(group)
|
||
|
this_rank = dist.get_rank()
|
||
|
out_tensors = None
|
||
|
if this_rank == dst_rank:
|
||
|
out_tensors = [torch.empty_like(value) for _ in range(world_size)]
|
||
|
dist.gather(value, out_tensors, dst=dst_rank, group=group)
|
||
|
if join_fn is not None:
|
||
|
out_tensors = join_fn(out_tensors, dim=join_dim)
|
||
|
return out_tensors
|
||
|
elif isinstance(value, dict):
|
||
|
return {k: gather_torch(v, dst_rank, group, join_fn, join_dim) for k, v in value.items()}
|
||
|
elif isinstance(value, (tuple, list)):
|
||
|
return type(value)(gather_torch(v, dst_rank, group, join_fn, join_dim) for v in value)
|
||
|
|
||
|
|
||
|
def all_reduce_torch(value: TensorSeq, op=ReduceOp.SUM, average=False, group=None):
|
||
|
if isinstance(value, torch.Tensor):
|
||
|
dist.all_reduce(value, op=op, group=group)
|
||
|
if average:
|
||
|
value /= dist.get_world_size(group)
|
||
|
elif isinstance(value, dict):
|
||
|
return {k: all_reduce_torch(v, op=op, average=average, group=group) for k, v in value.items()}
|
||
|
elif isinstance(value, (tuple, list)):
|
||
|
return type(value)(all_reduce_torch(v, op=op, average=average, group=group) for v in value)
|
||
|
|
||
|
|
||
|
def broadcast_torch(value: TensorSeq, src_rank: int = 0, group=None):
|
||
|
if isinstance(value, torch.Tensor):
|
||
|
return dist.broadcast(value, src=src_rank, group=group)
|
||
|
elif isinstance(value, dict):
|
||
|
return {k: broadcast_torch(v, src_rank=src_rank, group=group) for k, v in value.items()}
|
||
|
elif isinstance(value, (tuple, list)):
|
||
|
return type(value)(broadcast_torch(v, src_rank=src_rank, group=group) for v in value)
|