MOAR optimizer changes. Woo!

pull/816/head
Ross Wightman 3 years ago
parent 42c1f0cf6c
commit c207e02782

@ -23,6 +23,15 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
## What's New ## What's New
### Aug 18, 2021
* Optimizer bonanza!
* Add LAMB and LARS optimizers, incl trust ratio clipping options. Tweaked to work properly in PyTorch XLA (tested on TPUs w/ `timm bits` [branch](https://github.com/rwightman/pytorch-image-models/tree/bits_and_tpu/timm/bits))
* Add MADGRAD from FB research w/ a few tweaks (decoupled decay option, step handling that works with PyTorch XLA)
* Some cleanup on all optimizers and factory. No more `.data`, a bit more consistency, unit tests for all!
* SGDP and AdamP still won't work with PyTorch XLA but others should (have yet to test Adabelief, Adafactor, Adahessian myself).
* EfficientNet-V2 XL TF ported weights added, but they don't validate well in PyTorch (L is better). The pre-processing for the V2 TF training is a bit diff and the fine-tuned 21k -> 1k weights are very sensitive and less robust than the 1k weights.
* Added PyTorch trained EfficientNet-V2 'Tiny' w/ GlobalContext attn weights. Only .1-.2 top-1 better than the SE so more of a curiosity for those interested.
### July 12, 2021 ### July 12, 2021
* Add XCiT models from [official facebook impl](https://github.com/facebookresearch/xcit). Contributed by [Alexander Soare](https://github.com/alexander-soare) * Add XCiT models from [official facebook impl](https://github.com/facebookresearch/xcit). Contributed by [Alexander Soare](https://github.com/alexander-soare)

@ -463,26 +463,26 @@ def test_adafactor(optimizer):
_test_model(optimizer, dict(lr=5e-2)) _test_model(optimizer, dict(lr=5e-2))
@pytest.mark.parametrize('optimizer', ['lamb']) @pytest.mark.parametrize('optimizer', ['lamb', 'lambc'])
def test_lamb(optimizer): def test_lamb(optimizer):
_test_basic_cases( _test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
) )
_test_basic_cases( _test_basic_cases(
lambda weight, bias: create_optimizer_v2( lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=3e-3), _build_params_dict(weight, bias, lr=1e-3),
optimizer, optimizer,
lr=1e-3) lr=1e-3)
) )
_test_basic_cases( _test_basic_cases(
lambda weight, bias: create_optimizer_v2( lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3), _build_params_dict_single(weight, bias, lr=1e-3),
optimizer, optimizer,
lr=1e-3) lr=1e-3)
) )
_test_basic_cases( _test_basic_cases(
lambda weight, bias: create_optimizer_v2( lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3), optimizer) _build_params_dict_single(weight, bias, lr=1e-3), optimizer)
) )
_test_rosenbrock( _test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)

@ -73,10 +73,9 @@ class Lamb(Optimizer):
weight_decay (float, optional): weight decay (L2 penalty) (default: 0) weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
grad_averaging (bool, optional): whether apply (1-beta2) to grad when grad_averaging (bool, optional): whether apply (1-beta2) to grad when
calculating running averages of gradient. (default: True) calculating running averages of gradient. (default: True)
set_grad_none (bool, optional): whether set grad to None when zero_grad()
method is called. (default: True)
max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0) max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0)
use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0 trust_clip (bool): enable LAMBC trust ratio clipping (default: False)
always_adapt (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False) weight decay parameter (default: False)
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
@ -87,10 +86,11 @@ class Lamb(Optimizer):
def __init__( def __init__(
self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-6, self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-6,
weight_decay=0.01, grad_averaging=True, max_grad_norm=1.0, use_nvlamb=False): weight_decay=0.01, grad_averaging=True, max_grad_norm=1.0, trust_clip=False, always_adapt=False):
defaults = dict( defaults = dict(
lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay, lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging, max_grad_norm=max_grad_norm, use_nvlamb=use_nvlamb) grad_averaging=grad_averaging, max_grad_norm=max_grad_norm,
trust_clip=trust_clip, always_adapt=always_adapt)
super().__init__(params, defaults) super().__init__(params, defaults)
@torch.no_grad() @torch.no_grad()
@ -105,7 +105,7 @@ class Lamb(Optimizer):
with torch.enable_grad(): with torch.enable_grad():
loss = closure() loss = closure()
device = self.param_groups[0]["params"][0].device device = self.param_groups[0]['params'][0].device
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
global_grad_norm = torch.zeros(1, device=device) global_grad_norm = torch.zeros(1, device=device)
for group in self.param_groups: for group in self.param_groups:
@ -171,9 +171,9 @@ class Lamb(Optimizer):
if weight_decay != 0: if weight_decay != 0:
update.add_(p, alpha=weight_decay) update.add_(p, alpha=weight_decay)
if weight_decay != 0 or group['use_nvlamb']: if weight_decay != 0 or group['always_adapt']:
# Layer adaptation. By default, skip layer adaptation on parameters that are # Layer-wise LR adaptation. By default, skip adaptation on parameters that are
# excluded from weight decay, unless use_nvlamb == True, then always enabled. # excluded from weight decay, unless always_adapt == True, then always enabled.
w_norm = p.norm(2.0) w_norm = p.norm(2.0)
g_norm = update.norm(2.0) g_norm = update.norm(2.0)
# FIXME nested where required since logical and/or not working in PT XLA # FIXME nested where required since logical and/or not working in PT XLA
@ -182,6 +182,9 @@ class Lamb(Optimizer):
torch.where(g_norm > 0, w_norm / g_norm, one_tensor), torch.where(g_norm > 0, w_norm / g_norm, one_tensor),
one_tensor, one_tensor,
) )
if group['trust_clip']:
# LAMBC trust clipping, upper bound fixed at one
trust_ratio = torch.minimum(trust_ratio, one_tensor)
update.mul_(trust_ratio) update.mul_(trust_ratio)
p.add_(update, alpha=-group['lr']) p.add_(update, alpha=-group['lr'])

@ -11,7 +11,7 @@ Additional cleanup and modifications to properly support PyTorch XLA.
Copyright 2021 Ross Wightman Copyright 2021 Ross Wightman
""" """
import torch import torch
from torch.optim.optimizer import Optimizer, required from torch.optim.optimizer import Optimizer
class Lars(Optimizer): class Lars(Optimizer):
@ -21,31 +21,31 @@ class Lars(Optimizer):
Args: Args:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups. params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
lr (float, optional): learning rate. (default: 1e-3) lr (float, optional): learning rate (default: 1.0).
momentum (float, optional): momentum factor (default: 0) momentum (float, optional): momentum factor (default: 0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0) weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
dampening (float, optional): dampening for momentum (default: 0) dampening (float, optional): dampening for momentum (default: 0)
nesterov (bool, optional): enables Nesterov momentum (default: False) nesterov (bool, optional): enables Nesterov momentum (default: False)
trust_coeff (float): trust coefficient for computing adaptive lr / trust_ratio (default: 0.001) trust_coeff (float): trust coefficient for computing adaptive lr / trust_ratio (default: 0.001)
eps (float): eps for division denominator (default: 1e-8) eps (float): eps for division denominator (default: 1e-8)
larc (bool): enable LARC clipping (default: False) trust_clip (bool): enable LARC trust ratio clipping (default: False)
always_scale (bool): always apply LARS scaling, otherwise only when group weight_decay != 0 (default: False) always_adapt (bool): always apply LARS LR adapt, otherwise only when group weight_decay != 0 (default: False)
""" """
def __init__( def __init__(
self, self,
params, params,
lr=required, lr=1.0,
momentum=0, momentum=0,
dampening=0, dampening=0,
weight_decay=0, weight_decay=0,
nesterov=False, nesterov=False,
trust_coeff=0.001, trust_coeff=0.001,
eps=1e-8, eps=1e-8,
larc=False, trust_clip=False,
always_scale=False, always_adapt=False,
): ):
if lr is not required and lr < 0.0: if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}") raise ValueError(f"Invalid learning rate: {lr}")
if momentum < 0.0: if momentum < 0.0:
raise ValueError(f"Invalid momentum value: {momentum}") raise ValueError(f"Invalid momentum value: {momentum}")
@ -62,8 +62,8 @@ class Lars(Optimizer):
nesterov=nesterov, nesterov=nesterov,
trust_coeff=trust_coeff, trust_coeff=trust_coeff,
eps=eps, eps=eps,
larc=larc, trust_clip=trust_clip,
always_scale=always_scale, always_adapt=always_adapt,
) )
super().__init__(params, defaults) super().__init__(params, defaults)
@ -84,7 +84,7 @@ class Lars(Optimizer):
with torch.enable_grad(): with torch.enable_grad():
loss = closure() loss = closure()
device = self.param_groups[0]["params"][0].device device = self.param_groups[0]['params'][0].device
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
# exclude scaling for params with 0 weight decay # exclude scaling for params with 0 weight decay
@ -101,9 +101,9 @@ class Lars(Optimizer):
continue continue
grad = p.grad grad = p.grad
# apply LARS scaling, LARC clipping, weight decay # apply LARS LR adaptation, LARC clipping, weight decay
# ref: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py # ref: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
if weight_decay != 0 or group['always_scale']: if weight_decay != 0 or group['always_adapt']:
w_norm = p.norm(2.0) w_norm = p.norm(2.0)
g_norm = grad.norm(2.0) g_norm = grad.norm(2.0)
trust_ratio = trust_coeff * w_norm / (g_norm + w_norm * weight_decay + eps) trust_ratio = trust_coeff * w_norm / (g_norm + w_norm * weight_decay + eps)
@ -113,7 +113,7 @@ class Lars(Optimizer):
torch.where(g_norm > 0, trust_ratio, one_tensor), torch.where(g_norm > 0, trust_ratio, one_tensor),
one_tensor, one_tensor,
) )
if group['larc']: if group['trust_clip']:
trust_ratio = torch.minimum(trust_ratio / group['lr'], one_tensor) trust_ratio = torch.minimum(trust_ratio / group['lr'], one_tensor)
grad.add(p, alpha=weight_decay) grad.add(p, alpha=weight_decay)
grad.mul_(trust_ratio) grad.mul_(trust_ratio)

@ -87,42 +87,39 @@ class MADGRAD(torch.optim.Optimizer):
"""Performs a single optimization step. """Performs a single optimization step.
Arguments: Arguments:
closure (callable, optional): A closure that reevaluates the model closure (callable, optional): A closure that reevaluates the model and returns the loss.
and returns the loss.
""" """
loss = None loss = None
if closure is not None: if closure is not None:
with torch.enable_grad(): with torch.enable_grad():
loss = closure() loss = closure()
step = self.state.setdefault('step', 0) # k
for group in self.param_groups: for group in self.param_groups:
eps = group["eps"] eps = group['eps']
lr = group["lr"] + eps lr = group['lr'] + eps
weight_decay = group["weight_decay"] weight_decay = group['weight_decay']
momentum = group["momentum"] momentum = group['momentum']
ck = 1 - momentum ck = 1 - momentum
lamb = lr * math.sqrt(step + 1)
for p in group["params"]: for p in group["params"]:
if p.grad is None: if p.grad is None:
continue continue
grad = p.grad grad = p.grad
state = self.state[p]
if "grad_sum_sq" not in state:
state["grad_sum_sq"] = torch.zeros_like(p)
state["s"] = torch.zeros_like(p)
if momentum != 0:
state["x0"] = torch.clone(p).detach()
if momentum != 0.0 and grad.is_sparse: if momentum != 0.0 and grad.is_sparse:
raise RuntimeError("momentum != 0 is not compatible with sparse gradients") raise RuntimeError("momentum != 0 is not compatible with sparse gradients")
grad_sum_sq = state["grad_sum_sq"] state = self.state[p]
s = state["s"] if len(state) == 0:
state['step'] = 0
state['grad_sum_sq'] = torch.zeros_like(p)
state['s'] = torch.zeros_like(p)
if momentum != 0:
state['x0'] = torch.clone(p).detach()
state['step'] += 1
grad_sum_sq = state['grad_sum_sq']
s = state['s']
lamb = lr * math.sqrt(state['step'])
# Apply weight decay # Apply weight decay
if weight_decay != 0: if weight_decay != 0:
@ -166,7 +163,7 @@ class MADGRAD(torch.optim.Optimizer):
rms = grad_sum_sq.pow(1 / 3).add_(eps) rms = grad_sum_sq.pow(1 / 3).add_(eps)
x0 = p.addcdiv(s, rms, value=1) x0 = p.addcdiv(s, rms, value=1)
else: else:
x0 = state["x0"] x0 = state['x0']
# Accumulate second moments # Accumulate second moments
grad_sum_sq.addcmul_(grad, grad, value=lamb) grad_sum_sq.addcmul_(grad, grad, value=lamb)
@ -184,5 +181,4 @@ class MADGRAD(torch.optim.Optimizer):
# p is a moving average of z # p is a moving average of z
p.mul_(1 - ck).add_(z, alpha=ck) p.mul_(1 - ck).add_(z, alpha=ck)
self.state['step'] += 1
return loss return loss

@ -164,12 +164,14 @@ def create_optimizer_v2(
optimizer = Adafactor(parameters, **opt_args) optimizer = Adafactor(parameters, **opt_args)
elif opt_lower == 'lamb': elif opt_lower == 'lamb':
optimizer = Lamb(parameters, **opt_args) optimizer = Lamb(parameters, **opt_args)
elif opt_lower == 'lambc':
optimizer = Lamb(parameters, trust_clip=True, **opt_args)
elif opt_lower == 'larc': elif opt_lower == 'larc':
optimizer = Lars(parameters, momentum=momentum, larc=True, **opt_args) optimizer = Lars(parameters, momentum=momentum, trust_clip=True, **opt_args)
elif opt_lower == 'lars': elif opt_lower == 'lars':
optimizer = Lars(parameters, momentum=momentum, **opt_args) optimizer = Lars(parameters, momentum=momentum, **opt_args)
elif opt_lower == 'nlarc': elif opt_lower == 'nlarc':
optimizer = Lars(parameters, momentum=momentum, larc=True, nesterov=True, **opt_args) optimizer = Lars(parameters, momentum=momentum, trust_clip=True, nesterov=True, **opt_args)
elif opt_lower == 'nlars': elif opt_lower == 'nlars':
optimizer = Lars(parameters, momentum=momentum, nesterov=True, **opt_args) optimizer = Lars(parameters, momentum=momentum, nesterov=True, **opt_args)
elif opt_lower == 'madgrad': elif opt_lower == 'madgrad':

Loading…
Cancel
Save