You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/bits/device_env_cuda.py

93 lines
2.8 KiB

import os
from contextlib import suppress
import torch
from torch.nn.parallel import DistributedDataParallel
from .device_env import DeviceEnv
def is_cuda_available():
return torch.cuda.is_available()
class DeviceEnvCuda(DeviceEnv):
def __init__(self, device_idx=None, local_rank=None, amp=False, memory_format=None):
assert torch.cuda.device_count()
torch.backends.cudnn.benchmark = True
self._local_rank = 0
self._distributed = False
self._world_size = 1
self._global_rank = 0
if 'WORLD_SIZE' in os.environ:
self._distributed = int(os.environ['WORLD_SIZE']) > 1
if self._distributed:
if local_rank is None:
lr = os.environ.get('LOCAL_RANK', None)
if lr is None:
raise RuntimeError(
'At least one of LOCAL_RANK env variable or local_rank arg must be set to valid integer.')
self._local_rank = lr
else:
self._local_rank = int(local_rank)
self._device = torch.device('cuda:%d' % self._local_rank)
torch.cuda.set_device(self._local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
self._world_size = torch.distributed.get_world_size()
self._global_rank = torch.distributed.get_rank()
else:
self._device = torch.device('cuda' if device_idx is None else f'cuda:{device_idx}')
self._memory_format = memory_format
if amp:
self._amp = amp
self._autocast = torch.cuda.amp.autocast
else:
self._amp = amp
self._autocast = suppress
@property
def device(self):
return self._device
@property
def local_rank(self):
return self._local_rank
@property
def global_rank(self):
return self._global_rank
@property
def is_distributed(self):
return self._distributed
@property
def world_size(self):
return self._world_size
@property
def is_master(self):
return self._local_rank == 0
@property
def type(self) -> str:
return 'cuda'
@property
def amp(self) -> bool:
return self._amp
@property
def autocast(self):
return self._autocast
def wrap_distributed(self, *modules, **kwargs):
wrapped = [DistributedDataParallel(m, device_ids=[self._local_rank], **kwargs) for m in modules]
return wrapped[0] if len(wrapped) == 1 else wrapped
def to_device(self, *modules: torch.nn.Module):
# FIXME handling dtype / memformat... disable flags, enable flags, diff fn?
moved = [m.to(device=self._device, memory_format=self._memory_format) for m in modules]
return moved[0] if len(moved) == 1 else moved