You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/bits/updater_factory.py

40 lines
1.3 KiB

from typing import Callable, Optional, Union, Any
import torch
from .device_env import DeviceEnv, DeviceEnvType
from .device_env_factory import get_device
from .updater import Updater
from .updater_cuda import UpdaterCudaWithScaler
from .updater_deepspeed import UpdaterDeepSpeed
from .updater_xla import UpdaterXla, UpdaterXlaWithScaler
def create_updater(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
clip_fn: Optional[Union[Callable, str]] = None,
clip_value: Optional[float] = None,
scaler_kwargs: Any = None,
dev_env: Optional[DeviceEnv] = None,
deepspeed: bool = False,
) -> Updater:
if not dev_env:
dev_env = get_device()
updater_kwargs = dict(model=model, optimizer=optimizer, clip_fn=clip_fn, clip_value=clip_value)
use_scaler = dev_env.amp
if use_scaler:
updater_kwargs['scaler_kwargs'] = scaler_kwargs
updater_cls = Updater
if dev_env.type == DeviceEnvType.XLA:
updater_cls = UpdaterXlaWithScaler if use_scaler else UpdaterXla
elif dev_env.type == DeviceEnvType.CUDA and use_scaler:
updater_cls = UpdaterCudaWithScaler
elif deepspeed:
del updater_kwargs['scaler_kwargs']
updater_cls = UpdaterDeepSpeed
return updater_cls(**updater_kwargs)