From f35d6ea57b1d1ca7106cc7df8b6fd3e6bab0b187 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 16 Feb 2023 15:48:00 -0800 Subject: [PATCH] Add multi-tensor (foreach) version of Lion in style of upcoming PyTorch 2.0 optimizers --- timm/optim/lion.py | 169 ++++++++++++++++++++++++++++++++---- timm/optim/optim_factory.py | 23 ++++- 2 files changed, 175 insertions(+), 17 deletions(-) diff --git a/timm/optim/lion.py b/timm/optim/lion.py index 434d9831..4d808642 100644 --- a/timm/optim/lion.py +++ b/timm/optim/lion.py @@ -16,6 +16,8 @@ Original Impl: https://github.com/google/automl/tree/master/lion # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from typing import List + import torch from torch.optim.optimizer import Optimizer @@ -23,7 +25,15 @@ from torch.optim.optimizer import Optimizer class Lion(Optimizer): r"""Implements Lion algorithm.""" - def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0): + def __init__( + self, + params, + lr=1e-4, + betas=(0.9, 0.99), + weight_decay=0.0, + maximize=False, + foreach=None, + ): """Initialize the hyperparameters. Args: @@ -41,9 +51,21 @@ class Lion(Optimizer): raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1])) - defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) + defaults = dict( + lr=lr, + betas=betas, + weight_decay=weight_decay, + foreach=foreach, + maximize=maximize, + ) super().__init__(params, defaults) + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('maximize', False) + group.setdefault('foreach', None) + @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. @@ -61,27 +83,144 @@ class Lion(Optimizer): loss = closure() for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + beta1, beta2 = group['betas'] + for p in group['params']: if p.grad is None: continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('Lion does not support sparse gradients') + grads.append(p.grad) - # Perform stepweight decay - p.data.mul_(1 - group['lr'] * group['weight_decay']) - - grad = p.grad state = self.state[p] + # State initialization if len(state) == 0: - # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p) + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) - exp_avg = state['exp_avg'] - beta1, beta2 = group['betas'] + exp_avgs.append(state['exp_avg']) - # Weight update - update = exp_avg * beta1 + grad * (1 - beta1) - p.add_(torch.sign(update), alpha=-group['lr']) - # Decay the momentum running average coefficient - exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2) + lion( + params_with_grad, + grads, + exp_avgs, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + maximize=group['maximize'], + foreach=group['foreach'], + ) return loss + + +def lion( + params: List[torch.Tensor], + grads: List[torch.Tensor], + exp_avgs: List[torch.Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + maximize: bool = False, + foreach: bool = None, + *, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, +): + r"""Functional API that performs Lion algorithm computation. + """ + if foreach is None: + # Placeholder for more complex foreach logic to be added when value is not set + foreach = False + + if foreach and torch.jit.is_scripting(): + raise RuntimeError('torch.jit.script not supported with foreach optimizers') + + if foreach and not torch.jit.is_scripting(): + func = _multi_tensor_lion + else: + func = _single_tensor_lion + + func( + params, + grads, + exp_avgs, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + maximize=maximize, + ) + + +def _single_tensor_lion( + params: List[torch.Tensor], + grads: List[torch.Tensor], + exp_avgs: List[torch.Tensor], + *, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + maximize: bool, +): + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + + if torch.is_complex(param): + grad = torch.view_as_real(grad) + exp_avg = torch.view_as_real(exp_avg) + param = torch.view_as_real(param) + + # Perform stepweight decay + param.mul_(1 - lr * weight_decay) + + # Weight update + update = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1) + param.add_(torch.sign(update), alpha=-lr) + + # Decay the momentum running average coefficient + exp_avg.lerp_(grad, 1 - beta2) + + +def _multi_tensor_lion( + params: List[torch.Tensor], + grads: List[torch.Tensor], + exp_avgs: List[torch.Tensor], + *, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + maximize: bool, +): + if len(params) == 0: + return + + if maximize: + grads = torch._foreach_neg(tuple(grads)) # type: ignore[assignment] + + grads = [torch.view_as_real(x) if torch.is_complex(x) else x for x in grads] + exp_avgs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in exp_avgs] + params = [torch.view_as_real(x) if torch.is_complex(x) else x for x in params] + + # Perform stepweight decay + torch._foreach_mul_(params, 1 - lr * weight_decay) + + # Weight update + updates = torch._foreach_mul(exp_avgs, beta1) + torch._foreach_add_(updates, grads, alpha=1 - beta1) + + updates = [u.sign() for u in updates] + torch._foreach_add_(params, updates, alpha=-lr) + + # Decay the momentum running average coefficient + torch._foreach_mul_(exp_avgs, beta2) + torch._foreach_add_(exp_avgs, grads, alpha=1 - beta2) diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py index c2f253d3..10950210 100644 --- a/timm/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -36,6 +36,12 @@ except ImportError: _logger = logging.getLogger(__name__) +# optimizers to default to multi-tensor +_DEFAULT_FOREACH = { + 'lion', +} + + def param_groups_weight_decay( model: nn.Module, weight_decay=1e-5, @@ -162,7 +168,8 @@ def optimizer_kwargs(cfg): opt=cfg.opt, lr=cfg.lr, weight_decay=cfg.weight_decay, - momentum=cfg.momentum) + momentum=cfg.momentum, + ) if getattr(cfg, 'opt_eps', None) is not None: kwargs['eps'] = cfg.opt_eps if getattr(cfg, 'opt_betas', None) is not None: @@ -171,6 +178,8 @@ def optimizer_kwargs(cfg): kwargs['layer_decay'] = cfg.layer_decay if getattr(cfg, 'opt_args', None) is not None: kwargs.update(cfg.opt_args) + if getattr(cfg, 'opt_foreach', None) is not None: + kwargs['foreach'] = cfg.opt_foreach return kwargs @@ -191,6 +200,7 @@ def create_optimizer_v2( lr: Optional[float] = None, weight_decay: float = 0., momentum: float = 0.9, + foreach: Optional[bool] = None, filter_bias_and_bn: bool = True, layer_decay: Optional[float] = None, param_group_fn: Optional[Callable] = None, @@ -209,6 +219,7 @@ def create_optimizer_v2( lr: initial learning rate weight_decay: weight decay to apply in optimizer momentum: momentum for momentum based optimizers (others may use betas via kwargs) + foreach: Enable / disable foreach (multi-tensor) operation if True / False. Choose safe default if None filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay **kwargs: extra optimizer specific kwargs to pass through @@ -228,7 +239,8 @@ def create_optimizer_v2( model_or_params, weight_decay=weight_decay, layer_decay=layer_decay, - no_weight_decay_list=no_weight_decay) + no_weight_decay_list=no_weight_decay, + ) weight_decay = 0. elif weight_decay and filter_bias_and_bn: parameters = param_groups_weight_decay(model_or_params, weight_decay, no_weight_decay) @@ -246,9 +258,16 @@ def create_optimizer_v2( assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' opt_args = dict(weight_decay=weight_decay, **kwargs) + if lr is not None: opt_args.setdefault('lr', lr) + if foreach is None: + if opt in _DEFAULT_FOREACH: + opt_args.setdefault('foreach', True) + else: + opt_args['foreach'] = foreach + # basic SGD & related if opt_lower == 'sgd' or opt_lower == 'nesterov': # NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons