From 1042b8a14630f78e0b405f389b8a29aff05d0113 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 9 Aug 2021 13:13:43 -0700 Subject: [PATCH] Add non fused LAMB optimizer option --- timm/optim/lamb.py | 204 ++++++++++++++++++++++++++++++++++++ timm/optim/optim_factory.py | 8 +- 2 files changed, 211 insertions(+), 1 deletion(-) create mode 100644 timm/optim/lamb.py diff --git a/timm/optim/lamb.py b/timm/optim/lamb.py new file mode 100644 index 00000000..3c31d28f --- /dev/null +++ b/timm/optim/lamb.py @@ -0,0 +1,204 @@ +""" PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb + +This optimizer code was adapted from the following (starting with latest) +* https://github.com/HabanaAI/Model-References/blob/2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py +* https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py +* https://github.com/cybertronai/pytorch-lamb + +Use FusedLamb if you can. The reason for including this variant of Lamb is to have a version that is +similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install APEX for whatever reason. + +Original copyrights for above sources are below. +""" +# Copyright (c) 2021, Habana Labs Ltd. All rights reserved. + +# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# MIT License +# +# Copyright (c) 2019 cybertronai +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch +from torch.optim import Optimizer + + +class NvLamb(Optimizer): + """Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB + reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py + + LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its norm. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + grad_averaging (bool, optional): whether apply (1-beta2) to grad when + calculating running averages of gradient. (default: True) + set_grad_none (bool, optional): whether set grad to None when zero_grad() + method is called. (default: True) + max_grad_norm (float, optional): value used to clip global grad norm + (default: 1.0) + use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0 + weight decay parameter (default: False) + + .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: + https://arxiv.org/abs/1904.00962 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, params, lr=1e-3, bias_correction=True, + betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01, + grad_averaging=True, set_grad_none=True, + max_grad_norm=1.0, use_nvlamb=False): + defaults = dict(lr=lr, bias_correction=bias_correction, + betas=betas, eps=eps, weight_decay=weight_decay, + grad_averaging=grad_averaging, + max_grad_norm=max_grad_norm) + super().__init__(params, defaults) + self.set_grad_none = set_grad_none + self.use_nvlamb = use_nvlamb + + def zero_grad(self): + if self.set_grad_none: + for group in self.param_groups: + for p in group['params']: + p.grad = None + else: + super(NvLamb, self).zero_grad() + + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + device = self.param_groups[0]["params"][0].device + + loss = None + if closure is not None: + loss = closure() + + global_grad_norm = torch.zeros(1, device=device) + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') + global_grad_norm.add_(grad.pow(2).sum()) + + global_grad_norm_ = torch.sqrt(global_grad_norm) + max_grad_norm = self.defaults['max_grad_norm'] + + if global_grad_norm_ > max_grad_norm: + clip_global_grad_norm = global_grad_norm_ / max_grad_norm + else: + clip_global_grad_norm = 1.0 + + for group in self.param_groups: + bias_correction = 1 if group['bias_correction'] else 0 + beta1, beta2 = group['betas'] + grad_averaging = 1 if group['grad_averaging'] else 0 + if grad_averaging: + beta3 = 1 - beta1 + else: + beta3 = 1.0 + + # assume same step across group now to simplify things + # per parameter step can be easily support by making it tensor, or pass list into kernel + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + step_size = group['lr'] + + if bias_correction: + bias_correction1 = 1 - beta1 ** group['step'] + bias_correction2 = 1 - beta2 ** group['step'] + else: + bias_correction1, bias_correction2 = 1.0, 1.0 + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data.div_(clip_global_grad_norm) + state = self.state[p] + + # State initialization + if len(state) == 0: + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data) + + exp_avg_, exp_avg_sq_ = state['exp_avg'], state['exp_avg_sq'] + + # Decay the first and second moment running average coefficient + # m_t + exp_avg_.mul_(beta1).add_(grad, alpha=beta3) + # v_t + exp_avg_sq_.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + # create clones to avoid modifying runner stats + exp_avg = exp_avg_.div(bias_correction1) + exp_avg_sq = exp_avg_sq_.div(bias_correction2) + + # || w_t || + weight_norm = p.data.norm(2.0) + # u_t + exp_avg_sq_sqrt = torch.sqrt(exp_avg_sq) + adam_step = exp_avg.div_(exp_avg_sq_sqrt.add_(group['eps'])) + if group['weight_decay'] != 0: + adam_step.add_(p.data, alpha=group['weight_decay']) + # || u_t || + adam_norm = adam_step.norm(2.0) + if (group['weight_decay'] != 0 or self.use_nvlamb) and adam_norm > 0 and weight_norm > 0: + trust_ratio = weight_norm / adam_norm + trust_ratio = trust_ratio.item() + else: + trust_ratio = 1 + + state['weight_norm'] = weight_norm + state['adam_norm'] = adam_norm + state['trust_ratio'] = trust_ratio + + p.data.add_(adam_step, alpha=-step_size * trust_ratio) + + return loss diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py index 2017d21f..b3035514 100644 --- a/timm/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -7,9 +7,11 @@ import torch import torch.nn as nn import torch.optim as optim +from .adabelief import AdaBelief from .adafactor import Adafactor from .adahessian import Adahessian from .adamp import AdamP +from .lamb import NvLamb from .lookahead import Lookahead from .nadam import Nadam from .novograd import NovoGrad @@ -17,7 +19,6 @@ from .nvnovograd import NvNovoGrad from .radam import RAdam from .rmsprop_tf import RMSpropTF from .sgdp import SGDP -from .adabelief import AdaBelief try: from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD @@ -148,6 +149,10 @@ def create_optimizer_v2( optimizer = NovoGrad(parameters, **opt_args) elif opt_lower == 'nvnovograd': optimizer = NvNovoGrad(parameters, **opt_args) + elif opt_lower == 'lamb': + optimizer = NvLamb(parameters, **opt_args) + + # NVIDIA fused optimizers, require APEX to be installed elif opt_lower == 'fusedsgd': opt_args.pop('eps', None) optimizer = FusedSGD(parameters, momentum=momentum, nesterov=True, **opt_args) @@ -163,6 +168,7 @@ def create_optimizer_v2( elif opt_lower == 'fusednovograd': opt_args.setdefault('betas', (0.95, 0.98)) optimizer = FusedNovoGrad(parameters, **opt_args) + else: assert False and "Invalid optimizer" raise ValueError