Add multi-tensor (foreach) version of Lion in style of upcoming PyTorch 2.0 optimizers

pull/1680/head
Ross Wightman 2 years ago
parent 709d5e0d9d
commit f35d6ea57b

@ -16,6 +16,8 @@ Original Impl: https://github.com/google/automl/tree/master/lion
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
from typing import List
import torch import torch
from torch.optim.optimizer import Optimizer from torch.optim.optimizer import Optimizer
@ -23,7 +25,15 @@ from torch.optim.optimizer import Optimizer
class Lion(Optimizer): class Lion(Optimizer):
r"""Implements Lion algorithm.""" r"""Implements Lion algorithm."""
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0): def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0.0,
maximize=False,
foreach=None,
):
"""Initialize the hyperparameters. """Initialize the hyperparameters.
Args: Args:
@ -41,9 +51,21 @@ class Lion(Optimizer):
raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0])) raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0]))
if not 0.0 <= betas[1] < 1.0: if not 0.0 <= betas[1] < 1.0:
raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1])) raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1]))
defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) defaults = dict(
lr=lr,
betas=betas,
weight_decay=weight_decay,
foreach=foreach,
maximize=maximize,
)
super().__init__(params, defaults) super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault('maximize', False)
group.setdefault('foreach', None)
@torch.no_grad() @torch.no_grad()
def step(self, closure=None): def step(self, closure=None):
"""Performs a single optimization step. """Performs a single optimization step.
@ -61,27 +83,144 @@ class Lion(Optimizer):
loss = closure() loss = closure()
for group in self.param_groups: for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
beta1, beta2 = group['betas']
for p in group['params']: for p in group['params']:
if p.grad is None: if p.grad is None:
continue continue
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError('Lion does not support sparse gradients')
grads.append(p.grad)
# Perform stepweight decay
p.data.mul_(1 - group['lr'] * group['weight_decay'])
grad = p.grad
state = self.state[p] state = self.state[p]
# State initialization # State initialization
if len(state) == 0: if len(state) == 0:
# Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
state['exp_avg'] = torch.zeros_like(p)
exp_avg = state['exp_avg'] exp_avgs.append(state['exp_avg'])
beta1, beta2 = group['betas']
lion(
params_with_grad,
grads,
exp_avgs,
beta1=beta1,
beta2=beta2,
lr=group['lr'],
weight_decay=group['weight_decay'],
maximize=group['maximize'],
foreach=group['foreach'],
)
return loss
def lion(
params: List[torch.Tensor],
grads: List[torch.Tensor],
exp_avgs: List[torch.Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
maximize: bool = False,
foreach: bool = None,
*,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
):
r"""Functional API that performs Lion algorithm computation.
"""
if foreach is None:
# Placeholder for more complex foreach logic to be added when value is not set
foreach = False
if foreach and torch.jit.is_scripting():
raise RuntimeError('torch.jit.script not supported with foreach optimizers')
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_lion
else:
func = _single_tensor_lion
func(
params,
grads,
exp_avgs,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
maximize=maximize,
)
def _single_tensor_lion(
params: List[torch.Tensor],
grads: List[torch.Tensor],
exp_avgs: List[torch.Tensor],
*,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
maximize: bool,
):
for i, param in enumerate(params):
grad = grads[i] if not maximize else -grads[i]
exp_avg = exp_avgs[i]
if torch.is_complex(param):
grad = torch.view_as_real(grad)
exp_avg = torch.view_as_real(exp_avg)
param = torch.view_as_real(param)
# Perform stepweight decay
param.mul_(1 - lr * weight_decay)
# Weight update # Weight update
update = exp_avg * beta1 + grad * (1 - beta1) update = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1)
p.add_(torch.sign(update), alpha=-group['lr']) param.add_(torch.sign(update), alpha=-lr)
# Decay the momentum running average coefficient # Decay the momentum running average coefficient
exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2) exp_avg.lerp_(grad, 1 - beta2)
return loss
def _multi_tensor_lion(
params: List[torch.Tensor],
grads: List[torch.Tensor],
exp_avgs: List[torch.Tensor],
*,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
maximize: bool,
):
if len(params) == 0:
return
if maximize:
grads = torch._foreach_neg(tuple(grads)) # type: ignore[assignment]
grads = [torch.view_as_real(x) if torch.is_complex(x) else x for x in grads]
exp_avgs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in exp_avgs]
params = [torch.view_as_real(x) if torch.is_complex(x) else x for x in params]
# Perform stepweight decay
torch._foreach_mul_(params, 1 - lr * weight_decay)
# Weight update
updates = torch._foreach_mul(exp_avgs, beta1)
torch._foreach_add_(updates, grads, alpha=1 - beta1)
updates = [u.sign() for u in updates]
torch._foreach_add_(params, updates, alpha=-lr)
# Decay the momentum running average coefficient
torch._foreach_mul_(exp_avgs, beta2)
torch._foreach_add_(exp_avgs, grads, alpha=1 - beta2)

@ -36,6 +36,12 @@ except ImportError:
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
# optimizers to default to multi-tensor
_DEFAULT_FOREACH = {
'lion',
}
def param_groups_weight_decay( def param_groups_weight_decay(
model: nn.Module, model: nn.Module,
weight_decay=1e-5, weight_decay=1e-5,
@ -162,7 +168,8 @@ def optimizer_kwargs(cfg):
opt=cfg.opt, opt=cfg.opt,
lr=cfg.lr, lr=cfg.lr,
weight_decay=cfg.weight_decay, weight_decay=cfg.weight_decay,
momentum=cfg.momentum) momentum=cfg.momentum,
)
if getattr(cfg, 'opt_eps', None) is not None: if getattr(cfg, 'opt_eps', None) is not None:
kwargs['eps'] = cfg.opt_eps kwargs['eps'] = cfg.opt_eps
if getattr(cfg, 'opt_betas', None) is not None: if getattr(cfg, 'opt_betas', None) is not None:
@ -171,6 +178,8 @@ def optimizer_kwargs(cfg):
kwargs['layer_decay'] = cfg.layer_decay kwargs['layer_decay'] = cfg.layer_decay
if getattr(cfg, 'opt_args', None) is not None: if getattr(cfg, 'opt_args', None) is not None:
kwargs.update(cfg.opt_args) kwargs.update(cfg.opt_args)
if getattr(cfg, 'opt_foreach', None) is not None:
kwargs['foreach'] = cfg.opt_foreach
return kwargs return kwargs
@ -191,6 +200,7 @@ def create_optimizer_v2(
lr: Optional[float] = None, lr: Optional[float] = None,
weight_decay: float = 0., weight_decay: float = 0.,
momentum: float = 0.9, momentum: float = 0.9,
foreach: Optional[bool] = None,
filter_bias_and_bn: bool = True, filter_bias_and_bn: bool = True,
layer_decay: Optional[float] = None, layer_decay: Optional[float] = None,
param_group_fn: Optional[Callable] = None, param_group_fn: Optional[Callable] = None,
@ -209,6 +219,7 @@ def create_optimizer_v2(
lr: initial learning rate lr: initial learning rate
weight_decay: weight decay to apply in optimizer weight_decay: weight decay to apply in optimizer
momentum: momentum for momentum based optimizers (others may use betas via kwargs) momentum: momentum for momentum based optimizers (others may use betas via kwargs)
foreach: Enable / disable foreach (multi-tensor) operation if True / False. Choose safe default if None
filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay
**kwargs: extra optimizer specific kwargs to pass through **kwargs: extra optimizer specific kwargs to pass through
@ -228,7 +239,8 @@ def create_optimizer_v2(
model_or_params, model_or_params,
weight_decay=weight_decay, weight_decay=weight_decay,
layer_decay=layer_decay, layer_decay=layer_decay,
no_weight_decay_list=no_weight_decay) no_weight_decay_list=no_weight_decay,
)
weight_decay = 0. weight_decay = 0.
elif weight_decay and filter_bias_and_bn: elif weight_decay and filter_bias_and_bn:
parameters = param_groups_weight_decay(model_or_params, weight_decay, no_weight_decay) parameters = param_groups_weight_decay(model_or_params, weight_decay, no_weight_decay)
@ -246,9 +258,16 @@ def create_optimizer_v2(
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
opt_args = dict(weight_decay=weight_decay, **kwargs) opt_args = dict(weight_decay=weight_decay, **kwargs)
if lr is not None: if lr is not None:
opt_args.setdefault('lr', lr) opt_args.setdefault('lr', lr)
if foreach is None:
if opt in _DEFAULT_FOREACH:
opt_args.setdefault('foreach', True)
else:
opt_args['foreach'] = foreach
# basic SGD & related # basic SGD & related
if opt_lower == 'sgd' or opt_lower == 'nesterov': if opt_lower == 'sgd' or opt_lower == 'nesterov':
# NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons # NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons

Loading…
Cancel
Save