More optimizer cleanup. Change all to no longer use .data. Improve (b)float16 use with adabelief. Add XLA compatible Lars.

pull/816/head
Ross Wightman 3 years ago
parent 9541f4963b
commit a426511c95

@ -490,6 +490,33 @@ def test_lamb(optimizer):
_test_model(optimizer, dict(lr=1e-3))
@pytest.mark.parametrize('optimizer', ['lars', 'larc', 'nlars', 'nlarc'])
def test_lars(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=1e-3),
optimizer,
lr=1e-1)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=1e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=1e-3), optimizer)
)
_test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
)
_test_model(optimizer, dict(lr=1e-3))
@pytest.mark.parametrize('optimizer', ['madgrad', 'madgradw'])
def test_madgrad(optimizer):
_test_basic_cases(

@ -68,6 +68,7 @@ class AdaBelief(Optimizer):
for group in self.param_groups:
group.setdefault('amsgrad', False)
@torch.no_grad()
def reset(self):
for group in self.param_groups:
for p in group['params']:
@ -77,14 +78,15 @@ class AdaBelief(Optimizer):
# State initialization
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
state['exp_avg'] = torch.zeros_like(p)
# Exponential moving average of squared gradient values
state['exp_avg_var'] = torch.zeros_like(p.data)
state['exp_avg_var'] = torch.zeros_like(p)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_var'] = torch.zeros_like(p.data)
state['max_exp_avg_var'] = torch.zeros_like(p)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
@ -93,50 +95,47 @@ class AdaBelief(Optimizer):
"""
loss = None
if closure is not None:
loss = closure()
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
# cast data type
half_precision = False
if p.data.dtype == torch.float16:
half_precision = True
p.data = p.data.float()
p.grad = p.grad.float()
grad = p.grad.data
grad = p.grad
if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float()
if grad.is_sparse:
raise RuntimeError(
'AdaBelief does not support sparse gradients, please consider SparseAdam instead')
amsgrad = group['amsgrad']
state = self.state[p]
p_fp32 = p
if p.dtype in {torch.float16, torch.bfloat16}:
p_fp32 = p_fp32.float()
amsgrad = group['amsgrad']
beta1, beta2 = group['betas']
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
state['exp_avg'] = torch.zeros_like(p_fp32)
# Exponential moving average of squared gradient values
state['exp_avg_var'] = torch.zeros_like(p.data)
state['exp_avg_var'] = torch.zeros_like(p_fp32)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_var'] = torch.zeros_like(p.data)
state['max_exp_avg_var'] = torch.zeros_like(p_fp32)
# perform weight decay, check if decoupled weight decay
if group['decoupled_decay']:
if not group['fixed_decay']:
p.data.mul_(1.0 - group['lr'] * group['weight_decay'])
p_fp32.mul_(1.0 - group['lr'] * group['weight_decay'])
else:
p.data.mul_(1.0 - group['weight_decay'])
p_fp32.mul_(1.0 - group['weight_decay'])
else:
if group['weight_decay'] != 0:
grad.add_(p.data, alpha=group['weight_decay'])
grad.add_(p_fp32, alpha=group['weight_decay'])
# get current state variable
exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']
@ -164,7 +163,7 @@ class AdaBelief(Optimizer):
if not group['rectify']:
# Default update
step_size = group['lr'] / bias_correction1
p.data.addcdiv_(exp_avg, denom, value=-step_size)
p_fp32.addcdiv_(exp_avg, denom, value=-step_size)
else:
# Rectified update, forked from RAdam
buffered = group['buffer'][int(state['step'] % 10)]
@ -192,12 +191,11 @@ class AdaBelief(Optimizer):
if num_sma >= 5:
denom = exp_avg_var.sqrt().add_(group['eps'])
p.data.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
p_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
elif step_size > 0:
p.data.add_(exp_avg, alpha=-step_size * group['lr'])
p_fp32.add_(exp_avg, alpha=-step_size * group['lr'])
if half_precision:
p.data = p.data.half()
p.grad = p.grad.half()
if p.dtype in {torch.float16, torch.bfloat16}:
p.copy_(p_fp32)
return loss

@ -76,6 +76,7 @@ class Adafactor(torch.optim.Optimizer):
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
return torch.mul(r_factor, c_factor)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
@ -83,22 +84,22 @@ class Adafactor(torch.optim.Optimizer):
"""
loss = None
if closure is not None:
loss = closure()
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
grad = p.grad
if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float()
if grad.is_sparse:
raise RuntimeError('Adafactor does not support sparse gradients.')
state = self.state[p]
grad_shape = grad.shape
factored, use_first_moment = self._get_options(group, grad_shape)
factored, use_first_moment = self._get_options(group, grad.shape)
# State Initialization
if len(state) == 0:
state['step'] = 0
@ -107,8 +108,8 @@ class Adafactor(torch.optim.Optimizer):
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(grad)
if factored:
state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).to(grad)
state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
state['exp_avg_sq_row'] = torch.zeros(grad.shape[:-1]).to(grad)
state['exp_avg_sq_col'] = torch.zeros(grad.shape[:-2] + grad.shape[-1:]).to(grad)
else:
state['exp_avg_sq'] = torch.zeros_like(grad)
@ -122,12 +123,12 @@ class Adafactor(torch.optim.Optimizer):
else:
state['exp_avg_sq'] = state['exp_avg_sq'].to(grad)
p_data_fp32 = p.data
if p.data.dtype in {torch.float16, torch.bfloat16}:
p_data_fp32 = p_data_fp32.float()
p_fp32 = p
if p.dtype in {torch.float16, torch.bfloat16}:
p_fp32 = p_fp32.float()
state['step'] += 1
state['RMS'] = self._rms(p_data_fp32)
state['RMS'] = self._rms(p_fp32)
lr_t = self._get_lr(group, state)
beta2t = 1.0 - math.pow(state['step'], group['decay_rate'])
@ -157,11 +158,10 @@ class Adafactor(torch.optim.Optimizer):
update = exp_avg
if group['weight_decay'] != 0:
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * lr_t)
p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * lr_t)
p_data_fp32.add_(-update)
if p.data.dtype in {torch.float16, torch.bfloat16}:
p.data.copy_(p_data_fp32)
p_fp32.add_(-update)
if p.dtype in {torch.float16, torch.bfloat16}:
p.copy_(p_fp32)
return loss

@ -26,12 +26,13 @@ def projection(p, grad, perturb, delta: float, wd_ratio: float, eps: float):
wd = 1.
expand_size = (-1,) + (1,) * (len(p.shape) - 1)
for view_func in [_channel_view, _layer_view]:
param_view = view_func(p.data)
param_view = view_func(p)
grad_view = view_func(grad)
cosine_sim = F.cosine_similarity(grad_view, param_view, dim=1, eps=eps).abs_()
# FIXME this is a problem for PyTorch XLA
if cosine_sim.max() < delta / math.sqrt(param_view.size(1)):
p_n = p.data / param_view.norm(p=2, dim=1).add_(eps).reshape(expand_size)
p_n = p / param_view.norm(p=2, dim=1).add_(eps).reshape(expand_size)
perturb -= p_n * view_func(p_n * perturb).sum(dim=1).reshape(expand_size)
wd = wd_ratio
return perturb, wd
@ -47,17 +48,19 @@ class AdamP(Optimizer):
delta=delta, wd_ratio=wd_ratio, nesterov=nesterov)
super(AdamP, self).__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
grad = p.grad
beta1, beta2 = group['betas']
nesterov = group['nesterov']
@ -66,8 +69,8 @@ class AdamP(Optimizer):
# State initialization
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p.data)
state['exp_avg_sq'] = torch.zeros_like(p.data)
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)
# Adam
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
@ -94,9 +97,9 @@ class AdamP(Optimizer):
# Weight decay
if group['weight_decay'] > 0:
p.data.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio)
p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio)
# Step
p.data.add_(perturb, alpha=-step_size)
p.add_(perturb, alpha=-step_size)
return loss

@ -55,6 +55,7 @@ class AdamW(Optimizer):
for group in self.param_groups:
group.setdefault('amsgrad', False)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
@ -64,7 +65,8 @@ class AdamW(Optimizer):
"""
loss = None
if closure is not None:
loss = closure()
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group['params']:
@ -75,7 +77,7 @@ class AdamW(Optimizer):
p.data.mul_(1 - group['lr'] * group['weight_decay'])
# Perform optimization step
grad = p.grad.data
grad = p.grad
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
amsgrad = group['amsgrad']
@ -86,12 +88,12 @@ class AdamW(Optimizer):
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
state['exp_avg'] = torch.zeros_like(p)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
state['exp_avg_sq'] = torch.zeros_like(p)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p.data)
state['max_exp_avg_sq'] = torch.zeros_like(p)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
if amsgrad:
@ -115,6 +117,6 @@ class AdamW(Optimizer):
step_size = group['lr'] / bias_correction1
p.data.addcdiv_(exp_avg, denom, value=-step_size)
p.addcdiv_(exp_avg, denom, value=-step_size)
return loss

@ -5,10 +5,14 @@ This optimizer code was adapted from the following (starting with latest)
* https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
* https://github.com/cybertronai/pytorch-lamb
Use FusedLamb if you can. The reason for including this variant of Lamb is to have a version that is
similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install APEX for whatever reason.
Use FusedLamb if you can (GPU). The reason for including this variant of Lamb is to have a version that is
similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install/use APEX.
In addition to some cleanup, this Lamb impl has been modified to support PyTorch XLA and has been tested on TPU.
Original copyrights for above sources are below.
Modifications Copyright 2021 Ross Wightman
"""
# Copyright (c) 2021, Habana Labs Ltd. All rights reserved.
@ -60,8 +64,7 @@ class Lamb(Optimizer):
LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its norm. (default: (0.9, 0.999))
@ -72,8 +75,7 @@ class Lamb(Optimizer):
calculating running averages of gradient. (default: True)
set_grad_none (bool, optional): whether set grad to None when zero_grad()
method is called. (default: True)
max_grad_norm (float, optional): value used to clip global grad norm
(default: 1.0)
max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0)
use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False)
@ -91,25 +93,26 @@ class Lamb(Optimizer):
grad_averaging=grad_averaging, max_grad_norm=max_grad_norm, use_nvlamb=use_nvlamb)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
device = self.param_groups[0]["params"][0].device
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
loss = None
if closure is not None:
loss = closure()
with torch.enable_grad():
loss = closure()
device = self.param_groups[0]["params"][0].device
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
global_grad_norm = torch.zeros(1, device=device)
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
grad = p.grad
if grad.is_sparse:
raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
global_grad_norm.add_(grad.pow(2).sum())
@ -145,15 +148,15 @@ class Lamb(Optimizer):
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data.div_(clip_global_grad_norm)
grad = p.grad.div_(clip_global_grad_norm)
state = self.state[p]
# State initialization
if len(state) == 0:
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of gradient valuesa
state['exp_avg'] = torch.zeros_like(p)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
state['exp_avg_sq'] = torch.zeros_like(p)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
@ -166,20 +169,21 @@ class Lamb(Optimizer):
weight_decay = group['weight_decay']
if weight_decay != 0:
update.add_(p.data, alpha=weight_decay)
update.add_(p, alpha=weight_decay)
trust_ratio = one_tensor
if weight_decay != 0 or group['use_nvlamb']:
# Layer adaptation. By default, skip layer adaptation on parameters that are
# excluded from weight decay, unless use_nvlamb == True, then always enabled.
w_norm = p.data.norm(2.0)
w_norm = p.norm(2.0)
g_norm = update.norm(2.0)
# FIXME nested where required since logical and/or not working in PT XLA
trust_ratio = torch.where(
w_norm > 0,
torch.where(g_norm > 0, w_norm / g_norm, one_tensor),
one_tensor,
)
update.mul_(trust_ratio)
p.data.add_(update, alpha=-group['lr'])
update.mul_(trust_ratio)
p.add_(update, alpha=-group['lr'])
return loss

@ -0,0 +1,136 @@
""" PyTorch LARS / LARC Optimizer
An implementation of LARS (SGD) + LARC in PyTorch
Based on:
* PyTorch SGD: https://github.com/pytorch/pytorch/blob/1.7/torch/optim/sgd.py#L100
* NVIDIA APEX LARC: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
Additional cleanup and modifications to properly support PyTorch XLA.
Copyright 2021 Ross Wightman
"""
import torch
from torch.optim.optimizer import Optimizer, required
class Lars(Optimizer):
""" LARS for PyTorch
Paper: `Large batch training of Convolutional Networks` - https://arxiv.org/pdf/1708.03888.pdf
Args:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
momentum (float, optional): momentum factor (default: 0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
dampening (float, optional): dampening for momentum (default: 0)
nesterov (bool, optional): enables Nesterov momentum (default: False)
trust_coeff (float): trust coefficient for computing adaptive lr / trust_ratio (default: 0.001)
eps (float): eps for division denominator (default: 1e-8)
larc (bool): enable LARC clipping (default: False)
always_scale (bool): always apply LARS scaling, otherwise only when group weight_decay != 0 (default: False)
"""
def __init__(
self,
params,
lr=required,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
trust_coeff=0.001,
eps=1e-8,
larc=False,
always_scale=False,
):
if lr is not required and lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if momentum < 0.0:
raise ValueError(f"Invalid momentum value: {momentum}")
if weight_decay < 0.0:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
defaults = dict(
lr=lr,
momentum=momentum,
dampening=dampening,
weight_decay=weight_decay,
nesterov=nesterov,
trust_coeff=trust_coeff,
eps=eps,
larc=larc,
always_scale=always_scale,
)
super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("nesterov", False)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
device = self.param_groups[0]["params"][0].device
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
# exclude scaling for params with 0 weight decay
for group in self.param_groups:
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
trust_coeff = group['trust_coeff']
eps = group['eps']
for p in group['params']:
if p.grad is None:
continue
grad = p.grad
# apply LARS scaling, LARC clipping, weight decay
# ref: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
if weight_decay != 0 or group['always_scale']:
w_norm = p.norm(2.0)
g_norm = grad.norm(2.0)
trust_ratio = trust_coeff * w_norm / (g_norm + w_norm * weight_decay + eps)
# FIXME nested where required since logical and/or not working in PT XLA
trust_ratio = torch.where(
w_norm > 0,
torch.where(g_norm > 0, trust_ratio, one_tensor),
one_tensor,
)
if group['larc']:
trust_ratio = torch.minimum(trust_ratio / group['lr'], one_tensor)
grad.add(p, alpha=weight_decay)
grad.mul_(trust_ratio)
# apply SGD update https://github.com/pytorch/pytorch/blob/1.7/torch/optim/sgd.py#L100
if momentum != 0:
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
buf = param_state['momentum_buffer'] = torch.clone(grad).detach()
else:
buf = param_state['momentum_buffer']
buf.mul_(momentum).add_(grad, alpha=1. - dampening)
if nesterov:
grad = grad.add(buf, alpha=momentum)
else:
grad = buf
p.add_(grad, alpha=-group['lr'])
return loss

@ -27,22 +27,24 @@ class Lookahead(Optimizer):
for group in self._base_optimizer.param_groups:
group.setdefault(name, default)
@torch.no_grad()
def update_slow(self, group):
for fast_p in group["params"]:
if fast_p.grad is None:
continue
param_state = self._base_optimizer.state[fast_p]
if 'lookahead_slow_buff' not in param_state:
param_state['lookahead_slow_buff'] = torch.empty_like(fast_p.data)
param_state['lookahead_slow_buff'].copy_(fast_p.data)
param_state['lookahead_slow_buff'] = torch.empty_like(fast_p)
param_state['lookahead_slow_buff'].copy_(fast_p)
slow = param_state['lookahead_slow_buff']
slow.add_(fast_p.data - slow, alpha=group['lookahead_alpha'])
fast_p.data.copy_(slow)
slow.add_(fast_p - slow, alpha=group['lookahead_alpha'])
fast_p.copy_(slow)
def sync_lookahead(self):
for group in self._base_optimizer.param_groups:
self.update_slow(group)
@torch.no_grad()
def step(self, closure=None):
loss = self._base_optimizer.step(closure)
for group in self._base_optimizer.param_groups:

@ -82,6 +82,7 @@ class MADGRAD(torch.optim.Optimizer):
def supports_flat_params(self) -> bool:
return True
@torch.no_grad()
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
"""Performs a single optimization step.
@ -91,13 +92,10 @@ class MADGRAD(torch.optim.Optimizer):
"""
loss = None
if closure is not None:
loss = closure()
with torch.enable_grad():
loss = closure()
# step counter must be stored in state to ensure correct behavior under
# optimizer sharding
if 'k' not in self.state:
self.state['k'] = torch.tensor([0], dtype=torch.long)
k = self.state['k'].item()
step = self.state.setdefault('step', 0) # k
for group in self.param_groups:
eps = group["eps"]
@ -106,19 +104,19 @@ class MADGRAD(torch.optim.Optimizer):
momentum = group["momentum"]
ck = 1 - momentum
lamb = lr * math.pow(k + 1, 0.5)
lamb = lr * math.sqrt(step + 1)
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
grad = p.grad
state = self.state[p]
if "grad_sum_sq" not in state:
state["grad_sum_sq"] = torch.zeros_like(p.data).detach()
state["s"] = torch.zeros_like(p.data).detach()
state["grad_sum_sq"] = torch.zeros_like(p)
state["s"] = torch.zeros_like(p)
if momentum != 0:
state["x0"] = torch.clone(p.data).detach()
state["x0"] = torch.clone(p).detach()
if momentum != 0.0 and grad.is_sparse:
raise RuntimeError("momentum != 0 is not compatible with sparse gradients")
@ -129,11 +127,11 @@ class MADGRAD(torch.optim.Optimizer):
# Apply weight decay
if weight_decay != 0:
if group['decoupled_decay']:
p.data.mul_(1.0 - group['lr'] * weight_decay)
p.mul_(1.0 - group['lr'] * weight_decay)
else:
if grad.is_sparse:
raise RuntimeError("weight_decay option is not compatible with sparse gradients")
grad.add_(p.data, alpha=weight_decay)
grad.add_(p, alpha=weight_decay)
if grad.is_sparse:
grad = grad.coalesce()
@ -161,12 +159,12 @@ class MADGRAD(torch.optim.Optimizer):
p_kp1_masked_vals = x0_masked_vals.addcdiv(s_masked._values(), rms_masked_vals, value=-1)
# Copy updated masked p to dense p using an add operation
p_masked._values().add_(p_kp1_masked_vals, alpha=-1)
p.data.add_(p_masked, alpha=-1)
p.add_(p_masked, alpha=-1)
else:
if momentum == 0:
# Compute x_0 from other known quantities
rms = grad_sum_sq.pow(1 / 3).add_(eps)
x0 = p.data.addcdiv(s, rms, value=1)
x0 = p.addcdiv(s, rms, value=1)
else:
x0 = state["x0"]
@ -175,16 +173,16 @@ class MADGRAD(torch.optim.Optimizer):
rms = grad_sum_sq.pow(1 / 3).add_(eps)
# Update s
s.data.add_(grad, alpha=lamb)
s.add_(grad, alpha=lamb)
# Step
if momentum == 0:
p.data.copy_(x0.addcdiv(s, rms, value=-1))
p.copy_(x0.addcdiv(s, rms, value=-1))
else:
z = x0.addcdiv(s, rms, value=-1)
# p is a moving average of z
p.data.mul_(1 - ck).add_(z, alpha=ck)
p.mul_(1 - ck).add_(z, alpha=ck)
self.state['k'] += 1
self.state['step'] += 1
return loss

@ -1,3 +1,5 @@
import math
import torch
from torch.optim.optimizer import Optimizer
@ -33,6 +35,7 @@ class Nadam(Optimizer):
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, schedule_decay=schedule_decay)
super(Nadam, self).__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
@ -42,21 +45,22 @@ class Nadam(Optimizer):
"""
loss = None
if closure is not None:
loss = closure()
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
grad = p.grad
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
state['m_schedule'] = 1.
state['exp_avg'] = torch.zeros_like(p.data)
state['exp_avg_sq'] = torch.zeros_like(p.data)
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)
# Warming momentum schedule
m_schedule = state['m_schedule']
@ -66,9 +70,10 @@ class Nadam(Optimizer):
eps = group['eps']
state['step'] += 1
t = state['step']
bias_correction2 = 1 - beta2 ** t
if group['weight_decay'] != 0:
grad = grad.add(p.data, alpha=group['weight_decay'])
grad = grad.add(p, alpha=group['weight_decay'])
momentum_cache_t = beta1 * (1. - 0.5 * (0.96 ** (t * schedule_decay)))
momentum_cache_t_1 = beta1 * (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay)))
@ -79,10 +84,9 @@ class Nadam(Optimizer):
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1. - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1. - beta2)
exp_avg_sq_prime = exp_avg_sq / (1. - beta2 ** t)
denom = exp_avg_sq_prime.sqrt_().add_(eps)
p.data.addcdiv_(grad, denom, value=-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new))
p.data.addcdiv_(exp_avg, denom, value=-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next))
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
p.addcdiv_(grad, denom, value=-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new))
p.addcdiv_(exp_avg, denom, value=-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next))
return loss

@ -51,6 +51,7 @@ class NvNovoGrad(Optimizer):
for group in self.param_groups:
group.setdefault('amsgrad', False)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
@ -60,13 +61,14 @@ class NvNovoGrad(Optimizer):
"""
loss = None
if closure is not None:
loss = closure()
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
grad = p.grad
if grad.is_sparse:
raise RuntimeError('Sparse gradients are not supported.')
amsgrad = group['amsgrad']
@ -77,7 +79,7 @@ class NvNovoGrad(Optimizer):
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
state['exp_avg'] = torch.zeros_like(p)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
if amsgrad:
@ -108,11 +110,11 @@ class NvNovoGrad(Optimizer):
grad.div_(denom)
if group['weight_decay'] != 0:
grad.add_(p.data, alpha=group['weight_decay'])
grad.add_(p, alpha=group['weight_decay'])
if group['grad_averaging']:
grad.mul_(1 - beta1)
exp_avg.mul_(beta1).add_(grad)
p.data.add_(exp_avg, alpha=-group['lr'])
p.add_(exp_avg, alpha=-group['lr'])
return loss

@ -12,6 +12,7 @@ from .adafactor import Adafactor
from .adahessian import Adahessian
from .adamp import AdamP
from .lamb import Lamb
from .lars import Lars
from .lookahead import Lookahead
from .madgrad import MADGRAD
from .nadam import Nadam
@ -163,6 +164,14 @@ def create_optimizer_v2(
optimizer = Adafactor(parameters, **opt_args)
elif opt_lower == 'lamb':
optimizer = Lamb(parameters, **opt_args)
elif opt_lower == 'larc':
optimizer = Lars(parameters, momentum=momentum, larc=True, **opt_args)
elif opt_lower == 'lars':
optimizer = Lars(parameters, momentum=momentum, **opt_args)
elif opt_lower == 'nlarc':
optimizer = Lars(parameters, momentum=momentum, larc=True, nesterov=True, **opt_args)
elif opt_lower == 'nlars':
optimizer = Lars(parameters, momentum=momentum, nesterov=True, **opt_args)
elif opt_lower == 'madgrad':
optimizer = MADGRAD(parameters, momentum=momentum, **opt_args)
elif opt_lower == 'madgradw':

@ -18,31 +18,33 @@ class RAdam(Optimizer):
def __setstate__(self, state):
super(RAdam, self).__setstate__(state)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data.float()
grad = p.grad.float()
if grad.is_sparse:
raise RuntimeError('RAdam does not support sparse gradients')
p_data_fp32 = p.data.float()
p_fp32 = p.float()
state = self.state[p]
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p_data_fp32)
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
state['exp_avg'] = torch.zeros_like(p_fp32)
state['exp_avg_sq'] = torch.zeros_like(p_fp32)
else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
state['exp_avg'] = state['exp_avg'].type_as(p_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_fp32)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
@ -73,15 +75,15 @@ class RAdam(Optimizer):
buffered[2] = step_size
if group['weight_decay'] != 0:
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])
p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * group['lr'])
# more conservative since it's an approximated value
if num_sma >= 5:
denom = exp_avg_sq.sqrt().add_(group['eps'])
p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size)
p_fp32.addcdiv_(exp_avg, denom, value=-step_size)
else:
p_data_fp32.add_(exp_avg, alpha=-step_size)
p_fp32.add_(exp_avg, alpha=-step_size)
p.data.copy_(p_data_fp32)
p.copy_(p_fp32)
return loss

@ -4,7 +4,7 @@ Originally cut & paste from PyTorch RMSProp
https://github.com/pytorch/pytorch/blob/063946d2b3f3f1e953a2a3b54e0b34f1393de295/torch/optim/rmsprop.py
Licensed under BSD-Clause 3 (ish), https://github.com/pytorch/pytorch/blob/master/LICENSE
Modifications Copyright 2020 Ross Wightman
Modifications Copyright 2021 Ross Wightman
"""
import torch
@ -69,6 +69,7 @@ class RMSpropTF(Optimizer):
group.setdefault('momentum', 0)
group.setdefault('centered', False)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
@ -78,13 +79,14 @@ class RMSpropTF(Optimizer):
"""
loss = None
if closure is not None:
loss = closure()
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
grad = p.grad
if grad.is_sparse:
raise RuntimeError('RMSprop does not support sparse gradients')
state = self.state[p]
@ -92,11 +94,11 @@ class RMSpropTF(Optimizer):
# State initialization
if len(state) == 0:
state['step'] = 0
state['square_avg'] = torch.ones_like(p.data) # PyTorch inits to zero
state['square_avg'] = torch.ones_like(p) # PyTorch inits to zero
if group['momentum'] > 0:
state['momentum_buffer'] = torch.zeros_like(p.data)
state['momentum_buffer'] = torch.zeros_like(p)
if group['centered']:
state['grad_avg'] = torch.zeros_like(p.data)
state['grad_avg'] = torch.zeros_like(p)
square_avg = state['square_avg']
one_minus_alpha = 1. - group['alpha']
@ -105,9 +107,9 @@ class RMSpropTF(Optimizer):
if group['weight_decay'] != 0:
if group['decoupled_decay']:
p.data.mul_(1. - group['lr'] * group['weight_decay'])
p.mul_(1. - group['lr'] * group['weight_decay'])
else:
grad = grad.add(p.data, alpha=group['weight_decay'])
grad = grad.add(p, alpha=group['weight_decay'])
# Tensorflow order of ops for updating squared avg
square_avg.add_(grad.pow(2) - square_avg, alpha=one_minus_alpha)
@ -126,12 +128,12 @@ class RMSpropTF(Optimizer):
# Tensorflow accumulates the LR scaling in the momentum buffer
if group['lr_in_momentum']:
buf.mul_(group['momentum']).addcdiv_(grad, avg, value=group['lr'])
p.data.add_(-buf)
p.add_(-buf)
else:
# PyTorch scales the param update by LR
buf.mul_(group['momentum']).addcdiv_(grad, avg)
p.data.add_(buf, alpha=-group['lr'])
p.add_(buf, alpha=-group['lr'])
else:
p.data.addcdiv_(grad, avg, value=-group['lr'])
p.addcdiv_(grad, avg, value=-group['lr'])
return loss

@ -24,10 +24,12 @@ class SGDP(Optimizer):
nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio)
super(SGDP, self).__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
weight_decay = group['weight_decay']
@ -38,12 +40,12 @@ class SGDP(Optimizer):
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
grad = p.grad
state = self.state[p]
# State initialization
if len(state) == 0:
state['momentum'] = torch.zeros_like(p.data)
state['momentum'] = torch.zeros_like(p)
# SGD
buf = state['momentum']
@ -60,9 +62,9 @@ class SGDP(Optimizer):
# Weight decay
if weight_decay != 0:
p.data.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum))
p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum))
# Step
p.data.add_(d_p, alpha=-group['lr'])
p.add_(d_p, alpha=-group['lr'])
return loss

Loading…
Cancel
Save