You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/optim/lookahead.py

89 lines
3.4 KiB

""" Lookahead Optimizer Wrapper.
Implementation modified from: https://github.com/alphadl/lookahead.pytorch
Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610
"""
import torch
from torch.optim.optimizer import Optimizer
from collections import defaultdict
class Lookahead(Optimizer):
def __init__(self, base_optimizer, alpha=0.5, k=6):
if not 0.0 <= alpha <= 1.0:
raise ValueError(f'Invalid slow update rate: {alpha}')
if not 1 <= k:
raise ValueError(f'Invalid lookahead steps: {k}')
self.alpha = alpha
self.k = k
self.base_optimizer = base_optimizer
self.param_groups = self.base_optimizer.param_groups
self.defaults = base_optimizer.defaults
self.state = defaultdict(dict)
for group in self.param_groups:
group["step_counter"] = 0
def update_slow_weights(self, group):
for fast_p in group["params"]:
if fast_p.grad is None:
continue
param_state = self.state[fast_p]
if "slow_buffer" not in param_state:
param_state["slow_buffer"] = torch.empty_like(fast_p.data)
param_state["slow_buffer"].copy_(fast_p.data)
slow = param_state["slow_buffer"]
slow.add_(self.alpha, fast_p.data - slow)
fast_p.data.copy_(slow)
def sync_lookahead(self):
for group in self.param_groups:
self.update_slow_weights(group)
def step(self, closure=None):
loss = self.base_optimizer.step(closure)
for group in self.param_groups:
group['step_counter'] += 1
if group['step_counter'] % self.k == 0:
self.update_slow_weights(group)
return loss
def state_dict(self):
fast_state_dict = self.base_optimizer.state_dict()
slow_state = {
(id(k) if isinstance(k, torch.Tensor) else k): v
for k, v in self.state.items()
}
fast_state = fast_state_dict["state"]
param_groups = fast_state_dict["param_groups"]
return {
"state": fast_state,
"slow_state": slow_state,
"param_groups": param_groups,
}
def load_state_dict(self, state_dict):
if 'slow_state' not in state_dict:
print('Loading state_dict from optimizer without Lookahead applied')
state_dict['slow_state'] = defaultdict(dict)
slow_state_dict = {
"state": state_dict["slow_state"],
"param_groups": state_dict["param_groups"],
}
fast_state_dict = {
"state": state_dict["state"],
"param_groups": state_dict["param_groups"],
}
super(Lookahead, self).load_state_dict(slow_state_dict)
self.base_optimizer.load_state_dict(fast_state_dict)
def add_param_group(self, param_group):
r"""Add a param group to the :class:`Optimizer` s `param_groups`.
This can be useful when fine tuning a pre-trained network as frozen
layers can be made trainable and added to the :class:`Optimizer` as
training progresses.
Args:
param_group (dict): Specifies what Tensors should be optimized along
with group specific optimization options.
"""
param_group['step_counter'] = 0
self.base_optimizer.add_param_group(param_group)