You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
31 lines
978 B
31 lines
978 B
from typing import Callable, Optional, Union, Any
|
|
|
|
import torch
|
|
|
|
from .device_env import DeviceEnv
|
|
from .device_env_factory import get_device
|
|
from .updater import Updater
|
|
from .updater_cuda import UpdaterCuda
|
|
from .updater_xla import UpdaterXla
|
|
|
|
|
|
def create_updater(
|
|
optimizer: torch.optim.Optimizer,
|
|
dev_env: Optional[DeviceEnv] = None,
|
|
clip_value: Optional[Union[Callable, float]] = None,
|
|
clip_mode: str = 'norm',
|
|
scaler_kwargs: Any = None) -> Updater:
|
|
|
|
if not dev_env:
|
|
dev_env = get_device()
|
|
|
|
updater_kwargs = dict(
|
|
optimizer=optimizer, clip_value=clip_value, clip_mode=clip_mode, scaler_kwargs=scaler_kwargs)
|
|
if dev_env.type == 'xla':
|
|
return UpdaterXla(**updater_kwargs, use_scaler=dev_env.amp)
|
|
elif dev_env.type == 'cuda':
|
|
return UpdaterCuda(**updater_kwargs, use_scaler=dev_env.amp)
|
|
else:
|
|
updater_kwargs.pop('scaler_kwargs', None)
|
|
return Updater(**updater_kwargs)
|