You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
185 lines
6.7 KiB
185 lines
6.7 KiB
""" PyTorch MADGRAD optimizer
|
|
|
|
MADGRAD: https://arxiv.org/abs/2101.11075
|
|
|
|
Code from: https://github.com/facebookresearch/madgrad
|
|
"""
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import math
|
|
from typing import TYPE_CHECKING, Any, Callable, Optional
|
|
|
|
import torch
|
|
import torch.optim
|
|
|
|
if TYPE_CHECKING:
|
|
from torch.optim.optimizer import _params_t
|
|
else:
|
|
_params_t = Any
|
|
|
|
|
|
class MADGRAD(torch.optim.Optimizer):
|
|
"""
|
|
MADGRAD_: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic
|
|
Optimization.
|
|
|
|
.. _MADGRAD: https://arxiv.org/abs/2101.11075
|
|
|
|
MADGRAD is a general purpose optimizer that can be used in place of SGD or
|
|
Adam may converge faster and generalize better. Currently GPU-only.
|
|
Typically, the same learning rate schedule that is used for SGD or Adam may
|
|
be used. The overall learning rate is not comparable to either method and
|
|
should be determined by a hyper-parameter sweep.
|
|
|
|
MADGRAD requires less weight decay than other methods, often as little as
|
|
zero. Momentum values used for SGD or Adam's beta1 should work here also.
|
|
|
|
On sparse problems both weight_decay and momentum should be set to 0.
|
|
|
|
Arguments:
|
|
params (iterable):
|
|
Iterable of parameters to optimize or dicts defining parameter groups.
|
|
lr (float):
|
|
Learning rate (default: 1e-2).
|
|
momentum (float):
|
|
Momentum value in the range [0,1) (default: 0.9).
|
|
weight_decay (float):
|
|
Weight decay, i.e. a L2 penalty (default: 0).
|
|
eps (float):
|
|
Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-6).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
params: _params_t,
|
|
lr: float = 1e-2,
|
|
momentum: float = 0.9,
|
|
weight_decay: float = 0,
|
|
eps: float = 1e-6,
|
|
decoupled_decay: bool = False,
|
|
):
|
|
if momentum < 0 or momentum >= 1:
|
|
raise ValueError(f"Momentum {momentum} must be in the range [0,1]")
|
|
if lr <= 0:
|
|
raise ValueError(f"Learning rate {lr} must be positive")
|
|
if weight_decay < 0:
|
|
raise ValueError(f"Weight decay {weight_decay} must be non-negative")
|
|
if eps < 0:
|
|
raise ValueError(f"Eps must be non-negative")
|
|
|
|
defaults = dict(
|
|
lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay, decoupled_decay=decoupled_decay)
|
|
super().__init__(params, defaults)
|
|
|
|
@property
|
|
def supports_memory_efficient_fp16(self) -> bool:
|
|
return False
|
|
|
|
@property
|
|
def supports_flat_params(self) -> bool:
|
|
return True
|
|
|
|
@torch.no_grad()
|
|
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
|
|
"""Performs a single optimization step.
|
|
|
|
Arguments:
|
|
closure (callable, optional): A closure that reevaluates the model and returns the loss.
|
|
"""
|
|
loss = None
|
|
if closure is not None:
|
|
with torch.enable_grad():
|
|
loss = closure()
|
|
|
|
for group in self.param_groups:
|
|
eps = group['eps']
|
|
lr = group['lr'] + eps
|
|
weight_decay = group['weight_decay']
|
|
momentum = group['momentum']
|
|
ck = 1 - momentum
|
|
|
|
for p in group["params"]:
|
|
if p.grad is None:
|
|
continue
|
|
grad = p.grad
|
|
if momentum != 0.0 and grad.is_sparse:
|
|
raise RuntimeError("momentum != 0 is not compatible with sparse gradients")
|
|
|
|
state = self.state[p]
|
|
if len(state) == 0:
|
|
state['step'] = 0
|
|
state['grad_sum_sq'] = torch.zeros_like(p)
|
|
state['s'] = torch.zeros_like(p)
|
|
if momentum != 0:
|
|
state['x0'] = torch.clone(p).detach()
|
|
|
|
state['step'] += 1
|
|
grad_sum_sq = state['grad_sum_sq']
|
|
s = state['s']
|
|
lamb = lr * math.sqrt(state['step'])
|
|
|
|
# Apply weight decay
|
|
if weight_decay != 0:
|
|
if group['decoupled_decay']:
|
|
p.mul_(1.0 - group['lr'] * weight_decay)
|
|
else:
|
|
if grad.is_sparse:
|
|
raise RuntimeError("weight_decay option is not compatible with sparse gradients")
|
|
grad.add_(p, alpha=weight_decay)
|
|
|
|
if grad.is_sparse:
|
|
grad = grad.coalesce()
|
|
grad_val = grad._values()
|
|
|
|
p_masked = p.sparse_mask(grad)
|
|
grad_sum_sq_masked = grad_sum_sq.sparse_mask(grad)
|
|
s_masked = s.sparse_mask(grad)
|
|
|
|
# Compute x_0 from other known quantities
|
|
rms_masked_vals = grad_sum_sq_masked._values().pow(1 / 3).add_(eps)
|
|
x0_masked_vals = p_masked._values().addcdiv(s_masked._values(), rms_masked_vals, value=1)
|
|
|
|
# Dense + sparse op
|
|
grad_sq = grad * grad
|
|
grad_sum_sq.add_(grad_sq, alpha=lamb)
|
|
grad_sum_sq_masked.add_(grad_sq, alpha=lamb)
|
|
|
|
rms_masked_vals = grad_sum_sq_masked._values().pow_(1 / 3).add_(eps)
|
|
|
|
s.add_(grad, alpha=lamb)
|
|
s_masked._values().add_(grad_val, alpha=lamb)
|
|
|
|
# update masked copy of p
|
|
p_kp1_masked_vals = x0_masked_vals.addcdiv(s_masked._values(), rms_masked_vals, value=-1)
|
|
# Copy updated masked p to dense p using an add operation
|
|
p_masked._values().add_(p_kp1_masked_vals, alpha=-1)
|
|
p.add_(p_masked, alpha=-1)
|
|
else:
|
|
if momentum == 0:
|
|
# Compute x_0 from other known quantities
|
|
rms = grad_sum_sq.pow(1 / 3).add_(eps)
|
|
x0 = p.addcdiv(s, rms, value=1)
|
|
else:
|
|
x0 = state['x0']
|
|
|
|
# Accumulate second moments
|
|
grad_sum_sq.addcmul_(grad, grad, value=lamb)
|
|
rms = grad_sum_sq.pow(1 / 3).add_(eps)
|
|
|
|
# Update s
|
|
s.add_(grad, alpha=lamb)
|
|
|
|
# Step
|
|
if momentum == 0:
|
|
p.copy_(x0.addcdiv(s, rms, value=-1))
|
|
else:
|
|
z = x0.addcdiv(s, rms, value=-1)
|
|
|
|
# p is a moving average of z
|
|
p.mul_(1 - ck).add_(z, alpha=ck)
|
|
|
|
return loss
|