You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
117 lines
3.7 KiB
117 lines
3.7 KiB
import os
|
|
from contextlib import suppress
|
|
from dataclasses import dataclass, field, InitVar
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from torch.distributed import ReduceOp
|
|
|
|
try:
|
|
import torch_xla.core.xla_model as xm
|
|
_HAS_XLA = True
|
|
except ImportError as e:
|
|
xm = None
|
|
_HAS_XLA = False
|
|
|
|
try:
|
|
# only the very latest XLA builds have AMP
|
|
import torch_xla.amp as xa
|
|
except ImportError as e:
|
|
xa = None
|
|
|
|
from .device_env import DeviceEnv, DeviceEnvType, TensorList
|
|
|
|
|
|
_PT_TO_XM_OP = {
|
|
ReduceOp.SUM: xm.REDUCE_SUM,
|
|
ReduceOp.PRODUCT: xm.REDUCE_MUL,
|
|
ReduceOp.MIN: xm.REDUCE_MIN,
|
|
ReduceOp.MAX: xm.REDUCE_MAX,
|
|
ReduceOp.BAND: xm.REDUCE_AND,
|
|
ReduceOp.BOR: xm.REDUCE_OR,
|
|
}
|
|
|
|
|
|
def is_xla_available(xla_device_type=None):
|
|
if not _HAS_XLA:
|
|
return False
|
|
supported_devs = xm.get_xla_supported_devices(devkind=xla_device_type)
|
|
return len(supported_devs) >= 1
|
|
|
|
|
|
@dataclass
|
|
class DeviceEnvXla(DeviceEnv):
|
|
|
|
def __post_init__(self, device_type: Optional[str], device_idx: Optional[int]):
|
|
if device_type is not None:
|
|
device_type = device_type.upper()
|
|
assert device_type in ('TPU', 'GPU', 'CPU'), "XLA device type must be one of ('TPU', 'GPU', 'CPU')"
|
|
self.device = xm.xla_device(n=device_idx, devkind=device_type)
|
|
self.world_size = xm.xrt_world_size()
|
|
if self.distributed:
|
|
assert device_idx is None, "device_index is based on local rank for distributed XLA mode"
|
|
self.local_rank = xm.get_local_ordinal()
|
|
self.global_rank = xm.get_ordinal()
|
|
else:
|
|
self.local_rank = 0
|
|
self.global_rank = 0
|
|
if self.amp:
|
|
assert xa is not None, 'XLA AMP is not present on this build'
|
|
if self.autocast is None:
|
|
self.autocast = xa.autocast if self.amp else suppress
|
|
|
|
@property
|
|
def type(self) -> DeviceEnvType:
|
|
return DeviceEnvType.XLA
|
|
|
|
def wrap_distributed(self, *modules):
|
|
wrapped = [m for m in modules] # NO-OP
|
|
return wrapped[0] if len(wrapped) == 1 else wrapped
|
|
|
|
def wrap_parallel(self, *modules):
|
|
assert False, "Not implemented"
|
|
|
|
def mark_step(self):
|
|
xm.mark_step()
|
|
|
|
def all_reduce(self, tensor: torch.Tensor, op=ReduceOp.SUM, average=False):
|
|
assert isinstance(tensor, torch.Tensor) # unlike in-place variant, lists/tuples not allowed
|
|
op = _PT_TO_XM_OP[op]
|
|
scale = 1.0 / self.world_size if average else 1.0
|
|
return xm.all_reduce(op, tensor, scale=scale)
|
|
|
|
def all_reduce_(self, tensor: TensorList, op=ReduceOp.SUM, average=False):
|
|
op = _PT_TO_XM_OP[op]
|
|
scale = 1.0 / self.world_size if average else 1.0
|
|
wrapped = False
|
|
if isinstance(tensor, torch.Tensor):
|
|
tensor = [tensor] # bare tensors are not operated on in-place
|
|
wrapped = True
|
|
xm.all_reduce(op, tensor, scale=scale)
|
|
if wrapped:
|
|
tensor = tensor[0]
|
|
return tensor
|
|
|
|
def all_gather(self, tensor: torch.Tensor, cat_dim=0):
|
|
output = xm.all_gather(tensor, cat_dim)
|
|
return output
|
|
|
|
def all_to_all(self, tensor, num_splits, split_dim, cat_dim=0):
|
|
output = xm.all_to_all(tensor, split_dim, cat_dim, num_splits)
|
|
return output
|
|
|
|
def broadcast(self, tensor: torch.Tensor, src_rank=0):
|
|
if self.global_rank != src_rank:
|
|
reduce_tensor = torch.zeros_like(tensor)
|
|
xm.all_reduce('sum', reduce_tensor)
|
|
else:
|
|
xm.all_reduce('sum', tensor)
|
|
return tensor
|
|
|
|
def broadcast_(self, tensor: torch.Tensor, src_rank=0):
|
|
out_tensor = self.broadcast(tensor, src_rank)
|
|
return tensor.copy_(out_tensor)
|
|
|
|
def barrier(self):
|
|
xm.rendezvous('timm.bits.dist_barrier')
|