You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/bits/updater_xla.py

69 lines
1.9 KiB

from dataclasses import dataclass, field, InitVar
from typing import Any, Dict
import torch
import torch.nn as nn
try:
import torch_xla.core.xla_model as xm
_HAS_XLA = True
except ImportError as e:
xm = None
_HAS_XLA = False
try:
# only the very latest XLA builds have AMP
import torch_xla.amp as xa
except ImportError as e:
xa = None
from .updater import Updater
@dataclass
class UpdaterXla(Updater):
def __post_init__(self):
super().__post_init__()
self.after_step_closure = True
def apply(self, loss: torch.Tensor, accumulate: bool = False):
loss.backward(create_graph=self.create_graph)
if accumulate:
return
xm.reduce_gradients(self.optimizer)
if self.clip_fn is not None:
self.clip_fn(self.clip_params_fn(), self.clip_value)
self.optimizer.step()
xm.mark_step()
self.reset()
def after_step(self, after_step_fn, *args):
xm.add_step_closure(after_step_fn, args)
@dataclass
class UpdaterXlaWithScaler(UpdaterXla):
scaler_kwargs: InitVar[Dict[str, Any]] = None
def __post_init__(self, scaler_kwargs: Dict[str, Any]):
super().__post_init__()
scaler_kwargs = scaler_kwargs or {}
assert xa is not None, 'XLA AMP not present in this build'
self.scaler = xa.GradScaler(**scaler_kwargs)
def apply(self, loss: torch.Tensor, accumulate: bool = False):
self.scaler.scale(loss).backward(create_graph=self.create_graph)
if accumulate:
# unscale first?
return
xm.reduce_gradients(self.optimizer)
if self.clip_fn is not None:
self.scaler.unscale_(self.optimizer) # unscale the gradients of optimizer's assigned params in-place
self.clip_fn(self.clip_params_fn(), self.clip_value)
self.scaler.step(self.optimizer)
self.scaler.update()
xm.mark_step()
self.reset()