MOAR optimizer changes. Woo!

pull/816/head
Ross Wightman 3 years ago
parent 42c1f0cf6c
commit c207e02782

@ -23,6 +23,15 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
## What's New
### Aug 18, 2021
* Optimizer bonanza!
* Add LAMB and LARS optimizers, incl trust ratio clipping options. Tweaked to work properly in PyTorch XLA (tested on TPUs w/ `timm bits` [branch](https://github.com/rwightman/pytorch-image-models/tree/bits_and_tpu/timm/bits))
* Add MADGRAD from FB research w/ a few tweaks (decoupled decay option, step handling that works with PyTorch XLA)
* Some cleanup on all optimizers and factory. No more `.data`, a bit more consistency, unit tests for all!
* SGDP and AdamP still won't work with PyTorch XLA but others should (have yet to test Adabelief, Adafactor, Adahessian myself).
* EfficientNet-V2 XL TF ported weights added, but they don't validate well in PyTorch (L is better). The pre-processing for the V2 TF training is a bit diff and the fine-tuned 21k -> 1k weights are very sensitive and less robust than the 1k weights.
* Added PyTorch trained EfficientNet-V2 'Tiny' w/ GlobalContext attn weights. Only .1-.2 top-1 better than the SE so more of a curiosity for those interested.
### July 12, 2021
* Add XCiT models from [official facebook impl](https://github.com/facebookresearch/xcit). Contributed by [Alexander Soare](https://github.com/alexander-soare)

@ -463,26 +463,26 @@ def test_adafactor(optimizer):
_test_model(optimizer, dict(lr=5e-2))
@pytest.mark.parametrize('optimizer', ['lamb'])
@pytest.mark.parametrize('optimizer', ['lamb', 'lambc'])
def test_lamb(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=3e-3),
_build_params_dict(weight, bias, lr=1e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3),
_build_params_dict_single(weight, bias, lr=1e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
_build_params_dict_single(weight, bias, lr=1e-3), optimizer)
)
_test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)

@ -73,10 +73,9 @@ class Lamb(Optimizer):
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
grad_averaging (bool, optional): whether apply (1-beta2) to grad when
calculating running averages of gradient. (default: True)
set_grad_none (bool, optional): whether set grad to None when zero_grad()
method is called. (default: True)
max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0)
use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0
trust_clip (bool): enable LAMBC trust ratio clipping (default: False)
always_adapt (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False)
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
@ -87,10 +86,11 @@ class Lamb(Optimizer):
def __init__(
self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-6,
weight_decay=0.01, grad_averaging=True, max_grad_norm=1.0, use_nvlamb=False):
weight_decay=0.01, grad_averaging=True, max_grad_norm=1.0, trust_clip=False, always_adapt=False):
defaults = dict(
lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging, max_grad_norm=max_grad_norm, use_nvlamb=use_nvlamb)
grad_averaging=grad_averaging, max_grad_norm=max_grad_norm,
trust_clip=trust_clip, always_adapt=always_adapt)
super().__init__(params, defaults)
@torch.no_grad()
@ -105,7 +105,7 @@ class Lamb(Optimizer):
with torch.enable_grad():
loss = closure()
device = self.param_groups[0]["params"][0].device
device = self.param_groups[0]['params'][0].device
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
global_grad_norm = torch.zeros(1, device=device)
for group in self.param_groups:
@ -171,9 +171,9 @@ class Lamb(Optimizer):
if weight_decay != 0:
update.add_(p, alpha=weight_decay)
if weight_decay != 0 or group['use_nvlamb']:
# Layer adaptation. By default, skip layer adaptation on parameters that are
# excluded from weight decay, unless use_nvlamb == True, then always enabled.
if weight_decay != 0 or group['always_adapt']:
# Layer-wise LR adaptation. By default, skip adaptation on parameters that are
# excluded from weight decay, unless always_adapt == True, then always enabled.
w_norm = p.norm(2.0)
g_norm = update.norm(2.0)
# FIXME nested where required since logical and/or not working in PT XLA
@ -182,6 +182,9 @@ class Lamb(Optimizer):
torch.where(g_norm > 0, w_norm / g_norm, one_tensor),
one_tensor,
)
if group['trust_clip']:
# LAMBC trust clipping, upper bound fixed at one
trust_ratio = torch.minimum(trust_ratio, one_tensor)
update.mul_(trust_ratio)
p.add_(update, alpha=-group['lr'])

@ -11,7 +11,7 @@ Additional cleanup and modifications to properly support PyTorch XLA.
Copyright 2021 Ross Wightman
"""
import torch
from torch.optim.optimizer import Optimizer, required
from torch.optim.optimizer import Optimizer
class Lars(Optimizer):
@ -21,31 +21,31 @@ class Lars(Optimizer):
Args:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
lr (float, optional): learning rate (default: 1.0).
momentum (float, optional): momentum factor (default: 0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
dampening (float, optional): dampening for momentum (default: 0)
nesterov (bool, optional): enables Nesterov momentum (default: False)
trust_coeff (float): trust coefficient for computing adaptive lr / trust_ratio (default: 0.001)
eps (float): eps for division denominator (default: 1e-8)
larc (bool): enable LARC clipping (default: False)
always_scale (bool): always apply LARS scaling, otherwise only when group weight_decay != 0 (default: False)
trust_clip (bool): enable LARC trust ratio clipping (default: False)
always_adapt (bool): always apply LARS LR adapt, otherwise only when group weight_decay != 0 (default: False)
"""
def __init__(
self,
params,
lr=required,
lr=1.0,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
trust_coeff=0.001,
eps=1e-8,
larc=False,
always_scale=False,
trust_clip=False,
always_adapt=False,
):
if lr is not required and lr < 0.0:
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if momentum < 0.0:
raise ValueError(f"Invalid momentum value: {momentum}")
@ -62,8 +62,8 @@ class Lars(Optimizer):
nesterov=nesterov,
trust_coeff=trust_coeff,
eps=eps,
larc=larc,
always_scale=always_scale,
trust_clip=trust_clip,
always_adapt=always_adapt,
)
super().__init__(params, defaults)
@ -84,7 +84,7 @@ class Lars(Optimizer):
with torch.enable_grad():
loss = closure()
device = self.param_groups[0]["params"][0].device
device = self.param_groups[0]['params'][0].device
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
# exclude scaling for params with 0 weight decay
@ -101,9 +101,9 @@ class Lars(Optimizer):
continue
grad = p.grad
# apply LARS scaling, LARC clipping, weight decay
# apply LARS LR adaptation, LARC clipping, weight decay
# ref: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
if weight_decay != 0 or group['always_scale']:
if weight_decay != 0 or group['always_adapt']:
w_norm = p.norm(2.0)
g_norm = grad.norm(2.0)
trust_ratio = trust_coeff * w_norm / (g_norm + w_norm * weight_decay + eps)
@ -113,7 +113,7 @@ class Lars(Optimizer):
torch.where(g_norm > 0, trust_ratio, one_tensor),
one_tensor,
)
if group['larc']:
if group['trust_clip']:
trust_ratio = torch.minimum(trust_ratio / group['lr'], one_tensor)
grad.add(p, alpha=weight_decay)
grad.mul_(trust_ratio)

@ -87,42 +87,39 @@ class MADGRAD(torch.optim.Optimizer):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
closure (callable, optional): A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
step = self.state.setdefault('step', 0) # k
for group in self.param_groups:
eps = group["eps"]
lr = group["lr"] + eps
weight_decay = group["weight_decay"]
momentum = group["momentum"]
eps = group['eps']
lr = group['lr'] + eps
weight_decay = group['weight_decay']
momentum = group['momentum']
ck = 1 - momentum
lamb = lr * math.sqrt(step + 1)
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
state = self.state[p]
if "grad_sum_sq" not in state:
state["grad_sum_sq"] = torch.zeros_like(p)
state["s"] = torch.zeros_like(p)
if momentum != 0:
state["x0"] = torch.clone(p).detach()
if momentum != 0.0 and grad.is_sparse:
raise RuntimeError("momentum != 0 is not compatible with sparse gradients")
grad_sum_sq = state["grad_sum_sq"]
s = state["s"]
state = self.state[p]
if len(state) == 0:
state['step'] = 0
state['grad_sum_sq'] = torch.zeros_like(p)
state['s'] = torch.zeros_like(p)
if momentum != 0:
state['x0'] = torch.clone(p).detach()
state['step'] += 1
grad_sum_sq = state['grad_sum_sq']
s = state['s']
lamb = lr * math.sqrt(state['step'])
# Apply weight decay
if weight_decay != 0:
@ -166,7 +163,7 @@ class MADGRAD(torch.optim.Optimizer):
rms = grad_sum_sq.pow(1 / 3).add_(eps)
x0 = p.addcdiv(s, rms, value=1)
else:
x0 = state["x0"]
x0 = state['x0']
# Accumulate second moments
grad_sum_sq.addcmul_(grad, grad, value=lamb)
@ -184,5 +181,4 @@ class MADGRAD(torch.optim.Optimizer):
# p is a moving average of z
p.mul_(1 - ck).add_(z, alpha=ck)
self.state['step'] += 1
return loss

@ -164,12 +164,14 @@ def create_optimizer_v2(
optimizer = Adafactor(parameters, **opt_args)
elif opt_lower == 'lamb':
optimizer = Lamb(parameters, **opt_args)
elif opt_lower == 'lambc':
optimizer = Lamb(parameters, trust_clip=True, **opt_args)
elif opt_lower == 'larc':
optimizer = Lars(parameters, momentum=momentum, larc=True, **opt_args)
optimizer = Lars(parameters, momentum=momentum, trust_clip=True, **opt_args)
elif opt_lower == 'lars':
optimizer = Lars(parameters, momentum=momentum, **opt_args)
elif opt_lower == 'nlarc':
optimizer = Lars(parameters, momentum=momentum, larc=True, nesterov=True, **opt_args)
optimizer = Lars(parameters, momentum=momentum, trust_clip=True, nesterov=True, **opt_args)
elif opt_lower == 'nlars':
optimizer = Lars(parameters, momentum=momentum, nesterov=True, **opt_args)
elif opt_lower == 'madgrad':

Loading…
Cancel
Save