|
|
|
from dataclasses import dataclass, field, InitVar
|
|
|
|
from functools import partial
|
|
|
|
from typing import Callable, Optional, Union
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
from .grad_clip import get_clip_grad_fn, get_clip_parameters
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class Updater:
|
|
|
|
|
|
|
|
model: nn.Module = None
|
|
|
|
optimizer: torch.optim.Optimizer = None # FIXME handle multiple optimizers per-model
|
|
|
|
clip_fn: Optional[Union[Callable, str]] = None
|
|
|
|
clip_value: Optional[float] = None
|
|
|
|
clip_params_fn: Optional[Callable] = None
|
|
|
|
grad_scaler: Optional[Callable] = None
|
|
|
|
create_graph: Optional[bool] = None
|
|
|
|
after_step_closure: bool = False
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
assert self.model is not None
|
|
|
|
assert self.optimizer is not None
|
|
|
|
if self.clip_fn is not None:
|
|
|
|
if isinstance(self.clip_fn, Callable):
|
|
|
|
skip_last = 0
|
|
|
|
else:
|
|
|
|
assert isinstance(self.clip_fn, str)
|
|
|
|
skip_last = 2 if 'agc' in self.clip_fn else 0
|
|
|
|
self.clip_fn = get_clip_grad_fn(self.clip_fn)
|
|
|
|
assert self.clip_value is not None
|
|
|
|
self.clip_params_fn = partial(get_clip_parameters, model=self.model, skip_last=skip_last)
|
|
|
|
if self.create_graph is None:
|
|
|
|
self.create_graph = getattr(self.optimizer, 'second_order', False)
|
|
|
|
self.after_step_closure = False
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
self.optimizer.zero_grad()
|
|
|
|
|
|
|
|
def apply(self, loss: torch.Tensor, accumulate=False):
|
|
|
|
loss.backward(create_graph=self.create_graph)
|
|
|
|
if accumulate:
|
|
|
|
return
|
|
|
|
if self.clip_fn is not None:
|
|
|
|
self.clip_fn(self.clip_params_fn(), self.clip_value)
|
|
|
|
self.optimizer.step()
|
|
|
|
self.reset()
|
|
|
|
|
|
|
|
def get_average_lr(self):
|
|
|
|
lrl = [param_group['lr'] for param_group in self.optimizer.param_groups if param_group['lr'] > 0]
|
|
|
|
return sum(lrl) / len(lrl)
|
|
|
|
|
|
|
|
def state_dict(self):
|
|
|
|
state_dict = dict(optimizer=self.optimizer.state_dict())
|
|
|
|
if self.grad_scaler is not None:
|
|
|
|
state_dict['grad_scaler'] = self.grad_scaler.state_dict()
|
|
|
|
|
|
|
|
def load_state_dict(self, state_dict):
|
|
|
|
if 'optimizer' in state_dict:
|
|
|
|
self.optimizer.load_state_dict(state_dict['optimizer'])
|
|
|
|
if 'grad_scaler' in state_dict and self.grad_scaler is not None:
|
|
|
|
self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
|
|
|
|
|
|
|
|
def after_step(self, after_step_fn, *args):
|
|
|
|
after_step_fn(*args)
|
|
|
|
|