You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
93 lines
2.2 KiB
93 lines
2.2 KiB
import os
|
|
from contextlib import suppress
|
|
import torch
|
|
|
|
try:
|
|
import torch_xla.core.xla_model as xm
|
|
_HAS_XLA = True
|
|
except ImportError as e:
|
|
xm = None
|
|
_HAS_XLA = False
|
|
|
|
try:
|
|
# only the very latest XLA builds have AMP
|
|
import torch_xla.amp as xa
|
|
except ImportError as e:
|
|
xa = None
|
|
|
|
from .device_env import DeviceEnv
|
|
|
|
|
|
def is_xla_available(xla_device_type=None):
|
|
if not _HAS_XLA:
|
|
return False
|
|
supported_devs = xm.get_xla_supported_devices(devkind=xla_device_type)
|
|
print(supported_devs)
|
|
return len(supported_devs) >= 1
|
|
|
|
|
|
class DeviceEnvXla(DeviceEnv):
|
|
|
|
def __init__(self, xla_device_type=None, device_idx=None, local_rank=0, amp=False):
|
|
self._device = xm.xla_device(n=device_idx, devkind=xla_device_type)
|
|
self._local_rank = xm.get_local_ordinal(local_rank)
|
|
self._world_size = xm.xrt_world_size()
|
|
self._distributed = self._world_size > 1
|
|
self._global_rank = 0
|
|
if self._distributed:
|
|
self._global_rank = xm.get_ordinal()
|
|
if amp:
|
|
assert xa is not None, 'XLA AMP is not present on this build'
|
|
self._autocast = xa.autocast
|
|
else:
|
|
self._autocast = suppress
|
|
self._memory_format = None
|
|
|
|
@property
|
|
def device(self):
|
|
return self._device
|
|
|
|
@property
|
|
def local_rank(self):
|
|
return self._local_rank
|
|
|
|
@property
|
|
def global_rank(self):
|
|
return self._global_rank
|
|
|
|
@property
|
|
def is_distributed(self):
|
|
return self._distributed
|
|
|
|
@property
|
|
def world_size(self):
|
|
return self._world_size
|
|
|
|
@property
|
|
def is_master(self):
|
|
return self._global_rank == 0
|
|
|
|
@property
|
|
def type(self) -> str:
|
|
return 'xla'
|
|
|
|
@property
|
|
def amp(self) -> bool:
|
|
return False
|
|
|
|
@property
|
|
def autocast(self):
|
|
return self._autocast
|
|
|
|
def wrap_distributed(self, *modules):
|
|
# NO-OP
|
|
wrapped = [m for m in modules]
|
|
return wrapped[0] if len(wrapped) == 1 else wrapped
|
|
|
|
def to_device(self, *modules: torch.nn.Module):
|
|
moved = [m.to(device=self._device, memory_format=self._memory_format) for m in modules]
|
|
return moved[0] if len(moved) == 1 else moved
|
|
|
|
def mark_step(self):
|
|
xm.mark_step()
|