Add Adafactor and Adahessian optimizers, cleanup optimizer arg passing, add gradient clipping support.
parent
fcb6258877
commit
80078c47bb
@ -1,10 +1,13 @@
|
|||||||
from .nadam import Nadam
|
from .adamp import AdamP
|
||||||
from .rmsprop_tf import RMSpropTF
|
|
||||||
from .adamw import AdamW
|
from .adamw import AdamW
|
||||||
from .radam import RAdam
|
from .adafactor import Adafactor
|
||||||
|
from .adahessian import Adahessian
|
||||||
|
from .lookahead import Lookahead
|
||||||
|
from .nadam import Nadam
|
||||||
from .novograd import NovoGrad
|
from .novograd import NovoGrad
|
||||||
from .nvnovograd import NvNovoGrad
|
from .nvnovograd import NvNovoGrad
|
||||||
from .lookahead import Lookahead
|
from .radam import RAdam
|
||||||
from .adamp import AdamP
|
from .rmsprop_tf import RMSpropTF
|
||||||
from .sgdp import SGDP
|
from .sgdp import SGDP
|
||||||
|
|
||||||
from .optim_factory import create_optimizer
|
from .optim_factory import create_optimizer
|
@ -0,0 +1,174 @@
|
|||||||
|
""" Adafactor Optimizer
|
||||||
|
|
||||||
|
Lifted from https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
|
||||||
|
|
||||||
|
Original header/copyright below.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the MIT license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
class Adafactor(torch.optim.Optimizer):
|
||||||
|
"""Implements Adafactor algorithm.
|
||||||
|
This implementation is based on: `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost`
|
||||||
|
(see https://arxiv.org/abs/1804.04235)
|
||||||
|
|
||||||
|
Note that this optimizer internally adjusts the learning rate depending on the
|
||||||
|
*scale_parameter*, *relative_step* and *warmup_init* options.
|
||||||
|
|
||||||
|
To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
|
||||||
|
`relative_step=False`.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
|
||||||
|
lr (float, optional): external learning rate (default: None)
|
||||||
|
eps (tuple[float, float]): regularization constants for square gradient
|
||||||
|
and parameter scale respectively (default: (1e-30, 1e-3))
|
||||||
|
clip_threshold (float): threshold of root mean square of final gradient update (default: 1.0)
|
||||||
|
decay_rate (float): coefficient used to compute running averages of square gradient (default: -0.8)
|
||||||
|
beta1 (float): coefficient used for computing running averages of gradient (default: None)
|
||||||
|
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||||
|
scale_parameter (bool): if True, learning rate is scaled by root mean square of parameter (default: True)
|
||||||
|
relative_step (bool): if True, time-dependent learning rate is computed
|
||||||
|
instead of external learning rate (default: True)
|
||||||
|
warmup_init (bool): time-dependent learning rate computation depends on
|
||||||
|
whether warm-up initialization is being used (default: False)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, params, lr=None, eps=1e-30, eps_scale=1e-3, clip_threshold=1.0,
|
||||||
|
decay_rate=-0.8, betas=None, weight_decay=0.0, scale_parameter=True, warmup_init=False):
|
||||||
|
relative_step = lr is None
|
||||||
|
if warmup_init and not relative_step:
|
||||||
|
raise ValueError('warmup_init requires relative_step=True')
|
||||||
|
|
||||||
|
beta1 = None if betas is None else betas[0] # make it compat with standard betas arg
|
||||||
|
defaults = dict(lr=lr, eps=eps, eps_scale=eps_scale, clip_threshold=clip_threshold, decay_rate=decay_rate,
|
||||||
|
beta1=beta1, weight_decay=weight_decay, scale_parameter=scale_parameter,
|
||||||
|
relative_step=relative_step, warmup_init=warmup_init)
|
||||||
|
super(Adafactor, self).__init__(params, defaults)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_lr(param_group, param_state):
|
||||||
|
if param_group['relative_step']:
|
||||||
|
min_step = 1e-6 * param_state['step'] if param_group['warmup_init'] else 1e-2
|
||||||
|
lr_t = min(min_step, 1.0 / math.sqrt(param_state['step']))
|
||||||
|
param_scale = 1.0
|
||||||
|
if param_group['scale_parameter']:
|
||||||
|
param_scale = max(param_group['eps_scale'], param_state['RMS'])
|
||||||
|
param_group['lr'] = lr_t * param_scale
|
||||||
|
return param_group['lr']
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_options(param_group, param_shape):
|
||||||
|
factored = len(param_shape) >= 2
|
||||||
|
use_first_moment = param_group['beta1'] is not None
|
||||||
|
return factored, use_first_moment
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _rms(tensor):
|
||||||
|
return tensor.norm(2) / (tensor.numel() ** 0.5)
|
||||||
|
|
||||||
|
def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col):
|
||||||
|
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
|
||||||
|
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
|
||||||
|
return torch.mul(r_factor, c_factor)
|
||||||
|
|
||||||
|
def step(self, closure=None):
|
||||||
|
"""Performs a single optimization step.
|
||||||
|
Arguments:
|
||||||
|
closure (callable, optional): A closure that reevaluates the model and returns the loss.
|
||||||
|
"""
|
||||||
|
loss = None
|
||||||
|
if closure is not None:
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
|
for group in self.param_groups:
|
||||||
|
for p in group['params']:
|
||||||
|
if p.grad is None:
|
||||||
|
continue
|
||||||
|
grad = p.grad.data
|
||||||
|
if grad.dtype in {torch.float16, torch.bfloat16}:
|
||||||
|
grad = grad.float()
|
||||||
|
if grad.is_sparse:
|
||||||
|
raise RuntimeError('Adafactor does not support sparse gradients.')
|
||||||
|
|
||||||
|
state = self.state[p]
|
||||||
|
grad_shape = grad.shape
|
||||||
|
|
||||||
|
factored, use_first_moment = self._get_options(group, grad_shape)
|
||||||
|
# State Initialization
|
||||||
|
if len(state) == 0:
|
||||||
|
state['step'] = 0
|
||||||
|
|
||||||
|
if use_first_moment:
|
||||||
|
# Exponential moving average of gradient values
|
||||||
|
state['exp_avg'] = torch.zeros_like(grad)
|
||||||
|
if factored:
|
||||||
|
state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).to(grad)
|
||||||
|
state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
|
||||||
|
else:
|
||||||
|
state['exp_avg_sq'] = torch.zeros_like(grad)
|
||||||
|
|
||||||
|
state['RMS'] = 0
|
||||||
|
else:
|
||||||
|
if use_first_moment:
|
||||||
|
state['exp_avg'] = state['exp_avg'].to(grad)
|
||||||
|
if factored:
|
||||||
|
state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad)
|
||||||
|
state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad)
|
||||||
|
else:
|
||||||
|
state['exp_avg_sq'] = state['exp_avg_sq'].to(grad)
|
||||||
|
|
||||||
|
p_data_fp32 = p.data
|
||||||
|
if p.data.dtype in {torch.float16, torch.bfloat16}:
|
||||||
|
p_data_fp32 = p_data_fp32.float()
|
||||||
|
|
||||||
|
state['step'] += 1
|
||||||
|
state['RMS'] = self._rms(p_data_fp32)
|
||||||
|
lr_t = self._get_lr(group, state)
|
||||||
|
|
||||||
|
beta2t = 1.0 - math.pow(state['step'], group['decay_rate'])
|
||||||
|
update = grad ** 2 + group['eps']
|
||||||
|
if factored:
|
||||||
|
exp_avg_sq_row = state['exp_avg_sq_row']
|
||||||
|
exp_avg_sq_col = state['exp_avg_sq_col']
|
||||||
|
|
||||||
|
exp_avg_sq_row.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-1))
|
||||||
|
exp_avg_sq_col.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-2))
|
||||||
|
#exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t) # pytorch 1.6+
|
||||||
|
#exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t)
|
||||||
|
|
||||||
|
# Approximation of exponential moving average of square of gradient
|
||||||
|
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
|
||||||
|
update.mul_(grad)
|
||||||
|
else:
|
||||||
|
exp_avg_sq = state['exp_avg_sq']
|
||||||
|
|
||||||
|
exp_avg_sq.mul_(beta2t).add_(1.0 - beta2t, update)
|
||||||
|
#exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) # pytorch 1.6+
|
||||||
|
update = exp_avg_sq.rsqrt().mul_(grad)
|
||||||
|
|
||||||
|
update.div_((self._rms(update) / group['clip_threshold']).clamp_(min=1.0))
|
||||||
|
update.mul_(lr_t)
|
||||||
|
|
||||||
|
if use_first_moment:
|
||||||
|
exp_avg = state['exp_avg']
|
||||||
|
exp_avg.mul_(group["beta1"]).add_(1 - group["beta1"], update)
|
||||||
|
#exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1']) # pytorch 1.6+
|
||||||
|
update = exp_avg
|
||||||
|
|
||||||
|
if group['weight_decay'] != 0:
|
||||||
|
p_data_fp32.add_(-group["weight_decay"] * lr_t, p_data_fp32)
|
||||||
|
#p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * lr_t) # pytorch 1.6+
|
||||||
|
|
||||||
|
p_data_fp32.add_(-update)
|
||||||
|
|
||||||
|
if p.data.dtype in {torch.float16, torch.bfloat16}:
|
||||||
|
p.data.copy_(p_data_fp32)
|
||||||
|
|
||||||
|
return loss
|
@ -0,0 +1,156 @@
|
|||||||
|
""" AdaHessian Optimizer
|
||||||
|
|
||||||
|
Lifted from https://github.com/davda54/ada-hessian/blob/master/ada_hessian.py
|
||||||
|
Originally licensed MIT, Copyright 2020, David Samuel
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class Adahessian(torch.optim.Optimizer):
|
||||||
|
"""
|
||||||
|
Implements the AdaHessian algorithm from "ADAHESSIAN: An Adaptive Second OrderOptimizer for Machine Learning"
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
|
||||||
|
lr (float, optional): learning rate (default: 0.1)
|
||||||
|
betas ((float, float), optional): coefficients used for computing running averages of gradient and the
|
||||||
|
squared hessian trace (default: (0.9, 0.999))
|
||||||
|
eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8)
|
||||||
|
weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0)
|
||||||
|
hessian_power (float, optional): exponent of the hessian trace (default: 1.0)
|
||||||
|
update_each (int, optional): compute the hessian trace approximation only after *this* number of steps
|
||||||
|
(to save time) (default: 1)
|
||||||
|
n_samples (int, optional): how many times to sample `z` for the approximation of the hessian trace (default: 1)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, params, lr=0.1, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0,
|
||||||
|
hessian_power=1.0, update_each=1, n_samples=1, avg_conv_kernel=False):
|
||||||
|
if not 0.0 <= lr:
|
||||||
|
raise ValueError(f"Invalid learning rate: {lr}")
|
||||||
|
if not 0.0 <= eps:
|
||||||
|
raise ValueError(f"Invalid epsilon value: {eps}")
|
||||||
|
if not 0.0 <= betas[0] < 1.0:
|
||||||
|
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
|
||||||
|
if not 0.0 <= betas[1] < 1.0:
|
||||||
|
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
||||||
|
if not 0.0 <= hessian_power <= 1.0:
|
||||||
|
raise ValueError(f"Invalid Hessian power value: {hessian_power}")
|
||||||
|
|
||||||
|
self.n_samples = n_samples
|
||||||
|
self.update_each = update_each
|
||||||
|
self.avg_conv_kernel = avg_conv_kernel
|
||||||
|
|
||||||
|
# use a separate generator that deterministically generates the same `z`s across all GPUs in case of distributed training
|
||||||
|
self.seed = 2147483647
|
||||||
|
self.generator = torch.Generator().manual_seed(self.seed)
|
||||||
|
|
||||||
|
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, hessian_power=hessian_power)
|
||||||
|
super(Adahessian, self).__init__(params, defaults)
|
||||||
|
|
||||||
|
for p in self.get_params():
|
||||||
|
p.hess = 0.0
|
||||||
|
self.state[p]["hessian step"] = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_second_order(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_params(self):
|
||||||
|
"""
|
||||||
|
Gets all parameters in all param_groups with gradients
|
||||||
|
"""
|
||||||
|
|
||||||
|
return (p for group in self.param_groups for p in group['params'] if p.requires_grad)
|
||||||
|
|
||||||
|
def zero_hessian(self):
|
||||||
|
"""
|
||||||
|
Zeros out the accumalated hessian traces.
|
||||||
|
"""
|
||||||
|
|
||||||
|
for p in self.get_params():
|
||||||
|
if not isinstance(p.hess, float) and self.state[p]["hessian step"] % self.update_each == 0:
|
||||||
|
p.hess.zero_()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def set_hessian(self):
|
||||||
|
"""
|
||||||
|
Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
params = []
|
||||||
|
for p in filter(lambda p: p.grad is not None, self.get_params()):
|
||||||
|
if self.state[p]["hessian step"] % self.update_each == 0: # compute the trace only each `update_each` step
|
||||||
|
params.append(p)
|
||||||
|
self.state[p]["hessian step"] += 1
|
||||||
|
|
||||||
|
if len(params) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.generator.device != params[0].device: # hackish way of casting the generator to the right device
|
||||||
|
self.generator = torch.Generator(params[0].device).manual_seed(self.seed)
|
||||||
|
|
||||||
|
grads = [p.grad for p in params]
|
||||||
|
|
||||||
|
for i in range(self.n_samples):
|
||||||
|
# Rademacher distribution {-1.0, 1.0}
|
||||||
|
zs = [torch.randint(0, 2, p.size(), generator=self.generator, device=p.device) * 2.0 - 1.0 for p in params]
|
||||||
|
h_zs = torch.autograd.grad(
|
||||||
|
grads, params, grad_outputs=zs, only_inputs=True, retain_graph=i < self.n_samples - 1)
|
||||||
|
for h_z, z, p in zip(h_zs, zs, params):
|
||||||
|
p.hess += h_z * z / self.n_samples # approximate the expected values of z*(H@z)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step(self, closure=None):
|
||||||
|
"""
|
||||||
|
Performs a single optimization step.
|
||||||
|
Arguments:
|
||||||
|
closure (callable, optional) -- a closure that reevaluates the model and returns the loss (default: None)
|
||||||
|
"""
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if closure is not None:
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
|
self.zero_hessian()
|
||||||
|
self.set_hessian()
|
||||||
|
|
||||||
|
for group in self.param_groups:
|
||||||
|
for p in group['params']:
|
||||||
|
if p.grad is None or p.hess is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if self.avg_conv_kernel and p.dim() == 4:
|
||||||
|
p.hess = torch.abs(p.hess).mean(dim=[2, 3], keepdim=True).expand_as(p.hess).clone()
|
||||||
|
|
||||||
|
# Perform correct stepweight decay as in AdamW
|
||||||
|
p.mul_(1 - group['lr'] * group['weight_decay'])
|
||||||
|
|
||||||
|
state = self.state[p]
|
||||||
|
|
||||||
|
# State initialization
|
||||||
|
if len(state) == 1:
|
||||||
|
state['step'] = 0
|
||||||
|
# Exponential moving average of gradient values
|
||||||
|
state['exp_avg'] = torch.zeros_like(p)
|
||||||
|
# Exponential moving average of Hessian diagonal square values
|
||||||
|
state['exp_hessian_diag_sq'] = torch.zeros_like(p)
|
||||||
|
|
||||||
|
exp_avg, exp_hessian_diag_sq = state['exp_avg'], state['exp_hessian_diag_sq']
|
||||||
|
beta1, beta2 = group['betas']
|
||||||
|
state['step'] += 1
|
||||||
|
|
||||||
|
# Decay the first and second moment running average coefficient
|
||||||
|
exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
|
||||||
|
exp_hessian_diag_sq.mul_(beta2).addcmul_(p.hess, p.hess, value=1 - beta2)
|
||||||
|
|
||||||
|
bias_correction1 = 1 - beta1 ** state['step']
|
||||||
|
bias_correction2 = 1 - beta2 ** state['step']
|
||||||
|
|
||||||
|
k = group['hessian_power']
|
||||||
|
denom = (exp_hessian_diag_sq / bias_correction2).pow_(k / 2).add_(group['eps'])
|
||||||
|
|
||||||
|
# make update
|
||||||
|
step_size = group['lr'] / bias_correction1
|
||||||
|
p.addcdiv_(exp_avg, denom, value=-step_size)
|
||||||
|
|
||||||
|
return loss
|
Loading…
Reference in new issue