You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
37 lines
1005 B
37 lines
1005 B
from .device_env import DeviceEnv
|
|
from .device_env_cuda import DeviceEnvCuda, is_cuda_available
|
|
from .device_env_xla import DeviceEnvXla, is_xla_available
|
|
|
|
_device_env = None
|
|
|
|
|
|
def initialize_device(force_cpu: bool = False, **kwargs) -> DeviceEnv:
|
|
global _device_env
|
|
if _device_env is not None:
|
|
# warning
|
|
return _device_env
|
|
|
|
denv = None
|
|
if not force_cpu:
|
|
xla_device_type = kwargs.get('xla_device_type', None)
|
|
if is_xla_available(xla_device_type):
|
|
# XLA supports more than just TPU, will search in order TPU, GPU, CPU
|
|
denv = DeviceEnvXla(**kwargs)
|
|
elif is_cuda_available():
|
|
denv = DeviceEnvCuda(**kwargs)
|
|
|
|
if denv is None:
|
|
denv = DeviceEnv()
|
|
|
|
print(denv) # FIXME DEBUG
|
|
_device_env = denv
|
|
return denv
|
|
|
|
|
|
def get_device() -> DeviceEnv:
|
|
if _device_env is None:
|
|
raise RuntimeError('Please initialize device environment by calling initialize_device first.')
|
|
return _device_env
|
|
|
|
|