You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
150 lines
5.6 KiB
150 lines
5.6 KiB
from typing import Dict, Tuple, List, Union, Any, Callable
|
|
|
|
import torch
|
|
from torch.distributed import ReduceOp
|
|
|
|
from timm.utils import unwrap_model
|
|
|
|
from .device_env import DeviceEnv
|
|
|
|
|
|
TensorSeq = Union[torch.Tensor, Tuple[torch.Tensor, ...], List[torch.Tensor], Dict[Any, torch.Tensor]]
|
|
|
|
|
|
def _validate_type(tensor: TensorSeq):
|
|
if isinstance(tensor, (dict, list, tuple)):
|
|
if not tensor:
|
|
return
|
|
else:
|
|
assert isinstance(tensor, torch.Tensor)
|
|
|
|
|
|
def distribute_bn(model: torch.nn.Module, reduce: bool = False, dev_env: DeviceEnv = None):
|
|
if dev_env is None:
|
|
dev_env = DeviceEnv.instance()
|
|
# ensure every node has the same running bn stats
|
|
for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True):
|
|
if ('running_mean' in bn_name) or ('running_var' in bn_name):
|
|
if reduce:
|
|
# average bn stats across whole group
|
|
dev_env.all_reduce_(bn_buf, average=True)
|
|
else:
|
|
# broadcast bn stats from rank 0 to whole group
|
|
dev_env.broadcast_(bn_buf, 0)
|
|
|
|
|
|
def all_gather_recursive(tensor: TensorSeq, cat_dim=0, dev_env: DeviceEnv = None):
|
|
""" Recursive all gather via DeviceEnv distributed primitives
|
|
FIXME add group support
|
|
"""
|
|
_validate_type(tensor)
|
|
if dev_env is None:
|
|
dev_env = DeviceEnv.instance()
|
|
if isinstance(tensor, torch.Tensor):
|
|
return dev_env.all_gather(tensor, cat_dim=cat_dim)
|
|
elif isinstance(tensor, dict):
|
|
return {k: all_gather_recursive(v, dev_env=dev_env) for k, v in tensor.items()}
|
|
elif isinstance(tensor, (tuple, list)):
|
|
return type(tensor)(all_gather_recursive(v, dev_env=dev_env) for v in tensor)
|
|
|
|
|
|
def all_reduce_recursive(tensor: TensorSeq, op=ReduceOp.SUM, average=False, dev_env: DeviceEnv = None):
|
|
""" Recursive all reduce via DeviceEnv distributed primitives
|
|
FIXME add group support
|
|
"""
|
|
_validate_type(tensor)
|
|
if dev_env is None:
|
|
dev_env = DeviceEnv.instance()
|
|
if isinstance(tensor, torch.Tensor):
|
|
return dev_env.all_reduce_(tensor, op=op, average=average)
|
|
elif isinstance(tensor, dict):
|
|
return {k: all_reduce_recursive(v, op=op, average=average, dev_env=dev_env) for k, v in tensor.items()}
|
|
elif isinstance(tensor, (tuple, list)):
|
|
return type(tensor)(all_reduce_recursive(v, op=op, average=average, dev_env=dev_env) for v in tensor)
|
|
|
|
|
|
def broadcast_recursive(tensor: TensorSeq, src_rank: int, dev_env: DeviceEnv = None):
|
|
""" Recursive broadcast via DeviceEnv distributed primitives
|
|
FIXME add group support
|
|
"""
|
|
_validate_type(tensor)
|
|
if dev_env is None:
|
|
dev_env = DeviceEnv.instance()
|
|
if isinstance(tensor, torch.Tensor):
|
|
return dev_env.broadcast_(tensor, src_rank=src_rank)
|
|
elif isinstance(tensor, dict):
|
|
return {k: broadcast_recursive(v, src_rank=src_rank, dev_env=dev_env) for k, v in tensor.items()}
|
|
elif isinstance(tensor, (tuple, list)):
|
|
return type(tensor)(broadcast_recursive(v, src_rank=src_rank, dev_env=dev_env) for v in tensor)
|
|
|
|
|
|
def all_gather_sequence(tensor: TensorSeq, cat_dim: int = 0, dev_env: DeviceEnv = None):
|
|
""" All gather a flat Tensor sequence (dict, list, tuple) of same shape
|
|
|
|
"""
|
|
_validate_type(tensor)
|
|
if dev_env is None:
|
|
dev_env = DeviceEnv.instance()
|
|
|
|
with torch.no_grad():
|
|
names = None
|
|
# merge values into one tensor for reduction
|
|
if isinstance(tensor, dict):
|
|
names = tensor.keys()
|
|
gather_values = tuple(tensor.values())
|
|
elif isinstance(tensor, (tuple, list)):
|
|
gather_values = tensor
|
|
else:
|
|
gather_values = (tensor,)
|
|
|
|
gather_values = torch.stack(gather_values, dim=0)
|
|
gather_values = dev_env.all_gather(gather_values, cat_dim=cat_dim + 1).unbind(dim=0)
|
|
|
|
# separate reduced values into original structure
|
|
if isinstance(tensor, dict):
|
|
gather_values = {k: v for k, v in zip(names, gather_values)}
|
|
elif isinstance(tensor, (tuple, list)):
|
|
gather_values = type(tensor)(v for v in gather_values)
|
|
else:
|
|
gather_values = gather_values[0]
|
|
|
|
return gather_values
|
|
|
|
|
|
def all_reduce_sequence(tensor: TensorSeq, op=ReduceOp.SUM, average=False, dev_env: DeviceEnv = None):
|
|
"""
|
|
All reduce the tensors in a flat Tensor sequence (dict, list, tuple) of same tensor shape
|
|
|
|
Args:
|
|
tensor (dict): inputs to be reduced. All the values must be scalar Tensor.
|
|
average (bool): whether to do average or sum
|
|
Returns:
|
|
a sequence with the same type as input (dict, list, tuple)
|
|
"""
|
|
_validate_type(tensor)
|
|
if dev_env is None:
|
|
dev_env = DeviceEnv.instance()
|
|
|
|
with torch.no_grad():
|
|
names = None
|
|
# merge values into one tensor for reduction
|
|
if isinstance(tensor, dict):
|
|
names = tensor.keys()
|
|
reduce_values = tuple(tensor.values())
|
|
elif isinstance(tensor, (tuple, list)):
|
|
reduce_values = tensor
|
|
else:
|
|
reduce_values = (tensor,)
|
|
|
|
reduce_values = torch.stack(reduce_values, dim=0)
|
|
dev_env.all_reduce_(reduce_values, op=op, average=average)
|
|
reduce_values = reduce_values.unbind(dim=0)
|
|
# separate reduced values into original structure
|
|
if isinstance(tensor, dict):
|
|
reduce_values = {k: v for k, v in zip(names, reduce_values)}
|
|
elif isinstance(tensor, (tuple, list)):
|
|
reduce_values = type(tensor)(v for v in reduce_values)
|
|
else:
|
|
reduce_values = reduce_values[0]
|
|
|
|
return reduce_values |