diff --git a/timm/optim/lookahead.py b/timm/optim/lookahead.py index cc1fb495..7a58e0a6 100644 --- a/timm/optim/lookahead.py +++ b/timm/optim/lookahead.py @@ -13,37 +13,40 @@ class Lookahead(Optimizer): raise ValueError(f'Invalid slow update rate: {alpha}') if not 1 <= k: raise ValueError(f'Invalid lookahead steps: {k}') - self.alpha = alpha - self.k = k + defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) self.base_optimizer = base_optimizer self.param_groups = self.base_optimizer.param_groups self.defaults = base_optimizer.defaults + self.defaults.update(defaults) self.state = defaultdict(dict) - for group in self.param_groups: - group["step_counter"] = 0 + # manually add our defaults to the param groups + for name, default in defaults.items(): + for group in self.param_groups: + group.setdefault(name, default) - def update_slow_weights(self, group): + def update_slow(self, group): for fast_p in group["params"]: if fast_p.grad is None: continue param_state = self.state[fast_p] - if "slow_buffer" not in param_state: - param_state["slow_buffer"] = torch.empty_like(fast_p.data) - param_state["slow_buffer"].copy_(fast_p.data) - slow = param_state["slow_buffer"] - slow.add_(self.alpha, fast_p.data - slow) + if 'slow_buffer' not in param_state: + param_state['slow_buffer'] = torch.empty_like(fast_p.data) + param_state['slow_buffer'].copy_(fast_p.data) + slow = param_state['slow_buffer'] + slow.add_(group['lookahead_alpha'], fast_p.data - slow) fast_p.data.copy_(slow) def sync_lookahead(self): for group in self.param_groups: - self.update_slow_weights(group) + self.update_slow(group) def step(self, closure=None): + #assert id(self.param_groups) == id(self.base_optimizer.param_groups) loss = self.base_optimizer.step(closure) for group in self.param_groups: - group['step_counter'] += 1 - if group['step_counter'] % self.k == 0: - self.update_slow_weights(group) + group['lookahead_step'] += 1 + if group['lookahead_step'] % group['lookahead_k'] == 0: + self.update_slow(group) return loss def state_dict(self): @@ -52,37 +55,36 @@ class Lookahead(Optimizer): (id(k) if isinstance(k, torch.Tensor) else k): v for k, v in self.state.items() } - fast_state = fast_state_dict["state"] - param_groups = fast_state_dict["param_groups"] + fast_state = fast_state_dict['state'] + param_groups = fast_state_dict['param_groups'] return { - "state": fast_state, - "slow_state": slow_state, - "param_groups": param_groups, + 'state': fast_state, + 'slow_state': slow_state, + 'param_groups': param_groups, } def load_state_dict(self, state_dict): + fast_state_dict = { + 'state': state_dict['state'], + 'param_groups': state_dict['param_groups'], + } + self.base_optimizer.load_state_dict(fast_state_dict) + + # We want to restore the slow state, but share param_groups reference + # with base_optimizer. This is a bit redundant but least code + slow_state_new = False if 'slow_state' not in state_dict: - print('Loading state_dict from optimizer without Lookahead applied') + print('Loading state_dict from optimizer without Lookahead applied.') state_dict['slow_state'] = defaultdict(dict) + slow_state_new = True slow_state_dict = { - "state": state_dict["slow_state"], - "param_groups": state_dict["param_groups"], - } - fast_state_dict = { - "state": state_dict["state"], - "param_groups": state_dict["param_groups"], + 'state': state_dict['slow_state'], + 'param_groups': state_dict['param_groups'], # this is pointless but saves code } super(Lookahead, self).load_state_dict(slow_state_dict) - self.base_optimizer.load_state_dict(fast_state_dict) - - def add_param_group(self, param_group): - r"""Add a param group to the :class:`Optimizer` s `param_groups`. - This can be useful when fine tuning a pre-trained network as frozen - layers can be made trainable and added to the :class:`Optimizer` as - training progresses. - Args: - param_group (dict): Specifies what Tensors should be optimized along - with group specific optimization options. - """ - param_group['step_counter'] = 0 - self.base_optimizer.add_param_group(param_group) + self.param_groups = self.base_optimizer.param_groups # make both ref same container + if slow_state_new: + # reapply defaults to catch missing lookahead specific ones + for name, default in self.defaults.items(): + for group in self.param_groups: + group.setdefault(name, default) diff --git a/timm/optim/nvnovograd.py b/timm/optim/nvnovograd.py new file mode 100644 index 00000000..e69de29b