""" Lookahead Optimizer Wrapper. Implementation modified from: https://github.com/alphadl/lookahead.pytorch Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610 """ import torch from torch.optim.optimizer import Optimizer from collections import defaultdict class Lookahead(Optimizer): def __init__(self, base_optimizer, alpha=0.5, k=6): if not 0.0 <= alpha <= 1.0: raise ValueError(f'Invalid slow update rate: {alpha}') if not 1 <= k: raise ValueError(f'Invalid lookahead steps: {k}') self.alpha = alpha self.k = k self.base_optimizer = base_optimizer self.param_groups = self.base_optimizer.param_groups self.defaults = base_optimizer.defaults self.state = defaultdict(dict) for group in self.param_groups: group["step_counter"] = 0 def update_slow_weights(self, group): for fast_p in group["params"]: if fast_p.grad is None: continue param_state = self.state[fast_p] if "slow_buffer" not in param_state: param_state["slow_buffer"] = torch.empty_like(fast_p.data) param_state["slow_buffer"].copy_(fast_p.data) slow = param_state["slow_buffer"] slow.add_(self.alpha, fast_p.data - slow) fast_p.data.copy_(slow) def sync_lookahead(self): for group in self.param_groups: self.update_slow_weights(group) def step(self, closure=None): loss = self.base_optimizer.step(closure) for group in self.param_groups: group['step_counter'] += 1 if group['step_counter'] % self.k == 0: self.update_slow_weights(group) return loss def state_dict(self): fast_state_dict = self.base_optimizer.state_dict() slow_state = { (id(k) if isinstance(k, torch.Tensor) else k): v for k, v in self.state.items() } fast_state = fast_state_dict["state"] param_groups = fast_state_dict["param_groups"] return { "state": fast_state, "slow_state": slow_state, "param_groups": param_groups, } def load_state_dict(self, state_dict): if 'slow_state' not in state_dict: print('Loading state_dict from optimizer without Lookahead applied') state_dict['slow_state'] = defaultdict(dict) slow_state_dict = { "state": state_dict["slow_state"], "param_groups": state_dict["param_groups"], } fast_state_dict = { "state": state_dict["state"], "param_groups": state_dict["param_groups"], } super(Lookahead, self).load_state_dict(slow_state_dict) self.base_optimizer.load_state_dict(fast_state_dict) def add_param_group(self, param_group): r"""Add a param group to the :class:`Optimizer` s `param_groups`. This can be useful when fine tuning a pre-trained network as frozen layers can be made trainable and added to the :class:`Optimizer` as training progresses. Args: param_group (dict): Specifies what Tensors should be optimized along with group specific optimization options. """ param_group['step_counter'] = 0 self.base_optimizer.add_param_group(param_group)