""" Lookahead Optimizer Wrapper. Implementation modified from: https://github.com/alphadl/lookahead.pytorch Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610 Hacked together by / Copyright 2020 Ross Wightman """ import torch from torch.optim.optimizer import Optimizer from collections import defaultdict class Lookahead(Optimizer): def __init__(self, base_optimizer, alpha=0.5, k=6): # NOTE super().__init__() not called on purpose if not 0.0 <= alpha <= 1.0: raise ValueError(f'Invalid slow update rate: {alpha}') if not 1 <= k: raise ValueError(f'Invalid lookahead steps: {k}') defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) self._base_optimizer = base_optimizer self.param_groups = base_optimizer.param_groups self.defaults = base_optimizer.defaults self.defaults.update(defaults) self.state = defaultdict(dict) # manually add our defaults to the param groups for name, default in defaults.items(): for group in self._base_optimizer.param_groups: group.setdefault(name, default) @torch.no_grad() def update_slow(self, group): for fast_p in group["params"]: if fast_p.grad is None: continue param_state = self._base_optimizer.state[fast_p] if 'lookahead_slow_buff' not in param_state: param_state['lookahead_slow_buff'] = torch.empty_like(fast_p) param_state['lookahead_slow_buff'].copy_(fast_p) slow = param_state['lookahead_slow_buff'] slow.add_(fast_p - slow, alpha=group['lookahead_alpha']) fast_p.copy_(slow) def sync_lookahead(self): for group in self._base_optimizer.param_groups: self.update_slow(group) @torch.no_grad() def step(self, closure=None): loss = self._base_optimizer.step(closure) for group in self._base_optimizer.param_groups: group['lookahead_step'] += 1 if group['lookahead_step'] % group['lookahead_k'] == 0: self.update_slow(group) return loss def state_dict(self): return self._base_optimizer.state_dict() def load_state_dict(self, state_dict): self._base_optimizer.load_state_dict(state_dict) self.param_groups = self._base_optimizer.param_groups