You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/bits/updater_factory.py

31 lines
978 B

from typing import Callable, Optional, Union, Any
import torch
from .device_env import DeviceEnv
from .device_env_factory import get_device
from .updater import Updater
from .updater_cuda import UpdaterCuda
from .updater_xla import UpdaterXla
def create_updater(
optimizer: torch.optim.Optimizer,
dev_env: Optional[DeviceEnv] = None,
clip_value: Optional[Union[Callable, float]] = None,
clip_mode: str = 'norm',
scaler_kwargs: Any = None) -> Updater:
if not dev_env:
dev_env = get_device()
updater_kwargs = dict(
optimizer=optimizer, clip_value=clip_value, clip_mode=clip_mode, scaler_kwargs=scaler_kwargs)
if dev_env.type == 'xla':
return UpdaterXla(**updater_kwargs, use_scaler=dev_env.amp)
elif dev_env.type == 'cuda':
return UpdaterCuda(**updater_kwargs, use_scaler=dev_env.amp)
else:
updater_kwargs.pop('scaler_kwargs', None)
return Updater(**updater_kwargs)