You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
93 lines
2.8 KiB
93 lines
2.8 KiB
import os
|
|
from contextlib import suppress
|
|
|
|
import torch
|
|
from torch.nn.parallel import DistributedDataParallel
|
|
|
|
from .device_env import DeviceEnv
|
|
|
|
|
|
def is_cuda_available():
|
|
return torch.cuda.is_available()
|
|
|
|
|
|
class DeviceEnvCuda(DeviceEnv):
|
|
|
|
def __init__(self, device_idx=None, local_rank=None, amp=False, memory_format=None):
|
|
assert torch.cuda.device_count()
|
|
torch.backends.cudnn.benchmark = True
|
|
self._local_rank = 0
|
|
self._distributed = False
|
|
self._world_size = 1
|
|
self._global_rank = 0
|
|
if 'WORLD_SIZE' in os.environ:
|
|
self._distributed = int(os.environ['WORLD_SIZE']) > 1
|
|
if self._distributed:
|
|
if local_rank is None:
|
|
lr = os.environ.get('LOCAL_RANK', None)
|
|
if lr is None:
|
|
raise RuntimeError(
|
|
'At least one of LOCAL_RANK env variable or local_rank arg must be set to valid integer.')
|
|
self._local_rank = lr
|
|
else:
|
|
self._local_rank = int(local_rank)
|
|
self._device = torch.device('cuda:%d' % self._local_rank)
|
|
torch.cuda.set_device(self._local_rank)
|
|
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
|
self._world_size = torch.distributed.get_world_size()
|
|
self._global_rank = torch.distributed.get_rank()
|
|
else:
|
|
self._device = torch.device('cuda' if device_idx is None else f'cuda:{device_idx}')
|
|
self._memory_format = memory_format
|
|
if amp:
|
|
self._amp = amp
|
|
self._autocast = torch.cuda.amp.autocast
|
|
else:
|
|
self._amp = amp
|
|
self._autocast = suppress
|
|
|
|
@property
|
|
def device(self):
|
|
return self._device
|
|
|
|
@property
|
|
def local_rank(self):
|
|
return self._local_rank
|
|
|
|
@property
|
|
def global_rank(self):
|
|
return self._global_rank
|
|
|
|
@property
|
|
def is_distributed(self):
|
|
return self._distributed
|
|
|
|
@property
|
|
def world_size(self):
|
|
return self._world_size
|
|
|
|
@property
|
|
def is_master(self):
|
|
return self._local_rank == 0
|
|
|
|
@property
|
|
def type(self) -> str:
|
|
return 'cuda'
|
|
|
|
@property
|
|
def amp(self) -> bool:
|
|
return self._amp
|
|
|
|
@property
|
|
def autocast(self):
|
|
return self._autocast
|
|
|
|
def wrap_distributed(self, *modules, **kwargs):
|
|
wrapped = [DistributedDataParallel(m, device_ids=[self._local_rank], **kwargs) for m in modules]
|
|
return wrapped[0] if len(wrapped) == 1 else wrapped
|
|
|
|
def to_device(self, *modules: torch.nn.Module):
|
|
# FIXME handling dtype / memformat... disable flags, enable flags, diff fn?
|
|
moved = [m.to(device=self._device, memory_format=self._memory_format) for m in modules]
|
|
return moved[0] if len(moved) == 1 else moved
|