You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/bits/device_env_factory.py

37 lines
1005 B

from .device_env import DeviceEnv
from .device_env_cuda import DeviceEnvCuda, is_cuda_available
from .device_env_xla import DeviceEnvXla, is_xla_available
_device_env = None
def initialize_device(force_cpu: bool = False, **kwargs) -> DeviceEnv:
global _device_env
if _device_env is not None:
# warning
return _device_env
denv = None
if not force_cpu:
xla_device_type = kwargs.get('xla_device_type', None)
if is_xla_available(xla_device_type):
# XLA supports more than just TPU, will search in order TPU, GPU, CPU
denv = DeviceEnvXla(**kwargs)
elif is_cuda_available():
denv = DeviceEnvCuda(**kwargs)
if denv is None:
denv = DeviceEnv()
print(denv) # FIXME DEBUG
_device_env = denv
return denv
def get_device() -> DeviceEnv:
if _device_env is None:
raise RuntimeError('Please initialize device environment by calling initialize_device first.')
return _device_env