From 709d5e0d9d2d3f501531506eda96a435737223a3 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 14 Feb 2023 23:55:05 -0800 Subject: [PATCH] Add Lion optimizer --- timm/optim/lion.py | 87 +++++++++++++++++++++++++++++++++++++ timm/optim/optim_factory.py | 3 ++ 2 files changed, 90 insertions(+) create mode 100644 timm/optim/lion.py diff --git a/timm/optim/lion.py b/timm/optim/lion.py new file mode 100644 index 00000000..434d9831 --- /dev/null +++ b/timm/optim/lion.py @@ -0,0 +1,87 @@ +""" Lion Optimizer +Paper: `Symbolic Discovery of Optimization Algorithms` - https://arxiv.org/abs/2302.06675 +Original Impl: https://github.com/google/automl/tree/master/lion +""" +# Copyright 2023 Google Research. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import torch +from torch.optim.optimizer import Optimizer + + +class Lion(Optimizer): + r"""Implements Lion algorithm.""" + + def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0): + """Initialize the hyperparameters. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-4) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.99)) + weight_decay (float, optional): weight decay coefficient (default: 0) + """ + + if not 0.0 <= lr: + raise ValueError('Invalid learning rate: {}'.format(lr)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1])) + defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + + Returns: + the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + # Perform stepweight decay + p.data.mul_(1 - group['lr'] * group['weight_decay']) + + grad = p.grad + state = self.state[p] + # State initialization + if len(state) == 0: + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + + exp_avg = state['exp_avg'] + beta1, beta2 = group['betas'] + + # Weight update + update = exp_avg * beta1 + grad * (1 - beta1) + p.add_(torch.sign(update), alpha=-group['lr']) + # Decay the momentum running average coefficient + exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2) + + return loss diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py index 8613a62c..c2f253d3 100644 --- a/timm/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -18,6 +18,7 @@ from .adamp import AdamP from .adan import Adan from .lamb import Lamb from .lars import Lars +from .lion import Lion from .lookahead import Lookahead from .madgrad import MADGRAD from .nadam import Nadam @@ -313,6 +314,8 @@ def create_optimizer_v2( optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args) elif opt_lower == 'rmsproptf': optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args) + elif opt_lower == 'lion': + optimizer = Lion(parameters, **opt_args) # second order elif opt_lower == 'adahessian':