""" PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb This optimizer code was adapted from the following (starting with latest) * https://github.com/HabanaAI/Model-References/blob/2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py * https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py * https://github.com/cybertronai/pytorch-lamb Use FusedLamb if you can. The reason for including this variant of Lamb is to have a version that is similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install APEX for whatever reason. Original copyrights for above sources are below. """ # Copyright (c) 2021, Habana Labs Ltd. All rights reserved. # Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # MIT License # # Copyright (c) 2019 cybertronai # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. import math import torch from torch.optim import Optimizer class Lamb(Optimizer): """Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. Arguments: params (iterable): iterable of parameters to optimize or dicts defining parameter groups. lr (float, optional): learning rate. (default: 1e-3) betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its norm. (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability. (default: 1e-8) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) grad_averaging (bool, optional): whether apply (1-beta2) to grad when calculating running averages of gradient. (default: True) set_grad_none (bool, optional): whether set grad to None when zero_grad() method is called. (default: True) max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0) use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0 weight decay parameter (default: False) .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: https://arxiv.org/abs/1904.00962 .. _On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ """ def __init__( self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01, grad_averaging=True, max_grad_norm=1.0, decoupled_decay=False, use_nvlamb=False): defaults = dict( lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay, grad_averaging=grad_averaging, max_grad_norm=max_grad_norm, decoupled_decay=decoupled_decay, use_nvlamb=use_nvlamb) super().__init__(params, defaults) def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ device = self.param_groups[0]["params"][0].device one_tensor = torch.tensor(1.0, device=device) loss = None if closure is not None: loss = closure() global_grad_norm = torch.zeros(1, device=device) for group in self.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad.data if grad.is_sparse: raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') global_grad_norm.add_(grad.pow(2).sum()) global_grad_norm = torch.sqrt(global_grad_norm) max_grad_norm = self.defaults['max_grad_norm'] clip_global_grad_norm = torch.where( global_grad_norm > max_grad_norm, global_grad_norm / max_grad_norm, one_tensor) for group in self.param_groups: bias_correction = 1 if group['bias_correction'] else 0 beta1, beta2 = group['betas'] grad_averaging = 1 if group['grad_averaging'] else 0 beta3 = 1 - beta1 if grad_averaging else 1.0 # assume same step across group now to simplify things # per parameter step can be easily support by making it tensor, or pass list into kernel if 'step' in group: group['step'] += 1 else: group['step'] = 1 step_size = group['lr'] if bias_correction: bias_correction1 = 1 - beta1 ** group['step'] bias_correction2 = 1 - beta2 ** group['step'] else: bias_correction1, bias_correction2 = 1.0, 1.0 for p in group['params']: if p.grad is None: continue grad = p.grad.data.div_(clip_global_grad_norm) state = self.state[p] # State initialization if len(state) == 0: # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like(p.data) # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like(p.data) decoupled_decay = group['decoupled_decay'] weight_decay = group['weight_decay'] if decoupled_decay and weight_decay != 0: p.data.mul_(1. - group['lr'] * weight_decay) exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) update = (exp_avg / bias_correction1).div_(denom) if not decoupled_decay and weight_decay != 0: update.add_(p.data, alpha=weight_decay) trust_ratio = one_tensor if weight_decay != 0 or group['use_nvlamb']: # Layer adaptation. By default, skip layer adaptation on parameters that are # excluded from weight norm, unless use_nvlamb == True, then always enabled. w_norm = p.data.norm(2.0) g_norm = update.norm(2.0) trust_ratio = torch.where( w_norm > 0, torch.where(g_norm > 0, w_norm / g_norm, one_tensor), one_tensor, ) update.mul_(trust_ratio) p.data.add_(update, alpha=-step_size) return loss