From a426511c95e131389237e4ed2696f5967bc66130 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 18 Aug 2021 17:21:56 -0700 Subject: [PATCH] More optimizer cleanup. Change all to no longer use .data. Improve (b)float16 use with adabelief. Add XLA compatible Lars. --- tests/test_optim.py | 27 +++++++ timm/optim/adabelief.py | 54 +++++++------- timm/optim/adafactor.py | 30 ++++---- timm/optim/adamp.py | 19 ++--- timm/optim/adamw.py | 14 ++-- timm/optim/lamb.py | 44 ++++++------ timm/optim/lars.py | 136 ++++++++++++++++++++++++++++++++++++ timm/optim/lookahead.py | 10 +-- timm/optim/madgrad.py | 36 +++++----- timm/optim/nadam.py | 22 +++--- timm/optim/nvnovograd.py | 12 ++-- timm/optim/optim_factory.py | 9 +++ timm/optim/radam.py | 24 ++++--- timm/optim/rmsprop_tf.py | 24 ++++--- timm/optim/sgdp.py | 12 ++-- 15 files changed, 332 insertions(+), 141 deletions(-) create mode 100644 timm/optim/lars.py diff --git a/tests/test_optim.py b/tests/test_optim.py index eacc8e29..e1b78482 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -490,6 +490,33 @@ def test_lamb(optimizer): _test_model(optimizer, dict(lr=1e-3)) +@pytest.mark.parametrize('optimizer', ['lars', 'larc', 'nlars', 'nlarc']) +def test_lars(optimizer): + _test_basic_cases( + lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) + ) + _test_basic_cases( + lambda weight, bias: create_optimizer_v2( + _build_params_dict(weight, bias, lr=1e-3), + optimizer, + lr=1e-1) + ) + _test_basic_cases( + lambda weight, bias: create_optimizer_v2( + _build_params_dict_single(weight, bias, lr=1e-3), + optimizer, + lr=1e-3) + ) + _test_basic_cases( + lambda weight, bias: create_optimizer_v2( + _build_params_dict_single(weight, bias, lr=1e-3), optimizer) + ) + _test_rosenbrock( + lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) + ) + _test_model(optimizer, dict(lr=1e-3)) + + @pytest.mark.parametrize('optimizer', ['madgrad', 'madgradw']) def test_madgrad(optimizer): _test_basic_cases( diff --git a/timm/optim/adabelief.py b/timm/optim/adabelief.py index fc96e935..951d715c 100644 --- a/timm/optim/adabelief.py +++ b/timm/optim/adabelief.py @@ -68,6 +68,7 @@ class AdaBelief(Optimizer): for group in self.param_groups: group.setdefault('amsgrad', False) + @torch.no_grad() def reset(self): for group in self.param_groups: for p in group['params']: @@ -77,14 +78,15 @@ class AdaBelief(Optimizer): # State initialization state['step'] = 0 # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p.data) + state['exp_avg'] = torch.zeros_like(p) # Exponential moving average of squared gradient values - state['exp_avg_var'] = torch.zeros_like(p.data) + state['exp_avg_var'] = torch.zeros_like(p) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values - state['max_exp_avg_var'] = torch.zeros_like(p.data) + state['max_exp_avg_var'] = torch.zeros_like(p) + @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. Arguments: @@ -93,50 +95,47 @@ class AdaBelief(Optimizer): """ loss = None if closure is not None: - loss = closure() + with torch.enable_grad(): + loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue - - # cast data type - half_precision = False - if p.data.dtype == torch.float16: - half_precision = True - p.data = p.data.float() - p.grad = p.grad.float() - - grad = p.grad.data + grad = p.grad + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() if grad.is_sparse: raise RuntimeError( 'AdaBelief does not support sparse gradients, please consider SparseAdam instead') - amsgrad = group['amsgrad'] - state = self.state[p] + p_fp32 = p + if p.dtype in {torch.float16, torch.bfloat16}: + p_fp32 = p_fp32.float() + amsgrad = group['amsgrad'] beta1, beta2 = group['betas'] - + state = self.state[p] # State initialization if len(state) == 0: state['step'] = 0 # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p.data) + state['exp_avg'] = torch.zeros_like(p_fp32) # Exponential moving average of squared gradient values - state['exp_avg_var'] = torch.zeros_like(p.data) + state['exp_avg_var'] = torch.zeros_like(p_fp32) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values - state['max_exp_avg_var'] = torch.zeros_like(p.data) + state['max_exp_avg_var'] = torch.zeros_like(p_fp32) # perform weight decay, check if decoupled weight decay if group['decoupled_decay']: if not group['fixed_decay']: - p.data.mul_(1.0 - group['lr'] * group['weight_decay']) + p_fp32.mul_(1.0 - group['lr'] * group['weight_decay']) else: - p.data.mul_(1.0 - group['weight_decay']) + p_fp32.mul_(1.0 - group['weight_decay']) else: if group['weight_decay'] != 0: - grad.add_(p.data, alpha=group['weight_decay']) + grad.add_(p_fp32, alpha=group['weight_decay']) # get current state variable exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var'] @@ -164,7 +163,7 @@ class AdaBelief(Optimizer): if not group['rectify']: # Default update step_size = group['lr'] / bias_correction1 - p.data.addcdiv_(exp_avg, denom, value=-step_size) + p_fp32.addcdiv_(exp_avg, denom, value=-step_size) else: # Rectified update, forked from RAdam buffered = group['buffer'][int(state['step'] % 10)] @@ -192,12 +191,11 @@ class AdaBelief(Optimizer): if num_sma >= 5: denom = exp_avg_var.sqrt().add_(group['eps']) - p.data.addcdiv_(exp_avg, denom, value=-step_size * group['lr']) + p_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr']) elif step_size > 0: - p.data.add_(exp_avg, alpha=-step_size * group['lr']) + p_fp32.add_(exp_avg, alpha=-step_size * group['lr']) - if half_precision: - p.data = p.data.half() - p.grad = p.grad.half() + if p.dtype in {torch.float16, torch.bfloat16}: + p.copy_(p_fp32) return loss diff --git a/timm/optim/adafactor.py b/timm/optim/adafactor.py index 5cf734b4..06057433 100644 --- a/timm/optim/adafactor.py +++ b/timm/optim/adafactor.py @@ -76,6 +76,7 @@ class Adafactor(torch.optim.Optimizer): c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() return torch.mul(r_factor, c_factor) + @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. Arguments: @@ -83,22 +84,22 @@ class Adafactor(torch.optim.Optimizer): """ loss = None if closure is not None: - loss = closure() + with torch.enable_grad(): + loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue - grad = p.grad.data + grad = p.grad if grad.dtype in {torch.float16, torch.bfloat16}: grad = grad.float() if grad.is_sparse: raise RuntimeError('Adafactor does not support sparse gradients.') state = self.state[p] - grad_shape = grad.shape - factored, use_first_moment = self._get_options(group, grad_shape) + factored, use_first_moment = self._get_options(group, grad.shape) # State Initialization if len(state) == 0: state['step'] = 0 @@ -107,8 +108,8 @@ class Adafactor(torch.optim.Optimizer): # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like(grad) if factored: - state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).to(grad) - state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) + state['exp_avg_sq_row'] = torch.zeros(grad.shape[:-1]).to(grad) + state['exp_avg_sq_col'] = torch.zeros(grad.shape[:-2] + grad.shape[-1:]).to(grad) else: state['exp_avg_sq'] = torch.zeros_like(grad) @@ -122,12 +123,12 @@ class Adafactor(torch.optim.Optimizer): else: state['exp_avg_sq'] = state['exp_avg_sq'].to(grad) - p_data_fp32 = p.data - if p.data.dtype in {torch.float16, torch.bfloat16}: - p_data_fp32 = p_data_fp32.float() + p_fp32 = p + if p.dtype in {torch.float16, torch.bfloat16}: + p_fp32 = p_fp32.float() state['step'] += 1 - state['RMS'] = self._rms(p_data_fp32) + state['RMS'] = self._rms(p_fp32) lr_t = self._get_lr(group, state) beta2t = 1.0 - math.pow(state['step'], group['decay_rate']) @@ -157,11 +158,10 @@ class Adafactor(torch.optim.Optimizer): update = exp_avg if group['weight_decay'] != 0: - p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * lr_t) + p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * lr_t) - p_data_fp32.add_(-update) - - if p.data.dtype in {torch.float16, torch.bfloat16}: - p.data.copy_(p_data_fp32) + p_fp32.add_(-update) + if p.dtype in {torch.float16, torch.bfloat16}: + p.copy_(p_fp32) return loss diff --git a/timm/optim/adamp.py b/timm/optim/adamp.py index dfe8c7dc..ee187633 100644 --- a/timm/optim/adamp.py +++ b/timm/optim/adamp.py @@ -26,12 +26,13 @@ def projection(p, grad, perturb, delta: float, wd_ratio: float, eps: float): wd = 1. expand_size = (-1,) + (1,) * (len(p.shape) - 1) for view_func in [_channel_view, _layer_view]: - param_view = view_func(p.data) + param_view = view_func(p) grad_view = view_func(grad) cosine_sim = F.cosine_similarity(grad_view, param_view, dim=1, eps=eps).abs_() + # FIXME this is a problem for PyTorch XLA if cosine_sim.max() < delta / math.sqrt(param_view.size(1)): - p_n = p.data / param_view.norm(p=2, dim=1).add_(eps).reshape(expand_size) + p_n = p / param_view.norm(p=2, dim=1).add_(eps).reshape(expand_size) perturb -= p_n * view_func(p_n * perturb).sum(dim=1).reshape(expand_size) wd = wd_ratio return perturb, wd @@ -47,17 +48,19 @@ class AdamP(Optimizer): delta=delta, wd_ratio=wd_ratio, nesterov=nesterov) super(AdamP, self).__init__(params, defaults) + @torch.no_grad() def step(self, closure=None): loss = None if closure is not None: - loss = closure() + with torch.enable_grad(): + loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue - grad = p.grad.data + grad = p.grad beta1, beta2 = group['betas'] nesterov = group['nesterov'] @@ -66,8 +69,8 @@ class AdamP(Optimizer): # State initialization if len(state) == 0: state['step'] = 0 - state['exp_avg'] = torch.zeros_like(p.data) - state['exp_avg_sq'] = torch.zeros_like(p.data) + state['exp_avg'] = torch.zeros_like(p) + state['exp_avg_sq'] = torch.zeros_like(p) # Adam exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] @@ -94,9 +97,9 @@ class AdamP(Optimizer): # Weight decay if group['weight_decay'] > 0: - p.data.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio) + p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio) # Step - p.data.add_(perturb, alpha=-step_size) + p.add_(perturb, alpha=-step_size) return loss diff --git a/timm/optim/adamw.py b/timm/optim/adamw.py index 4d2bf078..66478bc6 100644 --- a/timm/optim/adamw.py +++ b/timm/optim/adamw.py @@ -55,6 +55,7 @@ class AdamW(Optimizer): for group in self.param_groups: group.setdefault('amsgrad', False) + @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. @@ -64,7 +65,8 @@ class AdamW(Optimizer): """ loss = None if closure is not None: - loss = closure() + with torch.enable_grad(): + loss = closure() for group in self.param_groups: for p in group['params']: @@ -75,7 +77,7 @@ class AdamW(Optimizer): p.data.mul_(1 - group['lr'] * group['weight_decay']) # Perform optimization step - grad = p.grad.data + grad = p.grad if grad.is_sparse: raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') amsgrad = group['amsgrad'] @@ -86,12 +88,12 @@ class AdamW(Optimizer): if len(state) == 0: state['step'] = 0 # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p.data) + state['exp_avg'] = torch.zeros_like(p) # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(p.data) + state['exp_avg_sq'] = torch.zeros_like(p) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values - state['max_exp_avg_sq'] = torch.zeros_like(p.data) + state['max_exp_avg_sq'] = torch.zeros_like(p) exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] if amsgrad: @@ -115,6 +117,6 @@ class AdamW(Optimizer): step_size = group['lr'] / bias_correction1 - p.data.addcdiv_(exp_avg, denom, value=-step_size) + p.addcdiv_(exp_avg, denom, value=-step_size) return loss diff --git a/timm/optim/lamb.py b/timm/optim/lamb.py index 5308e348..a65b4e0f 100644 --- a/timm/optim/lamb.py +++ b/timm/optim/lamb.py @@ -5,10 +5,14 @@ This optimizer code was adapted from the following (starting with latest) * https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py * https://github.com/cybertronai/pytorch-lamb -Use FusedLamb if you can. The reason for including this variant of Lamb is to have a version that is -similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install APEX for whatever reason. +Use FusedLamb if you can (GPU). The reason for including this variant of Lamb is to have a version that is +similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install/use APEX. + +In addition to some cleanup, this Lamb impl has been modified to support PyTorch XLA and has been tested on TPU. Original copyrights for above sources are below. + +Modifications Copyright 2021 Ross Wightman """ # Copyright (c) 2021, Habana Labs Ltd. All rights reserved. @@ -60,8 +64,7 @@ class Lamb(Optimizer): LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. Arguments: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups. + params (iterable): iterable of parameters to optimize or dicts defining parameter groups. lr (float, optional): learning rate. (default: 1e-3) betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its norm. (default: (0.9, 0.999)) @@ -72,8 +75,7 @@ class Lamb(Optimizer): calculating running averages of gradient. (default: True) set_grad_none (bool, optional): whether set grad to None when zero_grad() method is called. (default: True) - max_grad_norm (float, optional): value used to clip global grad norm - (default: 1.0) + max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0) use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0 weight decay parameter (default: False) @@ -91,25 +93,26 @@ class Lamb(Optimizer): grad_averaging=grad_averaging, max_grad_norm=max_grad_norm, use_nvlamb=use_nvlamb) super().__init__(params, defaults) + @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ - device = self.param_groups[0]["params"][0].device - one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly - loss = None if closure is not None: - loss = closure() + with torch.enable_grad(): + loss = closure() + device = self.param_groups[0]["params"][0].device + one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly global_grad_norm = torch.zeros(1, device=device) for group in self.param_groups: for p in group['params']: if p.grad is None: continue - grad = p.grad.data + grad = p.grad if grad.is_sparse: raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') global_grad_norm.add_(grad.pow(2).sum()) @@ -145,15 +148,15 @@ class Lamb(Optimizer): for p in group['params']: if p.grad is None: continue - grad = p.grad.data.div_(clip_global_grad_norm) + grad = p.grad.div_(clip_global_grad_norm) state = self.state[p] # State initialization if len(state) == 0: - # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of gradient valuesa + state['exp_avg'] = torch.zeros_like(p) # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(p.data) + state['exp_avg_sq'] = torch.zeros_like(p) exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] @@ -166,20 +169,21 @@ class Lamb(Optimizer): weight_decay = group['weight_decay'] if weight_decay != 0: - update.add_(p.data, alpha=weight_decay) + update.add_(p, alpha=weight_decay) - trust_ratio = one_tensor if weight_decay != 0 or group['use_nvlamb']: # Layer adaptation. By default, skip layer adaptation on parameters that are # excluded from weight decay, unless use_nvlamb == True, then always enabled. - w_norm = p.data.norm(2.0) + w_norm = p.norm(2.0) g_norm = update.norm(2.0) + # FIXME nested where required since logical and/or not working in PT XLA trust_ratio = torch.where( w_norm > 0, torch.where(g_norm > 0, w_norm / g_norm, one_tensor), one_tensor, ) - update.mul_(trust_ratio) - p.data.add_(update, alpha=-group['lr']) + update.mul_(trust_ratio) + + p.add_(update, alpha=-group['lr']) return loss diff --git a/timm/optim/lars.py b/timm/optim/lars.py new file mode 100644 index 00000000..7c530e1a --- /dev/null +++ b/timm/optim/lars.py @@ -0,0 +1,136 @@ +""" PyTorch LARS / LARC Optimizer + +An implementation of LARS (SGD) + LARC in PyTorch + +Based on: + * PyTorch SGD: https://github.com/pytorch/pytorch/blob/1.7/torch/optim/sgd.py#L100 + * NVIDIA APEX LARC: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py + +Additional cleanup and modifications to properly support PyTorch XLA. + +Copyright 2021 Ross Wightman +""" +import torch +from torch.optim.optimizer import Optimizer, required + + +class Lars(Optimizer): + """ LARS for PyTorch + + Paper: `Large batch training of Convolutional Networks` - https://arxiv.org/pdf/1708.03888.pdf + + Args: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + momentum (float, optional): momentum factor (default: 0) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + dampening (float, optional): dampening for momentum (default: 0) + nesterov (bool, optional): enables Nesterov momentum (default: False) + trust_coeff (float): trust coefficient for computing adaptive lr / trust_ratio (default: 0.001) + eps (float): eps for division denominator (default: 1e-8) + larc (bool): enable LARC clipping (default: False) + always_scale (bool): always apply LARS scaling, otherwise only when group weight_decay != 0 (default: False) + """ + + def __init__( + self, + params, + lr=required, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + trust_coeff=0.001, + eps=1e-8, + larc=False, + always_scale=False, + ): + if lr is not required and lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if momentum < 0.0: + raise ValueError(f"Invalid momentum value: {momentum}") + if weight_decay < 0.0: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + + defaults = dict( + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + trust_coeff=trust_coeff, + eps=eps, + larc=larc, + always_scale=always_scale, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("nesterov", False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + device = self.param_groups[0]["params"][0].device + one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly + + # exclude scaling for params with 0 weight decay + for group in self.param_groups: + weight_decay = group['weight_decay'] + momentum = group['momentum'] + dampening = group['dampening'] + nesterov = group['nesterov'] + trust_coeff = group['trust_coeff'] + eps = group['eps'] + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + + # apply LARS scaling, LARC clipping, weight decay + # ref: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py + if weight_decay != 0 or group['always_scale']: + w_norm = p.norm(2.0) + g_norm = grad.norm(2.0) + trust_ratio = trust_coeff * w_norm / (g_norm + w_norm * weight_decay + eps) + # FIXME nested where required since logical and/or not working in PT XLA + trust_ratio = torch.where( + w_norm > 0, + torch.where(g_norm > 0, trust_ratio, one_tensor), + one_tensor, + ) + if group['larc']: + trust_ratio = torch.minimum(trust_ratio / group['lr'], one_tensor) + grad.add(p, alpha=weight_decay) + grad.mul_(trust_ratio) + + # apply SGD update https://github.com/pytorch/pytorch/blob/1.7/torch/optim/sgd.py#L100 + if momentum != 0: + param_state = self.state[p] + if 'momentum_buffer' not in param_state: + buf = param_state['momentum_buffer'] = torch.clone(grad).detach() + else: + buf = param_state['momentum_buffer'] + buf.mul_(momentum).add_(grad, alpha=1. - dampening) + if nesterov: + grad = grad.add(buf, alpha=momentum) + else: + grad = buf + + p.add_(grad, alpha=-group['lr']) + + return loss \ No newline at end of file diff --git a/timm/optim/lookahead.py b/timm/optim/lookahead.py index d0ab2ed2..462c3acd 100644 --- a/timm/optim/lookahead.py +++ b/timm/optim/lookahead.py @@ -27,22 +27,24 @@ class Lookahead(Optimizer): for group in self._base_optimizer.param_groups: group.setdefault(name, default) + @torch.no_grad() def update_slow(self, group): for fast_p in group["params"]: if fast_p.grad is None: continue param_state = self._base_optimizer.state[fast_p] if 'lookahead_slow_buff' not in param_state: - param_state['lookahead_slow_buff'] = torch.empty_like(fast_p.data) - param_state['lookahead_slow_buff'].copy_(fast_p.data) + param_state['lookahead_slow_buff'] = torch.empty_like(fast_p) + param_state['lookahead_slow_buff'].copy_(fast_p) slow = param_state['lookahead_slow_buff'] - slow.add_(fast_p.data - slow, alpha=group['lookahead_alpha']) - fast_p.data.copy_(slow) + slow.add_(fast_p - slow, alpha=group['lookahead_alpha']) + fast_p.copy_(slow) def sync_lookahead(self): for group in self._base_optimizer.param_groups: self.update_slow(group) + @torch.no_grad() def step(self, closure=None): loss = self._base_optimizer.step(closure) for group in self._base_optimizer.param_groups: diff --git a/timm/optim/madgrad.py b/timm/optim/madgrad.py index 7f8d73e8..4d88b753 100644 --- a/timm/optim/madgrad.py +++ b/timm/optim/madgrad.py @@ -82,6 +82,7 @@ class MADGRAD(torch.optim.Optimizer): def supports_flat_params(self) -> bool: return True + @torch.no_grad() def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: """Performs a single optimization step. @@ -91,13 +92,10 @@ class MADGRAD(torch.optim.Optimizer): """ loss = None if closure is not None: - loss = closure() + with torch.enable_grad(): + loss = closure() - # step counter must be stored in state to ensure correct behavior under - # optimizer sharding - if 'k' not in self.state: - self.state['k'] = torch.tensor([0], dtype=torch.long) - k = self.state['k'].item() + step = self.state.setdefault('step', 0) # k for group in self.param_groups: eps = group["eps"] @@ -106,19 +104,19 @@ class MADGRAD(torch.optim.Optimizer): momentum = group["momentum"] ck = 1 - momentum - lamb = lr * math.pow(k + 1, 0.5) + lamb = lr * math.sqrt(step + 1) for p in group["params"]: if p.grad is None: continue - grad = p.grad.data + grad = p.grad state = self.state[p] if "grad_sum_sq" not in state: - state["grad_sum_sq"] = torch.zeros_like(p.data).detach() - state["s"] = torch.zeros_like(p.data).detach() + state["grad_sum_sq"] = torch.zeros_like(p) + state["s"] = torch.zeros_like(p) if momentum != 0: - state["x0"] = torch.clone(p.data).detach() + state["x0"] = torch.clone(p).detach() if momentum != 0.0 and grad.is_sparse: raise RuntimeError("momentum != 0 is not compatible with sparse gradients") @@ -129,11 +127,11 @@ class MADGRAD(torch.optim.Optimizer): # Apply weight decay if weight_decay != 0: if group['decoupled_decay']: - p.data.mul_(1.0 - group['lr'] * weight_decay) + p.mul_(1.0 - group['lr'] * weight_decay) else: if grad.is_sparse: raise RuntimeError("weight_decay option is not compatible with sparse gradients") - grad.add_(p.data, alpha=weight_decay) + grad.add_(p, alpha=weight_decay) if grad.is_sparse: grad = grad.coalesce() @@ -161,12 +159,12 @@ class MADGRAD(torch.optim.Optimizer): p_kp1_masked_vals = x0_masked_vals.addcdiv(s_masked._values(), rms_masked_vals, value=-1) # Copy updated masked p to dense p using an add operation p_masked._values().add_(p_kp1_masked_vals, alpha=-1) - p.data.add_(p_masked, alpha=-1) + p.add_(p_masked, alpha=-1) else: if momentum == 0: # Compute x_0 from other known quantities rms = grad_sum_sq.pow(1 / 3).add_(eps) - x0 = p.data.addcdiv(s, rms, value=1) + x0 = p.addcdiv(s, rms, value=1) else: x0 = state["x0"] @@ -175,16 +173,16 @@ class MADGRAD(torch.optim.Optimizer): rms = grad_sum_sq.pow(1 / 3).add_(eps) # Update s - s.data.add_(grad, alpha=lamb) + s.add_(grad, alpha=lamb) # Step if momentum == 0: - p.data.copy_(x0.addcdiv(s, rms, value=-1)) + p.copy_(x0.addcdiv(s, rms, value=-1)) else: z = x0.addcdiv(s, rms, value=-1) # p is a moving average of z - p.data.mul_(1 - ck).add_(z, alpha=ck) + p.mul_(1 - ck).add_(z, alpha=ck) - self.state['k'] += 1 + self.state['step'] += 1 return loss diff --git a/timm/optim/nadam.py b/timm/optim/nadam.py index 4382d38c..6268d5d4 100644 --- a/timm/optim/nadam.py +++ b/timm/optim/nadam.py @@ -1,3 +1,5 @@ +import math + import torch from torch.optim.optimizer import Optimizer @@ -33,6 +35,7 @@ class Nadam(Optimizer): lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, schedule_decay=schedule_decay) super(Nadam, self).__init__(params, defaults) + @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. @@ -42,21 +45,22 @@ class Nadam(Optimizer): """ loss = None if closure is not None: - loss = closure() + with torch.enable_grad(): + loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue - grad = p.grad.data + grad = p.grad state = self.state[p] # State initialization if len(state) == 0: state['step'] = 0 state['m_schedule'] = 1. - state['exp_avg'] = torch.zeros_like(p.data) - state['exp_avg_sq'] = torch.zeros_like(p.data) + state['exp_avg'] = torch.zeros_like(p) + state['exp_avg_sq'] = torch.zeros_like(p) # Warming momentum schedule m_schedule = state['m_schedule'] @@ -66,9 +70,10 @@ class Nadam(Optimizer): eps = group['eps'] state['step'] += 1 t = state['step'] + bias_correction2 = 1 - beta2 ** t if group['weight_decay'] != 0: - grad = grad.add(p.data, alpha=group['weight_decay']) + grad = grad.add(p, alpha=group['weight_decay']) momentum_cache_t = beta1 * (1. - 0.5 * (0.96 ** (t * schedule_decay))) momentum_cache_t_1 = beta1 * (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay))) @@ -79,10 +84,9 @@ class Nadam(Optimizer): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1. - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1. - beta2) - exp_avg_sq_prime = exp_avg_sq / (1. - beta2 ** t) - denom = exp_avg_sq_prime.sqrt_().add_(eps) - p.data.addcdiv_(grad, denom, value=-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new)) - p.data.addcdiv_(exp_avg, denom, value=-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next)) + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + p.addcdiv_(grad, denom, value=-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new)) + p.addcdiv_(exp_avg, denom, value=-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next)) return loss diff --git a/timm/optim/nvnovograd.py b/timm/optim/nvnovograd.py index d763dfb0..fda3f4a6 100644 --- a/timm/optim/nvnovograd.py +++ b/timm/optim/nvnovograd.py @@ -51,6 +51,7 @@ class NvNovoGrad(Optimizer): for group in self.param_groups: group.setdefault('amsgrad', False) + @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. @@ -60,13 +61,14 @@ class NvNovoGrad(Optimizer): """ loss = None if closure is not None: - loss = closure() + with torch.enable_grad(): + loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue - grad = p.grad.data + grad = p.grad if grad.is_sparse: raise RuntimeError('Sparse gradients are not supported.') amsgrad = group['amsgrad'] @@ -77,7 +79,7 @@ class NvNovoGrad(Optimizer): if len(state) == 0: state['step'] = 0 # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p.data) + state['exp_avg'] = torch.zeros_like(p) # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) if amsgrad: @@ -108,11 +110,11 @@ class NvNovoGrad(Optimizer): grad.div_(denom) if group['weight_decay'] != 0: - grad.add_(p.data, alpha=group['weight_decay']) + grad.add_(p, alpha=group['weight_decay']) if group['grad_averaging']: grad.mul_(1 - beta1) exp_avg.mul_(beta1).add_(grad) - p.data.add_(exp_avg, alpha=-group['lr']) + p.add_(exp_avg, alpha=-group['lr']) return loss diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py index 2157f73d..bc58014e 100644 --- a/timm/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -12,6 +12,7 @@ from .adafactor import Adafactor from .adahessian import Adahessian from .adamp import AdamP from .lamb import Lamb +from .lars import Lars from .lookahead import Lookahead from .madgrad import MADGRAD from .nadam import Nadam @@ -163,6 +164,14 @@ def create_optimizer_v2( optimizer = Adafactor(parameters, **opt_args) elif opt_lower == 'lamb': optimizer = Lamb(parameters, **opt_args) + elif opt_lower == 'larc': + optimizer = Lars(parameters, momentum=momentum, larc=True, **opt_args) + elif opt_lower == 'lars': + optimizer = Lars(parameters, momentum=momentum, **opt_args) + elif opt_lower == 'nlarc': + optimizer = Lars(parameters, momentum=momentum, larc=True, nesterov=True, **opt_args) + elif opt_lower == 'nlars': + optimizer = Lars(parameters, momentum=momentum, nesterov=True, **opt_args) elif opt_lower == 'madgrad': optimizer = MADGRAD(parameters, momentum=momentum, **opt_args) elif opt_lower == 'madgradw': diff --git a/timm/optim/radam.py b/timm/optim/radam.py index 8e399c98..eb8d22e0 100644 --- a/timm/optim/radam.py +++ b/timm/optim/radam.py @@ -18,31 +18,33 @@ class RAdam(Optimizer): def __setstate__(self, state): super(RAdam, self).__setstate__(state) + @torch.no_grad() def step(self, closure=None): loss = None if closure is not None: - loss = closure() + with torch.enable_grad(): + loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue - grad = p.grad.data.float() + grad = p.grad.float() if grad.is_sparse: raise RuntimeError('RAdam does not support sparse gradients') - p_data_fp32 = p.data.float() + p_fp32 = p.float() state = self.state[p] if len(state) == 0: state['step'] = 0 - state['exp_avg'] = torch.zeros_like(p_data_fp32) - state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) + state['exp_avg'] = torch.zeros_like(p_fp32) + state['exp_avg_sq'] = torch.zeros_like(p_fp32) else: - state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) - state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) + state['exp_avg'] = state['exp_avg'].type_as(p_fp32) + state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_fp32) exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] beta1, beta2 = group['betas'] @@ -73,15 +75,15 @@ class RAdam(Optimizer): buffered[2] = step_size if group['weight_decay'] != 0: - p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr']) + p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * group['lr']) # more conservative since it's an approximated value if num_sma >= 5: denom = exp_avg_sq.sqrt().add_(group['eps']) - p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size) + p_fp32.addcdiv_(exp_avg, denom, value=-step_size) else: - p_data_fp32.add_(exp_avg, alpha=-step_size) + p_fp32.add_(exp_avg, alpha=-step_size) - p.data.copy_(p_data_fp32) + p.copy_(p_fp32) return loss diff --git a/timm/optim/rmsprop_tf.py b/timm/optim/rmsprop_tf.py index e6139b33..0817887d 100644 --- a/timm/optim/rmsprop_tf.py +++ b/timm/optim/rmsprop_tf.py @@ -4,7 +4,7 @@ Originally cut & paste from PyTorch RMSProp https://github.com/pytorch/pytorch/blob/063946d2b3f3f1e953a2a3b54e0b34f1393de295/torch/optim/rmsprop.py Licensed under BSD-Clause 3 (ish), https://github.com/pytorch/pytorch/blob/master/LICENSE -Modifications Copyright 2020 Ross Wightman +Modifications Copyright 2021 Ross Wightman """ import torch @@ -69,6 +69,7 @@ class RMSpropTF(Optimizer): group.setdefault('momentum', 0) group.setdefault('centered', False) + @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. @@ -78,13 +79,14 @@ class RMSpropTF(Optimizer): """ loss = None if closure is not None: - loss = closure() + with torch.enable_grad(): + loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue - grad = p.grad.data + grad = p.grad if grad.is_sparse: raise RuntimeError('RMSprop does not support sparse gradients') state = self.state[p] @@ -92,11 +94,11 @@ class RMSpropTF(Optimizer): # State initialization if len(state) == 0: state['step'] = 0 - state['square_avg'] = torch.ones_like(p.data) # PyTorch inits to zero + state['square_avg'] = torch.ones_like(p) # PyTorch inits to zero if group['momentum'] > 0: - state['momentum_buffer'] = torch.zeros_like(p.data) + state['momentum_buffer'] = torch.zeros_like(p) if group['centered']: - state['grad_avg'] = torch.zeros_like(p.data) + state['grad_avg'] = torch.zeros_like(p) square_avg = state['square_avg'] one_minus_alpha = 1. - group['alpha'] @@ -105,9 +107,9 @@ class RMSpropTF(Optimizer): if group['weight_decay'] != 0: if group['decoupled_decay']: - p.data.mul_(1. - group['lr'] * group['weight_decay']) + p.mul_(1. - group['lr'] * group['weight_decay']) else: - grad = grad.add(p.data, alpha=group['weight_decay']) + grad = grad.add(p, alpha=group['weight_decay']) # Tensorflow order of ops for updating squared avg square_avg.add_(grad.pow(2) - square_avg, alpha=one_minus_alpha) @@ -126,12 +128,12 @@ class RMSpropTF(Optimizer): # Tensorflow accumulates the LR scaling in the momentum buffer if group['lr_in_momentum']: buf.mul_(group['momentum']).addcdiv_(grad, avg, value=group['lr']) - p.data.add_(-buf) + p.add_(-buf) else: # PyTorch scales the param update by LR buf.mul_(group['momentum']).addcdiv_(grad, avg) - p.data.add_(buf, alpha=-group['lr']) + p.add_(buf, alpha=-group['lr']) else: - p.data.addcdiv_(grad, avg, value=-group['lr']) + p.addcdiv_(grad, avg, value=-group['lr']) return loss diff --git a/timm/optim/sgdp.py b/timm/optim/sgdp.py index 6d17739e..baf05fa5 100644 --- a/timm/optim/sgdp.py +++ b/timm/optim/sgdp.py @@ -24,10 +24,12 @@ class SGDP(Optimizer): nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio) super(SGDP, self).__init__(params, defaults) + @torch.no_grad() def step(self, closure=None): loss = None if closure is not None: - loss = closure() + with torch.enable_grad(): + loss = closure() for group in self.param_groups: weight_decay = group['weight_decay'] @@ -38,12 +40,12 @@ class SGDP(Optimizer): for p in group['params']: if p.grad is None: continue - grad = p.grad.data + grad = p.grad state = self.state[p] # State initialization if len(state) == 0: - state['momentum'] = torch.zeros_like(p.data) + state['momentum'] = torch.zeros_like(p) # SGD buf = state['momentum'] @@ -60,9 +62,9 @@ class SGDP(Optimizer): # Weight decay if weight_decay != 0: - p.data.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum)) + p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum)) # Step - p.data.add_(d_p, alpha=-group['lr']) + p.add_(d_p, alpha=-group['lr']) return loss