import abc from contextlib import suppress from enum import Enum from typing import Callable, Union, Optional, List, Tuple from dataclasses import dataclass, field, InitVar import torch import torch.distributed as dist TensorList = Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]] class DeviceEnvType(Enum): """ Device Environment Types """ CPU = "cpu" CUDA = "cuda" XLA = "xla" @dataclass class DeviceEnv: device_type: InitVar[Optional[str]] = None device_index: InitVar[Optional[int]] = None device: torch.device = field(init=False) # set from device_type + device_index or post_init logic world_size: Optional[int] = None # set by post_init from env when None local_rank: Optional[int] = None # set by post_init from env when None global_rank: Optional[int] = None # set by post_init from env when None amp: bool = False autocast: Optional[Callable] = None # set by post_init from env when None memory_format: Optional[torch.memory_format] = None dtype: Optional[torch.dtype] = None def __post_init__(self, device_type: Optional[str], device_index: Optional[int]): device_type = device_type or 'cpu' self.device = torch.device(device_type) if device_index is None \ else torch.device(device_type, device_index) self.world_size = 1 if self.world_size is None else self.world_size self.local_rank = 0 if self.local_rank is None else self.local_rank self.global_rank = 0 if self.global_rank is None else self.global_rank if self.autocast is None: self.autocast = suppress @property def type(self) -> DeviceEnvType: if self.device.type == 'cpu': return DeviceEnvType.CPU elif self.device.type == 'cuda': return DeviceEnvType.CUDA elif self.device.type == 'xla': return DeviceEnvType.XLA else: assert False, "Unexpected device type for base DevEnv impl." @property def type_cuda(self): # shortcut for common cuda device type return self.type == DeviceEnvType.CUDA @property def type_xla(self): # shortcut for common xla device type return self.type == DeviceEnvType.XLA @property def distributed(self): return self.world_size > 1 @property def primary(self): return self.local_rank == 0 @property def global_primary(self): return self.global_rank == 0 def wrap_distributed(self, *modules): pass def wrap_parallel(self, *modules): pass def to_device(self, *modules: torch.nn.Module): # FIXME handling dtype / memformat... disable flags, enable flags, diff fn? moved = [m.to(device=self.device, memory_format=self.memory_format) for m in modules] return moved[0] if len(moved) == 1 else moved def mark_step(self): pass # NO-OP for non-XLA devices def all_reduce_(self, tensor: TensorList, op=dist.ReduceOp.SUM, average=False): dist.all_reduce(tensor, op=op) if average: tensor.div_(self.world_size) return tensor def all_reduce(self, tensor: torch.Tensor, op=dist.ReduceOp.SUM, average=False): reduce_tensor = tensor.clone() dist.all_reduce(reduce_tensor, op=op) if average: reduce_tensor = reduce_tensor / self.world_size return reduce_tensor def all_gather(self, tensor: torch.Tensor, cat_dim=0): output_tensors = [torch.empty_like(tensor) for _ in range(self.world_size)] dist.all_gather(output_tensors, tensor) return torch.cat(output_tensors, cat_dim) def all_to_all(self, tensor: torch.Tensor, num_splits, split_dim, cat_dim=0): input_tensors = torch.chunk(tensor, num_splits, split_dim) output_tensors = [torch.empty_like(input_tensors[0]) for _ in range(self.world_size)] dist.all_to_all(output_tensors, input_tensors) return torch.cat(output_tensors, cat_dim) def broadcast_(self, tensor: torch.Tensor, src_rank=0): dist.broadcast(tensor, src=src_rank) return tensor def broadcast(self, tensor: Optional[torch.Tensor] = None, src_rank=0): if self.global_rank != src_rank: tensor = torch.empty_like(tensor) assert tensor is not None dist.broadcast(tensor, src=src_rank) return tensor def barrier(self): dist.barrier()