You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/bits/updater.py

69 lines
2.4 KiB

from dataclasses import dataclass, field, InitVar
from functools import partial
from typing import Callable, Optional, Union
import torch
import torch.nn as nn
from .grad_clip import get_clip_grad_fn, get_clip_parameters
@dataclass
class Updater:
model: nn.Module = None
optimizer: torch.optim.Optimizer = None # FIXME handle multiple optimizers per-model
clip_fn: Optional[Union[Callable, str]] = None
clip_value: Optional[float] = None
clip_params_fn: Optional[Callable] = None
grad_scaler: Optional[Callable] = None
create_graph: Optional[bool] = None
after_step_closure: bool = False
def __post_init__(self):
assert self.model is not None
assert self.optimizer is not None
if self.clip_fn is not None:
if isinstance(self.clip_fn, Callable):
skip_last = 0
else:
assert isinstance(self.clip_fn, str)
skip_last = 2 if 'agc' in self.clip_fn else 0
self.clip_fn = get_clip_grad_fn(self.clip_fn)
assert self.clip_value is not None
self.clip_params_fn = partial(get_clip_parameters, model=self.model, skip_last=skip_last)
if self.create_graph is None:
self.create_graph = getattr(self.optimizer, 'second_order', False)
self.after_step_closure = False
def reset(self):
self.optimizer.zero_grad()
def apply(self, loss: torch.Tensor, accumulate=False):
loss.backward(create_graph=self.create_graph)
if accumulate:
return
if self.clip_fn is not None:
self.clip_fn(self.clip_params_fn(), self.clip_value)
self.optimizer.step()
self.reset()
def get_average_lr(self):
lrl = [param_group['lr'] for param_group in self.optimizer.param_groups if param_group['lr'] > 0]
return sum(lrl) / len(lrl)
def state_dict(self):
state_dict = dict(optimizer=self.optimizer.state_dict())
if self.grad_scaler is not None:
state_dict['grad_scaler'] = self.grad_scaler.state_dict()
def load_state_dict(self, state_dict):
if 'optimizer' in state_dict:
self.optimizer.load_state_dict(state_dict['optimizer'])
if 'grad_scaler' in state_dict and self.grad_scaler is not None:
self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
def after_step(self, after_step_fn, *args):
after_step_fn(*args)