Optimizer improvements, additions, cleanup

* Add MADGRAD code
* Fix Lamb (non-fused variant) to work w/ PyTorch XLA
* Tweak optimizer factory args (lr/learning_rate and opt/optimizer_name), may break compat
* Use newer fn signatures for all add,addcdiv, addcmul in optimizers
* Use upcoming PyTorch native Nadam if it's available
* Cleanup lookahead opt
* Add optimizer tests
* Remove novograd.py impl as it was messy, keep nvnovograd
* Make AdamP/SGDP work in channels_last layout
* Add rectified adablief mode (radabelief)
* Support a few more PyTorch optim, adamax, adagrad
pull/813/head
Ross Wightman 3 years ago
parent 3cdaf5ed56
commit ac469b50da

@ -255,8 +255,8 @@ class TrainBenchmarkRunner(BenchmarkRunner):
self.optimizer = create_optimizer_v2(
self.model,
optimizer_name=kwargs.pop('opt', 'sgd'),
learning_rate=kwargs.pop('lr', 1e-4))
opt=kwargs.pop('opt', 'sgd'),
lr=kwargs.pop('lr', 1e-4))
def _gen_target(self, batch_size):
return torch.empty(

@ -4,7 +4,6 @@ from .adafactor import Adafactor
from .adahessian import Adahessian
from .lookahead import Lookahead
from .nadam import Nadam
from .novograd import NovoGrad
from .nvnovograd import NvNovoGrad
from .radam import RAdam
from .rmsprop_tf import RMSpropTF

@ -18,7 +18,7 @@ class AdaBelief(Optimizer):
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
weight_decouple (boolean, optional): ( default: True) If set as True, then
decoupled_decay (boolean, optional): ( default: True) If set as True, then
the optimizer uses decoupled weight decay as in AdamW
fixed_decay (boolean, optional): (default: False) This is used when weight_decouple
is set as True.
@ -39,9 +39,9 @@ class AdaBelief(Optimizer):
- link to args.yaml: https://gist.github.com/juntang-zhuang/517ce3c27022b908bb93f78e4f786dc3
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
weight_decay=0, amsgrad=False, weight_decouple=True, fixed_decay=False, rectify=True,
degenerated_to_sgd=True):
def __init__(
self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16, weight_decay=0, amsgrad=False,
decoupled_decay=True, fixed_decay=False, rectify=True, degenerated_to_sgd=True):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
@ -52,21 +52,17 @@ class AdaBelief(Optimizer):
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
self.degenerated_to_sgd = degenerated_to_sgd
if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
for param in params:
if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
param['buffer'] = [[None, None, None] for _ in range(10)]
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad, buffer=[[None, None, None] for _ in range(10)])
defaults = dict(
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad,
degenerated_to_sgd=degenerated_to_sgd, decoupled_decay=decoupled_decay, rectify=rectify,
fixed_decay=fixed_decay, buffer=[[None, None, None] for _ in range(10)])
super(AdaBelief, self).__init__(params, defaults)
self.degenerated_to_sgd = degenerated_to_sgd
self.weight_decouple = weight_decouple
self.rectify = rectify
self.fixed_decay = fixed_decay
def __setstate__(self, state):
super(AdaBelief, self).__setstate__(state)
for group in self.param_groups:
@ -133,8 +129,8 @@ class AdaBelief(Optimizer):
state['max_exp_avg_var'] = torch.zeros_like(p.data)
# perform weight decay, check if decoupled weight decay
if self.weight_decouple:
if not self.fixed_decay:
if group['decoupled_decay']:
if not group['fixed_decay']:
p.data.mul_(1.0 - group['lr'] * group['weight_decay'])
else:
p.data.mul_(1.0 - group['weight_decay'])
@ -152,7 +148,7 @@ class AdaBelief(Optimizer):
# Update first and second moment running average
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
grad_residual = grad - exp_avg
exp_avg_var.mul_(beta2).addcmul_( grad_residual, grad_residual, value=1 - beta2)
exp_avg_var.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2)
if amsgrad:
max_exp_avg_var = state['max_exp_avg_var']
@ -165,34 +161,36 @@ class AdaBelief(Optimizer):
denom = (exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
# update
if not self.rectify:
if not group['rectify']:
# Default update
step_size = group['lr'] / bias_correction1
p.data.addcdiv_( exp_avg, denom, value=-step_size)
else: # Rectified update, forked from RAdam
p.data.addcdiv_(exp_avg, denom, value=-step_size)
else:
# Rectified update, forked from RAdam
buffered = group['buffer'][int(state['step'] % 10)]
if state['step'] == buffered[0]:
N_sma, step_size = buffered[1], buffered[2]
num_sma, step_size = buffered[1], buffered[2]
else:
buffered[0] = state['step']
beta2_t = beta2 ** state['step']
N_sma_max = 2 / (1 - beta2) - 1
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
buffered[1] = N_sma
num_sma_max = 2 / (1 - beta2) - 1
num_sma = num_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
buffered[1] = num_sma
# more conservative since it's an approximated value
if N_sma >= 5:
if num_sma >= 5:
step_size = math.sqrt(
(1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
N_sma_max - 2)) / (1 - beta1 ** state['step'])
elif self.degenerated_to_sgd:
(1 - beta2_t) *
(num_sma - 4) / (num_sma_max - 4) *
(num_sma - 2) / num_sma *
num_sma_max / (num_sma_max - 2)) / (1 - beta1 ** state['step'])
elif group['degenerated_to_sgd']:
step_size = 1.0 / (1 - beta1 ** state['step'])
else:
step_size = -1
buffered[2] = step_size
if N_sma >= 5:
if num_sma >= 5:
denom = exp_avg_var.sqrt().add_(group['eps'])
p.data.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
elif step_size > 0:

@ -34,15 +34,13 @@ class Adafactor(torch.optim.Optimizer):
beta1 (float): coefficient used for computing running averages of gradient (default: None)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
scale_parameter (bool): if True, learning rate is scaled by root mean square of parameter (default: True)
relative_step (bool): if True, time-dependent learning rate is computed
instead of external learning rate (default: True)
warmup_init (bool): time-dependent learning rate computation depends on
whether warm-up initialization is being used (default: False)
"""
def __init__(self, params, lr=None, eps=1e-30, eps_scale=1e-3, clip_threshold=1.0,
decay_rate=-0.8, betas=None, weight_decay=0.0, scale_parameter=True, warmup_init=False):
relative_step = lr is None
relative_step = not lr
if warmup_init and not relative_step:
raise ValueError('warmup_init requires relative_step=True')
@ -138,10 +136,8 @@ class Adafactor(torch.optim.Optimizer):
exp_avg_sq_row = state['exp_avg_sq_row']
exp_avg_sq_col = state['exp_avg_sq_col']
exp_avg_sq_row.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-1))
exp_avg_sq_col.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-2))
#exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t) # pytorch 1.6+
#exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t)
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t)
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t)
# Approximation of exponential moving average of square of gradient
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
@ -149,8 +145,7 @@ class Adafactor(torch.optim.Optimizer):
else:
exp_avg_sq = state['exp_avg_sq']
exp_avg_sq.mul_(beta2t).add_(1.0 - beta2t, update)
#exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) # pytorch 1.6+
exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t)
update = exp_avg_sq.rsqrt().mul_(grad)
update.div_((self._rms(update) / group['clip_threshold']).clamp_(min=1.0))
@ -158,17 +153,15 @@ class Adafactor(torch.optim.Optimizer):
if use_first_moment:
exp_avg = state['exp_avg']
exp_avg.mul_(group["beta1"]).add_(1 - group["beta1"], update)
#exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1']) # pytorch 1.6+
exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1'])
update = exp_avg
if group['weight_decay'] != 0:
p_data_fp32.add_(-group["weight_decay"] * lr_t, p_data_fp32)
#p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * lr_t) # pytorch 1.6+
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * lr_t)
p_data_fp32.add_(-update)
if p.data.dtype in {torch.float16, torch.bfloat16}:
p.data.copy_(p_data_fp32)
return loss
return loss

@ -9,48 +9,43 @@ MIT license
"""
import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer, required
import torch.nn.functional as F
from torch.optim.optimizer import Optimizer
import math
class AdamP(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
delta=delta, wd_ratio=wd_ratio, nesterov=nesterov)
super(AdamP, self).__init__(params, defaults)
def _channel_view(self, x):
return x.view(x.size(0), -1)
def _layer_view(self, x):
return x.view(1, -1)
def _channel_view(x) -> torch.Tensor:
return x.reshape(x.size(0), -1)
def _cosine_similarity(self, x, y, eps, view_func):
x = view_func(x)
y = view_func(y)
x_norm = x.norm(dim=1).add_(eps)
y_norm = y.norm(dim=1).add_(eps)
dot = (x * y).sum(dim=1)
def _layer_view(x) -> torch.Tensor:
return x.reshape(1, -1)
return dot.abs() / x_norm / y_norm
def _projection(self, p, grad, perturb, delta, wd_ratio, eps):
wd = 1
expand_size = [-1] + [1] * (len(p.shape) - 1)
for view_func in [self._channel_view, self._layer_view]:
def projection(p, grad, perturb, delta: float, wd_ratio: float, eps: float):
wd = 1.
expand_size = (-1,) + (1,) * (len(p.shape) - 1)
for view_func in [_channel_view, _layer_view]:
param_view = view_func(p.data)
grad_view = view_func(grad)
cosine_sim = F.cosine_similarity(grad_view, param_view, dim=1, eps=eps).abs_()
cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func)
if cosine_sim.max() < delta / math.sqrt(param_view.size(1)):
p_n = p.data / param_view.norm(p=2, dim=1).add_(eps).reshape(expand_size)
perturb -= p_n * view_func(p_n * perturb).sum(dim=1).reshape(expand_size)
wd = wd_ratio
return perturb, wd
if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)):
p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps)
perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size)
wd = wd_ratio
return perturb, wd
return perturb, wd
return perturb, wd
class AdamP(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False):
defaults = dict(
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
delta=delta, wd_ratio=wd_ratio, nesterov=nesterov)
super(AdamP, self).__init__(params, defaults)
def step(self, closure=None):
loss = None
@ -81,8 +76,8 @@ class AdamP(Optimizer):
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
step_size = group['lr'] / bias_correction1
@ -93,15 +88,15 @@ class AdamP(Optimizer):
perturb = exp_avg / denom
# Projection
wd_ratio = 1
wd_ratio = 1.
if len(p.shape) > 1:
perturb, wd_ratio = self._projection(p, grad, perturb, group['delta'], group['wd_ratio'], group['eps'])
perturb, wd_ratio = projection(p, grad, perturb, group['delta'], group['wd_ratio'], group['eps'])
# Weight decay
if group['weight_decay'] > 0:
p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio)
p.data.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio)
# Step
p.data.add_(-step_size, perturb)
p.data.add_(perturb, alpha=-step_size)
return loss

@ -1,5 +1,8 @@
""" AdamW Optimizer
Impl copied from PyTorch master
NOTE: Builtin optim.AdamW is used by the factory, this impl only serves as a Python based reference, will be removed
someday
"""
import math
import torch
@ -100,8 +103,8 @@ class AdamW(Optimizer):
bias_correction2 = 1 - beta2 ** state['step']
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
@ -112,6 +115,6 @@ class AdamW(Optimizer):
step_size = group['lr'] / bias_correction1
p.data.addcdiv_(-step_size, exp_avg, denom)
p.data.addcdiv_(exp_avg, denom, value=-step_size)
return loss

@ -47,12 +47,13 @@ Original copyrights for above sources are below.
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import math
import torch
from torch.optim import Optimizer
class NvLamb(Optimizer):
class Lamb(Optimizer):
"""Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB
reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
@ -82,25 +83,15 @@ class NvLamb(Optimizer):
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self, params, lr=1e-3, bias_correction=True,
betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01,
grad_averaging=True, set_grad_none=True,
max_grad_norm=1.0, use_nvlamb=False):
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging,
max_grad_norm=max_grad_norm)
def __init__(
self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01,
grad_averaging=True, max_grad_norm=1.0, decoupled_decay=False, use_nvlamb=False):
defaults = dict(
lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging, max_grad_norm=max_grad_norm,
decoupled_decay=decoupled_decay, use_nvlamb=use_nvlamb)
super().__init__(params, defaults)
self.set_grad_none = set_grad_none
self.use_nvlamb = use_nvlamb
def zero_grad(self):
if self.set_grad_none:
for group in self.param_groups:
for p in group['params']:
p.grad = None
else:
super(NvLamb, self).zero_grad()
def step(self, closure=None):
"""Performs a single optimization step.
@ -109,6 +100,7 @@ class NvLamb(Optimizer):
and returns the loss.
"""
device = self.param_groups[0]["params"][0].device
one_tensor = torch.tensor(1.0, device=device)
loss = None
if closure is not None:
@ -124,22 +116,18 @@ class NvLamb(Optimizer):
raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
global_grad_norm.add_(grad.pow(2).sum())
global_grad_norm_ = torch.sqrt(global_grad_norm)
global_grad_norm = torch.sqrt(global_grad_norm)
max_grad_norm = self.defaults['max_grad_norm']
if global_grad_norm_ > max_grad_norm:
clip_global_grad_norm = global_grad_norm_ / max_grad_norm
else:
clip_global_grad_norm = 1.0
clip_global_grad_norm = torch.where(
global_grad_norm > max_grad_norm,
global_grad_norm / max_grad_norm,
one_tensor)
for group in self.param_groups:
bias_correction = 1 if group['bias_correction'] else 0
beta1, beta2 = group['betas']
grad_averaging = 1 if group['grad_averaging'] else 0
if grad_averaging:
beta3 = 1 - beta1
else:
beta3 = 1.0
beta3 = 1 - beta1 if grad_averaging else 1.0
# assume same step across group now to simplify things
# per parameter step can be easily support by making it tensor, or pass list into kernel
@ -169,36 +157,35 @@ class NvLamb(Optimizer):
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg_, exp_avg_sq_ = state['exp_avg'], state['exp_avg_sq']
decoupled_decay = group['decoupled_decay']
weight_decay = group['weight_decay']
if decoupled_decay and weight_decay != 0:
p.data.mul_(1. - group['lr'] * weight_decay)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
# Decay the first and second moment running average coefficient
# m_t
exp_avg_.mul_(beta1).add_(grad, alpha=beta3)
# v_t
exp_avg_sq_.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
# create clones to avoid modifying runner stats
exp_avg = exp_avg_.div(bias_correction1)
exp_avg_sq = exp_avg_sq_.div(bias_correction2)
# || w_t ||
weight_norm = p.data.norm(2.0)
# u_t
exp_avg_sq_sqrt = torch.sqrt(exp_avg_sq)
adam_step = exp_avg.div_(exp_avg_sq_sqrt.add_(group['eps']))
if group['weight_decay'] != 0:
adam_step.add_(p.data, alpha=group['weight_decay'])
# || u_t ||
adam_norm = adam_step.norm(2.0)
if (group['weight_decay'] != 0 or self.use_nvlamb) and adam_norm > 0 and weight_norm > 0:
trust_ratio = weight_norm / adam_norm
trust_ratio = trust_ratio.item()
else:
trust_ratio = 1
state['weight_norm'] = weight_norm
state['adam_norm'] = adam_norm
state['trust_ratio'] = trust_ratio
p.data.add_(adam_step, alpha=-step_size * trust_ratio)
exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
update = (exp_avg / bias_correction1).div_(denom)
if not decoupled_decay and weight_decay != 0:
update.add_(p.data, alpha=weight_decay)
trust_ratio = one_tensor
if weight_decay != 0 or group['use_nvlamb']:
# Layer adaptation. By default, skip layer adaptation on parameters that are
# excluded from weight norm, unless use_nvlamb == True, then always enabled.
w_norm = p.data.norm(2.0)
g_norm = update.norm(2.0)
trust_ratio = torch.where(
w_norm > 0,
torch.where(g_norm > 0, w_norm / g_norm, one_tensor),
one_tensor,
)
update.mul_(trust_ratio)
p.data.add_(update, alpha=-step_size)
return loss

@ -11,82 +11,49 @@ from collections import defaultdict
class Lookahead(Optimizer):
def __init__(self, base_optimizer, alpha=0.5, k=6):
# NOTE super().__init__() not called on purpose
if not 0.0 <= alpha <= 1.0:
raise ValueError(f'Invalid slow update rate: {alpha}')
if not 1 <= k:
raise ValueError(f'Invalid lookahead steps: {k}')
defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
self.base_optimizer = base_optimizer
self.param_groups = self.base_optimizer.param_groups
self._base_optimizer = base_optimizer
self.param_groups = base_optimizer.param_groups
self.defaults = base_optimizer.defaults
self.defaults.update(defaults)
self.state = defaultdict(dict)
# manually add our defaults to the param groups
for name, default in defaults.items():
for group in self.param_groups:
for group in self._base_optimizer.param_groups:
group.setdefault(name, default)
def update_slow(self, group):
for fast_p in group["params"]:
if fast_p.grad is None:
continue
param_state = self.state[fast_p]
if 'slow_buffer' not in param_state:
param_state['slow_buffer'] = torch.empty_like(fast_p.data)
param_state['slow_buffer'].copy_(fast_p.data)
slow = param_state['slow_buffer']
slow.add_(group['lookahead_alpha'], fast_p.data - slow)
param_state = self._base_optimizer.state[fast_p]
if 'lookahead_slow_buff' not in param_state:
param_state['lookahead_slow_buff'] = torch.empty_like(fast_p.data)
param_state['lookahead_slow_buff'].copy_(fast_p.data)
slow = param_state['lookahead_slow_buff']
slow.add_(fast_p.data - slow, alpha=group['lookahead_alpha'])
fast_p.data.copy_(slow)
def sync_lookahead(self):
for group in self.param_groups:
for group in self._base_optimizer.param_groups:
self.update_slow(group)
def step(self, closure=None):
#assert id(self.param_groups) == id(self.base_optimizer.param_groups)
loss = self.base_optimizer.step(closure)
for group in self.param_groups:
loss = self._base_optimizer.step(closure)
for group in self._base_optimizer.param_groups:
group['lookahead_step'] += 1
if group['lookahead_step'] % group['lookahead_k'] == 0:
self.update_slow(group)
return loss
def state_dict(self):
fast_state_dict = self.base_optimizer.state_dict()
slow_state = {
(id(k) if isinstance(k, torch.Tensor) else k): v
for k, v in self.state.items()
}
fast_state = fast_state_dict['state']
param_groups = fast_state_dict['param_groups']
return {
'state': fast_state,
'slow_state': slow_state,
'param_groups': param_groups,
}
return self._base_optimizer.state_dict()
def load_state_dict(self, state_dict):
fast_state_dict = {
'state': state_dict['state'],
'param_groups': state_dict['param_groups'],
}
self.base_optimizer.load_state_dict(fast_state_dict)
# We want to restore the slow state, but share param_groups reference
# with base_optimizer. This is a bit redundant but least code
slow_state_new = False
if 'slow_state' not in state_dict:
print('Loading state_dict from optimizer without Lookahead applied.')
state_dict['slow_state'] = defaultdict(dict)
slow_state_new = True
slow_state_dict = {
'state': state_dict['slow_state'],
'param_groups': state_dict['param_groups'], # this is pointless but saves code
}
super(Lookahead, self).load_state_dict(slow_state_dict)
self.param_groups = self.base_optimizer.param_groups # make both ref same container
if slow_state_new:
# reapply defaults to catch missing lookahead specific ones
for name, default in self.defaults.items():
for group in self.param_groups:
group.setdefault(name, default)
self._base_optimizer.load_state_dict(state_dict)
self.param_groups = self._base_optimizer.param_groups

@ -0,0 +1,175 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import TYPE_CHECKING, Any, Callable, Optional
import torch
import torch.optim
if TYPE_CHECKING:
from torch.optim.optimizer import _params_t
else:
_params_t = Any
class MADGRAD(torch.optim.Optimizer):
"""
MADGRAD_: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic
Optimization.
.. _MADGRAD: https://arxiv.org/abs/2101.11075
MADGRAD is a general purpose optimizer that can be used in place of SGD or
Adam may converge faster and generalize better. Currently GPU-only.
Typically, the same learning rate schedule that is used for SGD or Adam may
be used. The overall learning rate is not comparable to either method and
should be determined by a hyper-parameter sweep.
MADGRAD requires less weight decay than other methods, often as little as
zero. Momentum values used for SGD or Adam's beta1 should work here also.
On sparse problems both weight_decay and momentum should be set to 0.
Arguments:
params (iterable):
Iterable of parameters to optimize or dicts defining parameter groups.
lr (float):
Learning rate (default: 1e-2).
momentum (float):
Momentum value in the range [0,1) (default: 0.9).
weight_decay (float):
Weight decay, i.e. a L2 penalty (default: 0).
eps (float):
Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-6).
"""
def __init__(
self, params: _params_t, lr: float = 1e-2, momentum: float = 0.9, weight_decay: float = 0, eps: float = 1e-6,
):
if momentum < 0 or momentum >= 1:
raise ValueError(f"Momentum {momentum} must be in the range [0,1]")
if lr <= 0:
raise ValueError(f"Learning rate {lr} must be positive")
if weight_decay < 0:
raise ValueError(f"Weight decay {weight_decay} must be non-negative")
if eps < 0:
raise ValueError(f"Eps must be non-negative")
defaults = dict(lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay)
super().__init__(params, defaults)
@property
def supports_memory_efficient_fp16(self) -> bool:
return False
@property
def supports_flat_params(self) -> bool:
return True
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
# step counter must be stored in state to ensure correct behavior under
# optimizer sharding
if 'k' not in self.state:
self.state['k'] = torch.tensor([0], dtype=torch.long)
k = self.state['k'].item()
for group in self.param_groups:
eps = group["eps"]
lr = group["lr"] + eps
decay = group["weight_decay"]
momentum = group["momentum"]
ck = 1 - momentum
lamb = lr * math.pow(k + 1, 0.5)
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
state = self.state[p]
if "grad_sum_sq" not in state:
state["grad_sum_sq"] = torch.zeros_like(p.data).detach()
state["s"] = torch.zeros_like(p.data).detach()
if momentum != 0:
state["x0"] = torch.clone(p.data).detach()
if momentum != 0.0 and grad.is_sparse:
raise RuntimeError("momentum != 0 is not compatible with sparse gradients")
grad_sum_sq = state["grad_sum_sq"]
s = state["s"]
# Apply weight decay
if decay != 0:
if grad.is_sparse:
raise RuntimeError("weight_decay option is not compatible with sparse gradients")
grad.add_(p.data, alpha=decay)
if grad.is_sparse:
grad = grad.coalesce()
grad_val = grad._values()
p_masked = p.sparse_mask(grad)
grad_sum_sq_masked = grad_sum_sq.sparse_mask(grad)
s_masked = s.sparse_mask(grad)
# Compute x_0 from other known quantities
rms_masked_vals = grad_sum_sq_masked._values().pow(1 / 3).add_(eps)
x0_masked_vals = p_masked._values().addcdiv(s_masked._values(), rms_masked_vals, value=1)
# Dense + sparse op
grad_sq = grad * grad
grad_sum_sq.add_(grad_sq, alpha=lamb)
grad_sum_sq_masked.add_(grad_sq, alpha=lamb)
rms_masked_vals = grad_sum_sq_masked._values().pow_(1 / 3).add_(eps)
s.add_(grad, alpha=lamb)
s_masked._values().add_(grad_val, alpha=lamb)
# update masked copy of p
p_kp1_masked_vals = x0_masked_vals.addcdiv(s_masked._values(), rms_masked_vals, value=-1)
# Copy updated masked p to dense p using an add operation
p_masked._values().add_(p_kp1_masked_vals, alpha=-1)
p.data.add_(p_masked, alpha=-1)
else:
if momentum == 0:
# Compute x_0 from other known quantities
rms = grad_sum_sq.pow(1 / 3).add_(eps)
x0 = p.data.addcdiv(s, rms, value=1)
else:
x0 = state["x0"]
# Accumulate second moments
grad_sum_sq.addcmul_(grad, grad, value=lamb)
rms = grad_sum_sq.pow(1 / 3).add_(eps)
# Update s
s.data.add_(grad, alpha=lamb)
# Step
if momentum == 0:
p.data.copy_(x0.addcdiv(s, rms, value=-1))
else:
z = x0.addcdiv(s, rms, value=-1)
# p is a moving average of z
p.data.mul_(1 - ck).add_(z, alpha=ck)
self.state['k'] += 1
return loss

@ -1,5 +1,5 @@
import torch
from torch.optim import Optimizer
from torch.optim.optimizer import Optimizer
class Nadam(Optimizer):
@ -27,8 +27,10 @@ class Nadam(Optimizer):
def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, schedule_decay=4e-3):
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, schedule_decay=schedule_decay)
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
defaults = dict(
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, schedule_decay=schedule_decay)
super(Nadam, self).__init__(params, defaults)
def step(self, closure=None):
@ -53,8 +55,8 @@ class Nadam(Optimizer):
if len(state) == 0:
state['step'] = 0
state['m_schedule'] = 1.
state['exp_avg'] = grad.new().resize_as_(grad).zero_()
state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_()
state['exp_avg'] = torch.zeros_like(p.data)
state['exp_avg_sq'] = torch.zeros_like(p.data)
# Warming momentum schedule
m_schedule = state['m_schedule']
@ -66,23 +68,21 @@ class Nadam(Optimizer):
t = state['step']
if group['weight_decay'] != 0:
grad = grad.add(group['weight_decay'], p.data)
grad = grad.add(p.data, alpha=group['weight_decay'])
momentum_cache_t = beta1 * \
(1. - 0.5 * (0.96 ** (t * schedule_decay)))
momentum_cache_t_1 = beta1 * \
(1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay)))
momentum_cache_t = beta1 * (1. - 0.5 * (0.96 ** (t * schedule_decay)))
momentum_cache_t_1 = beta1 * (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay)))
m_schedule_new = m_schedule * momentum_cache_t
m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1
state['m_schedule'] = m_schedule_new
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1. - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1. - beta2, grad, grad)
exp_avg.mul_(beta1).add_(grad, alpha=1. - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1. - beta2)
exp_avg_sq_prime = exp_avg_sq / (1. - beta2 ** t)
denom = exp_avg_sq_prime.sqrt_().add_(eps)
p.data.addcdiv_(-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new), grad, denom)
p.data.addcdiv_(-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next), exp_avg, denom)
p.data.addcdiv_(grad, denom, value=-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new))
p.data.addcdiv_(exp_avg, denom, value=-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next))
return loss

@ -1,77 +0,0 @@
"""NovoGrad Optimizer.
Original impl by Masashi Kimura (Convergence Lab): https://github.com/convergence-lab/novograd
Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks`
- https://arxiv.org/abs/1905.11286
"""
import torch
from torch.optim.optimizer import Optimizer
import math
class NovoGrad(Optimizer):
def __init__(self, params, grad_averaging=False, lr=0.1, betas=(0.95, 0.98), eps=1e-8, weight_decay=0):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(NovoGrad, self).__init__(params, defaults)
self._lr = lr
self._beta1 = betas[0]
self._beta2 = betas[1]
self._eps = eps
self._wd = weight_decay
self._grad_averaging = grad_averaging
self._momentum_initialized = False
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
if not self._momentum_initialized:
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
state = self.state[p]
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('NovoGrad does not support sparse gradients')
v = torch.norm(grad)**2
m = grad/(torch.sqrt(v) + self._eps) + self._wd * p.data
state['step'] = 0
state['v'] = v
state['m'] = m
state['grad_ema'] = None
self._momentum_initialized = True
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
state = self.state[p]
state['step'] += 1
step, v, m = state['step'], state['v'], state['m']
grad_ema = state['grad_ema']
grad = p.grad.data
g2 = torch.norm(grad)**2
grad_ema = g2 if grad_ema is None else grad_ema * \
self._beta2 + g2 * (1. - self._beta2)
grad *= 1.0 / (torch.sqrt(grad_ema) + self._eps)
if self._grad_averaging:
grad *= (1. - self._beta1)
g2 = torch.norm(grad)**2
v = self._beta2*v + (1. - self._beta2)*g2
m = self._beta1*m + (grad / (torch.sqrt(v) + self._eps) + self._wd * p.data)
bias_correction1 = 1 - self._beta1 ** step
bias_correction2 = 1 - self._beta2 ** step
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
state['v'], state['m'] = v, m
state['grad_ema'] = grad_ema
p.data.add_(-step_size, m)
return loss

@ -96,7 +96,7 @@ class NvNovoGrad(Optimizer):
if exp_avg_sq == 0:
exp_avg_sq.copy_(norm)
else:
exp_avg_sq.mul_(beta2).add_(1 - beta2, norm)
exp_avg_sq.mul_(beta2).add_(norm, alpha=1 - beta2)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
@ -108,11 +108,11 @@ class NvNovoGrad(Optimizer):
grad.div_(denom)
if group['weight_decay'] != 0:
grad.add_(group['weight_decay'], p.data)
grad.add_(p.data, alpha=group['weight_decay'])
if group['grad_averaging']:
grad.mul_(1 - beta1)
exp_avg.mul_(beta1).add_(grad)
p.data.add_(-group['lr'], exp_avg)
p.data.add_(exp_avg, alpha=-group['lr'])
return loss

@ -6,15 +6,16 @@ from typing import Optional
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.optimizer import required
from .adabelief import AdaBelief
from .adafactor import Adafactor
from .adahessian import Adahessian
from .adamp import AdamP
from .lamb import NvLamb
from .lamb import Lamb
from .lookahead import Lookahead
from .madgrad import MADGRAD
from .nadam import Nadam
from .novograd import NovoGrad
from .nvnovograd import NvNovoGrad
from .radam import RAdam
from .rmsprop_tf import RMSpropTF
@ -47,8 +48,8 @@ def optimizer_kwargs(cfg):
Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn.
"""
kwargs = dict(
optimizer_name=cfg.opt,
learning_rate=cfg.lr,
opt=cfg.opt,
lr=cfg.lr,
weight_decay=cfg.weight_decay,
momentum=cfg.momentum)
if getattr(cfg, 'opt_eps', None) is not None:
@ -72,9 +73,9 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
def create_optimizer_v2(
model: nn.Module,
optimizer_name: str = 'sgd',
learning_rate: Optional[float] = None,
model_or_params,
opt: str = 'sgd',
lr: Optional[float] = None,
weight_decay: float = 0.,
momentum: float = 0.9,
filter_bias_and_bn: bool = True,
@ -87,9 +88,9 @@ def create_optimizer_v2(
* expose the parameters interface and leave it up to caller
Args:
model (nn.Module): model containing parameters to optimize
optimizer_name: name of optimizer to create
learning_rate: initial learning rate
model_or_params (nn.Module): model containing parameters to optimize
opt: name of optimizer to create
lr: initial learning rate
weight_decay: weight decay to apply in optimizer
momentum: momentum for momentum based optimizers (others may use betas via kwargs)
filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay
@ -98,59 +99,85 @@ def create_optimizer_v2(
Returns:
Optimizer
"""
opt_lower = optimizer_name.lower()
if weight_decay and filter_bias_and_bn:
skip = {}
if hasattr(model, 'no_weight_decay'):
skip = model.no_weight_decay()
parameters = add_weight_decay(model, weight_decay, skip)
weight_decay = 0.
if isinstance(model_or_params, nn.Module):
# a model was passed in, extract parameters and add weight decays to appropriate layers
if weight_decay and filter_bias_and_bn:
skip = {}
if hasattr(model_or_params, 'no_weight_decay'):
skip = model_or_params.no_weight_decay()
parameters = add_weight_decay(model_or_params, weight_decay, skip)
weight_decay = 0.
else:
parameters = model_or_params.parameters()
else:
parameters = model.parameters()
if 'fused' in opt_lower:
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
# iterable of parameters or param groups passed in
parameters = model_or_params
opt_args = dict(lr=learning_rate, weight_decay=weight_decay, **kwargs)
opt_lower = opt.lower()
opt_split = opt_lower.split('_')
opt_lower = opt_split[-1]
if 'fused' in opt_lower:
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
opt_args = dict(weight_decay=weight_decay, **kwargs)
if lr is not None:
opt_args.setdefault('lr', lr)
# basic SGD & related
if opt_lower == 'sgd' or opt_lower == 'nesterov':
# NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons
opt_args.pop('eps', None)
optimizer = optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args)
elif opt_lower == 'momentum':
opt_args.pop('eps', None)
optimizer = optim.SGD(parameters, momentum=momentum, nesterov=False, **opt_args)
elif opt_lower == 'sgdp':
optimizer = SGDP(parameters, momentum=momentum, nesterov=True, **opt_args)
# adaptive
elif opt_lower == 'adam':
optimizer = optim.Adam(parameters, **opt_args)
elif opt_lower == 'adabelief':
optimizer = AdaBelief(parameters, rectify=False, **opt_args)
elif opt_lower == 'adamw':
optimizer = optim.AdamW(parameters, **opt_args)
elif opt_lower == 'adamp':
optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
elif opt_lower == 'nadam':
optimizer = Nadam(parameters, **opt_args)
try:
# NOTE PyTorch >= 1.10 should have native NAdam
optimizer = optim.Nadam(parameters, **opt_args)
except AttributeError:
optimizer = Nadam(parameters, **opt_args)
elif opt_lower == 'radam':
optimizer = RAdam(parameters, **opt_args)
elif opt_lower == 'adamp':
optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
elif opt_lower == 'sgdp':
optimizer = SGDP(parameters, momentum=momentum, nesterov=True, **opt_args)
elif opt_lower == 'adamax':
optimizer = optim.Adamax(parameters, **opt_args)
elif opt_lower == 'adabelief':
optimizer = AdaBelief(parameters, rectify=False, **opt_args)
elif opt_lower == 'radabelief':
optimizer = AdaBelief(parameters, rectify=True, **opt_args)
elif opt_lower == 'adadelta':
optimizer = optim.Adadelta(parameters, **opt_args)
elif opt_lower == 'adagrad':
opt_args.setdefault('eps', 1e-8)
optimizer = optim.Adagrad(parameters, **opt_args)
elif opt_lower == 'adafactor':
if not learning_rate:
opt_args['lr'] = None
optimizer = Adafactor(parameters, **opt_args)
elif opt_lower == 'adahessian':
optimizer = Adahessian(parameters, **opt_args)
elif opt_lower == 'lamb':
optimizer = Lamb(parameters, **opt_args)
elif opt_lower == 'lambw':
optimizer = Lamb(parameters, decoupled_decay=True, **opt_args) # FIXME experimental
elif opt_lower == 'madgrad':
optimizer = MADGRAD(parameters, momentum=momentum, **opt_args)
elif opt_lower == 'novograd' or opt_lower == 'nvnovograd':
optimizer = NvNovoGrad(parameters, **opt_args)
elif opt_lower == 'rmsprop':
optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args)
elif opt_lower == 'rmsproptf':
optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args)
elif opt_lower == 'novograd':
optimizer = NovoGrad(parameters, **opt_args)
elif opt_lower == 'nvnovograd':
optimizer = NvNovoGrad(parameters, **opt_args)
elif opt_lower == 'lamb':
optimizer = NvLamb(parameters, **opt_args)
# second order
elif opt_lower == 'adahessian':
optimizer = Adahessian(parameters, **opt_args)
# NVIDIA fused optimizers, require APEX to be installed
elif opt_lower == 'fusedsgd':

@ -4,21 +4,21 @@ Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxi
"""
import math
import torch
from torch.optim.optimizer import Optimizer, required
from torch.optim.optimizer import Optimizer
class RAdam(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
self.buffer = [[None, None, None] for ind in range(10)]
defaults = dict(
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
buffer=[[None, None, None] for _ in range(10)])
super(RAdam, self).__init__(params, defaults)
def __setstate__(self, state):
super(RAdam, self).__setstate__(state)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
@ -47,105 +47,40 @@ class RAdam(Optimizer):
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
state['step'] += 1
buffered = self.buffer[int(state['step'] % 10)]
buffered = group['buffer'][int(state['step'] % 10)]
if state['step'] == buffered[0]:
N_sma, step_size = buffered[1], buffered[2]
num_sma, step_size = buffered[1], buffered[2]
else:
buffered[0] = state['step']
beta2_t = beta2 ** state['step']
N_sma_max = 2 / (1 - beta2) - 1
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
buffered[1] = N_sma
num_sma_max = 2 / (1 - beta2) - 1
num_sma = num_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
buffered[1] = num_sma
# more conservative since it's an approximated value
if N_sma >= 5:
if num_sma >= 5:
step_size = group['lr'] * math.sqrt(
(1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
N_sma_max - 2)) / (1 - beta1 ** state['step'])
(1 - beta2_t) *
(num_sma - 4) / (num_sma_max - 4) *
(num_sma - 2) / num_sma *
num_sma_max / (num_sma_max - 2)) / (1 - beta1 ** state['step'])
else:
step_size = group['lr'] / (1 - beta1 ** state['step'])
buffered[2] = step_size
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
# more conservative since it's an approximated value
if N_sma >= 5:
denom = exp_avg_sq.sqrt().add_(group['eps'])
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
else:
p_data_fp32.add_(-step_size, exp_avg)
p.data.copy_(p_data_fp32)
return loss
class PlainRAdam(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(PlainRAdam, self).__init__(params, defaults)
def __setstate__(self, state):
super(PlainRAdam, self).__setstate__(state)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data.float()
if grad.is_sparse:
raise RuntimeError('RAdam does not support sparse gradients')
p_data_fp32 = p.data.float()
state = self.state[p]
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p_data_fp32)
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg.mul_(beta1).add_(1 - beta1, grad)
state['step'] += 1
beta2_t = beta2 ** state['step']
N_sma_max = 2 / (1 - beta2) - 1
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])
# more conservative since it's an approximated value
if N_sma >= 5:
step_size = group['lr'] * math.sqrt(
(1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
N_sma_max - 2)) / (1 - beta1 ** state['step'])
if num_sma >= 5:
denom = exp_avg_sq.sqrt().add_(group['eps'])
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size)
else:
step_size = group['lr'] / (1 - beta1 ** state['step'])
p_data_fp32.add_(-step_size, exp_avg)
p_data_fp32.add_(exp_avg, alpha=-step_size)
p.data.copy_(p_data_fp32)

@ -58,8 +58,9 @@ class RMSpropTF(Optimizer):
if not 0.0 <= alpha:
raise ValueError("Invalid alpha value: {}".format(alpha))
defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay,
decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum)
defaults = dict(
lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay,
decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum)
super(RMSpropTF, self).__init__(params, defaults)
def __setstate__(self, state):
@ -103,34 +104,34 @@ class RMSpropTF(Optimizer):
state['step'] += 1
if group['weight_decay'] != 0:
if 'decoupled_decay' in group and group['decoupled_decay']:
p.data.add_(-group['weight_decay'], p.data)
if group['decoupled_decay']:
p.data.mul_(1. - group['lr'] * group['weight_decay'])
else:
grad = grad.add(group['weight_decay'], p.data)
grad = grad.add(p.data, alpha=group['weight_decay'])
# Tensorflow order of ops for updating squared avg
square_avg.add_(one_minus_alpha, grad.pow(2) - square_avg)
# square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) # PyTorch original
square_avg.add_(grad.pow(2) - square_avg, alpha=one_minus_alpha)
# square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha) # PyTorch original
if group['centered']:
grad_avg = state['grad_avg']
grad_avg.add_(one_minus_alpha, grad - grad_avg)
# grad_avg.mul_(alpha).add_(1 - alpha, grad) # PyTorch original
avg = square_avg.addcmul(-1, grad_avg, grad_avg).add(group['eps']).sqrt_() # eps moved in sqrt
grad_avg.add_(grad - grad_avg, alpha=one_minus_alpha)
avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).add(group['eps']).sqrt_() # eps in sqrt
# grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha) # PyTorch original
else:
avg = square_avg.add(group['eps']).sqrt_() # eps moved in sqrt
if group['momentum'] > 0:
buf = state['momentum_buffer']
# Tensorflow accumulates the LR scaling in the momentum buffer
if 'lr_in_momentum' in group and group['lr_in_momentum']:
buf.mul_(group['momentum']).addcdiv_(group['lr'], grad, avg)
if group['lr_in_momentum']:
buf.mul_(group['momentum']).addcdiv_(grad, avg, value=group['lr'])
p.data.add_(-buf)
else:
# PyTorch scales the param update by LR
buf.mul_(group['momentum']).addcdiv_(grad, avg)
p.data.add_(-group['lr'], buf)
p.data.add_(buf, alpha=-group['lr'])
else:
p.data.addcdiv_(-group['lr'], grad, avg)
p.data.addcdiv_(grad, avg, value=-group['lr'])
return loss

@ -9,49 +9,21 @@ MIT license
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.optimizer import Optimizer, required
import math
from .adamp import projection
class SGDP(Optimizer):
def __init__(self, params, lr=required, momentum=0, dampening=0,
weight_decay=0, nesterov=False, eps=1e-8, delta=0.1, wd_ratio=0.1):
defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay,
nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio)
defaults = dict(
lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay,
nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio)
super(SGDP, self).__init__(params, defaults)
def _channel_view(self, x):
return x.view(x.size(0), -1)
def _layer_view(self, x):
return x.view(1, -1)
def _cosine_similarity(self, x, y, eps, view_func):
x = view_func(x)
y = view_func(y)
x_norm = x.norm(dim=1).add_(eps)
y_norm = y.norm(dim=1).add_(eps)
dot = (x * y).sum(dim=1)
return dot.abs() / x_norm / y_norm
def _projection(self, p, grad, perturb, delta, wd_ratio, eps):
wd = 1
expand_size = [-1] + [1] * (len(p.shape) - 1)
for view_func in [self._channel_view, self._layer_view]:
cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func)
if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)):
p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps)
perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size)
wd = wd_ratio
return perturb, wd
return perturb, wd
def step(self, closure=None):
loss = None
if closure is not None:
@ -75,22 +47,22 @@ class SGDP(Optimizer):
# SGD
buf = state['momentum']
buf.mul_(momentum).add_(1 - dampening, grad)
buf.mul_(momentum).add_(grad, alpha=1. - dampening)
if nesterov:
d_p = grad + momentum * buf
else:
d_p = buf
# Projection
wd_ratio = 1
wd_ratio = 1.
if len(p.shape) > 1:
d_p, wd_ratio = self._projection(p, grad, d_p, group['delta'], group['wd_ratio'], group['eps'])
d_p, wd_ratio = projection(p, grad, d_p, group['delta'], group['wd_ratio'], group['eps'])
# Weight decay
if weight_decay != 0:
p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum))
p.data.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum))
# Step
p.data.add_(-group['lr'], d_p)
p.data.add_(d_p, alpha=-group['lr'])
return loss

Loading…
Cancel
Save