From 42c1f0cf6c1dd63f46b80d8950579bad67f197c7 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 18 Aug 2021 21:05:34 -0700 Subject: [PATCH 1/3] Fix lars tests --- tests/test_optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_optim.py b/tests/test_optim.py index e1b78482..559fe8d9 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -499,7 +499,7 @@ def test_lars(optimizer): lambda weight, bias: create_optimizer_v2( _build_params_dict(weight, bias, lr=1e-3), optimizer, - lr=1e-1) + lr=1e-3) ) _test_basic_cases( lambda weight, bias: create_optimizer_v2( From c207e02782d5a0b71f59f58d1a4382dfc232d54f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 18 Aug 2021 22:20:35 -0700 Subject: [PATCH 2/3] MOAR optimizer changes. Woo! --- README.md | 9 +++++++++ tests/test_optim.py | 8 ++++---- timm/optim/lamb.py | 21 ++++++++++--------- timm/optim/lars.py | 28 +++++++++++++------------- timm/optim/madgrad.py | 40 +++++++++++++++++-------------------- timm/optim/optim_factory.py | 6 ++++-- 6 files changed, 61 insertions(+), 51 deletions(-) diff --git a/README.md b/README.md index f227118e..fda37ca0 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,15 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor ## What's New +### Aug 18, 2021 +* Optimizer bonanza! + * Add LAMB and LARS optimizers, incl trust ratio clipping options. Tweaked to work properly in PyTorch XLA (tested on TPUs w/ `timm bits` [branch](https://github.com/rwightman/pytorch-image-models/tree/bits_and_tpu/timm/bits)) + * Add MADGRAD from FB research w/ a few tweaks (decoupled decay option, step handling that works with PyTorch XLA) + * Some cleanup on all optimizers and factory. No more `.data`, a bit more consistency, unit tests for all! + * SGDP and AdamP still won't work with PyTorch XLA but others should (have yet to test Adabelief, Adafactor, Adahessian myself). +* EfficientNet-V2 XL TF ported weights added, but they don't validate well in PyTorch (L is better). The pre-processing for the V2 TF training is a bit diff and the fine-tuned 21k -> 1k weights are very sensitive and less robust than the 1k weights. +* Added PyTorch trained EfficientNet-V2 'Tiny' w/ GlobalContext attn weights. Only .1-.2 top-1 better than the SE so more of a curiosity for those interested. + ### July 12, 2021 * Add XCiT models from [official facebook impl](https://github.com/facebookresearch/xcit). Contributed by [Alexander Soare](https://github.com/alexander-soare) diff --git a/tests/test_optim.py b/tests/test_optim.py index 559fe8d9..c12e33cc 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -463,26 +463,26 @@ def test_adafactor(optimizer): _test_model(optimizer, dict(lr=5e-2)) -@pytest.mark.parametrize('optimizer', ['lamb']) +@pytest.mark.parametrize('optimizer', ['lamb', 'lambc']) def test_lamb(optimizer): _test_basic_cases( lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) ) _test_basic_cases( lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=3e-3), + _build_params_dict(weight, bias, lr=1e-3), optimizer, lr=1e-3) ) _test_basic_cases( lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), + _build_params_dict_single(weight, bias, lr=1e-3), optimizer, lr=1e-3) ) _test_basic_cases( lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), optimizer) + _build_params_dict_single(weight, bias, lr=1e-3), optimizer) ) _test_rosenbrock( lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) diff --git a/timm/optim/lamb.py b/timm/optim/lamb.py index a65b4e0f..12c7c49b 100644 --- a/timm/optim/lamb.py +++ b/timm/optim/lamb.py @@ -73,10 +73,9 @@ class Lamb(Optimizer): weight_decay (float, optional): weight decay (L2 penalty) (default: 0) grad_averaging (bool, optional): whether apply (1-beta2) to grad when calculating running averages of gradient. (default: True) - set_grad_none (bool, optional): whether set grad to None when zero_grad() - method is called. (default: True) max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0) - use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0 + trust_clip (bool): enable LAMBC trust ratio clipping (default: False) + always_adapt (boolean, optional): Apply adaptive learning rate to 0.0 weight decay parameter (default: False) .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: @@ -87,10 +86,11 @@ class Lamb(Optimizer): def __init__( self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-6, - weight_decay=0.01, grad_averaging=True, max_grad_norm=1.0, use_nvlamb=False): + weight_decay=0.01, grad_averaging=True, max_grad_norm=1.0, trust_clip=False, always_adapt=False): defaults = dict( lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay, - grad_averaging=grad_averaging, max_grad_norm=max_grad_norm, use_nvlamb=use_nvlamb) + grad_averaging=grad_averaging, max_grad_norm=max_grad_norm, + trust_clip=trust_clip, always_adapt=always_adapt) super().__init__(params, defaults) @torch.no_grad() @@ -105,7 +105,7 @@ class Lamb(Optimizer): with torch.enable_grad(): loss = closure() - device = self.param_groups[0]["params"][0].device + device = self.param_groups[0]['params'][0].device one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly global_grad_norm = torch.zeros(1, device=device) for group in self.param_groups: @@ -171,9 +171,9 @@ class Lamb(Optimizer): if weight_decay != 0: update.add_(p, alpha=weight_decay) - if weight_decay != 0 or group['use_nvlamb']: - # Layer adaptation. By default, skip layer adaptation on parameters that are - # excluded from weight decay, unless use_nvlamb == True, then always enabled. + if weight_decay != 0 or group['always_adapt']: + # Layer-wise LR adaptation. By default, skip adaptation on parameters that are + # excluded from weight decay, unless always_adapt == True, then always enabled. w_norm = p.norm(2.0) g_norm = update.norm(2.0) # FIXME nested where required since logical and/or not working in PT XLA @@ -182,6 +182,9 @@ class Lamb(Optimizer): torch.where(g_norm > 0, w_norm / g_norm, one_tensor), one_tensor, ) + if group['trust_clip']: + # LAMBC trust clipping, upper bound fixed at one + trust_ratio = torch.minimum(trust_ratio, one_tensor) update.mul_(trust_ratio) p.add_(update, alpha=-group['lr']) diff --git a/timm/optim/lars.py b/timm/optim/lars.py index 7c530e1a..958c2d0e 100644 --- a/timm/optim/lars.py +++ b/timm/optim/lars.py @@ -11,7 +11,7 @@ Additional cleanup and modifications to properly support PyTorch XLA. Copyright 2021 Ross Wightman """ import torch -from torch.optim.optimizer import Optimizer, required +from torch.optim.optimizer import Optimizer class Lars(Optimizer): @@ -21,31 +21,31 @@ class Lars(Optimizer): Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups. - lr (float, optional): learning rate. (default: 1e-3) + lr (float, optional): learning rate (default: 1.0). momentum (float, optional): momentum factor (default: 0) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) dampening (float, optional): dampening for momentum (default: 0) nesterov (bool, optional): enables Nesterov momentum (default: False) trust_coeff (float): trust coefficient for computing adaptive lr / trust_ratio (default: 0.001) eps (float): eps for division denominator (default: 1e-8) - larc (bool): enable LARC clipping (default: False) - always_scale (bool): always apply LARS scaling, otherwise only when group weight_decay != 0 (default: False) + trust_clip (bool): enable LARC trust ratio clipping (default: False) + always_adapt (bool): always apply LARS LR adapt, otherwise only when group weight_decay != 0 (default: False) """ def __init__( self, params, - lr=required, + lr=1.0, momentum=0, dampening=0, weight_decay=0, nesterov=False, trust_coeff=0.001, eps=1e-8, - larc=False, - always_scale=False, + trust_clip=False, + always_adapt=False, ): - if lr is not required and lr < 0.0: + if lr < 0.0: raise ValueError(f"Invalid learning rate: {lr}") if momentum < 0.0: raise ValueError(f"Invalid momentum value: {momentum}") @@ -62,8 +62,8 @@ class Lars(Optimizer): nesterov=nesterov, trust_coeff=trust_coeff, eps=eps, - larc=larc, - always_scale=always_scale, + trust_clip=trust_clip, + always_adapt=always_adapt, ) super().__init__(params, defaults) @@ -84,7 +84,7 @@ class Lars(Optimizer): with torch.enable_grad(): loss = closure() - device = self.param_groups[0]["params"][0].device + device = self.param_groups[0]['params'][0].device one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly # exclude scaling for params with 0 weight decay @@ -101,9 +101,9 @@ class Lars(Optimizer): continue grad = p.grad - # apply LARS scaling, LARC clipping, weight decay + # apply LARS LR adaptation, LARC clipping, weight decay # ref: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py - if weight_decay != 0 or group['always_scale']: + if weight_decay != 0 or group['always_adapt']: w_norm = p.norm(2.0) g_norm = grad.norm(2.0) trust_ratio = trust_coeff * w_norm / (g_norm + w_norm * weight_decay + eps) @@ -113,7 +113,7 @@ class Lars(Optimizer): torch.where(g_norm > 0, trust_ratio, one_tensor), one_tensor, ) - if group['larc']: + if group['trust_clip']: trust_ratio = torch.minimum(trust_ratio / group['lr'], one_tensor) grad.add(p, alpha=weight_decay) grad.mul_(trust_ratio) diff --git a/timm/optim/madgrad.py b/timm/optim/madgrad.py index 4d88b753..a76713bf 100644 --- a/timm/optim/madgrad.py +++ b/timm/optim/madgrad.py @@ -87,42 +87,39 @@ class MADGRAD(torch.optim.Optimizer): """Performs a single optimization step. Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. + closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: with torch.enable_grad(): loss = closure() - step = self.state.setdefault('step', 0) # k - for group in self.param_groups: - eps = group["eps"] - lr = group["lr"] + eps - weight_decay = group["weight_decay"] - momentum = group["momentum"] - + eps = group['eps'] + lr = group['lr'] + eps + weight_decay = group['weight_decay'] + momentum = group['momentum'] ck = 1 - momentum - lamb = lr * math.sqrt(step + 1) for p in group["params"]: if p.grad is None: continue grad = p.grad - state = self.state[p] - - if "grad_sum_sq" not in state: - state["grad_sum_sq"] = torch.zeros_like(p) - state["s"] = torch.zeros_like(p) - if momentum != 0: - state["x0"] = torch.clone(p).detach() - if momentum != 0.0 and grad.is_sparse: raise RuntimeError("momentum != 0 is not compatible with sparse gradients") - grad_sum_sq = state["grad_sum_sq"] - s = state["s"] + state = self.state[p] + if len(state) == 0: + state['step'] = 0 + state['grad_sum_sq'] = torch.zeros_like(p) + state['s'] = torch.zeros_like(p) + if momentum != 0: + state['x0'] = torch.clone(p).detach() + + state['step'] += 1 + grad_sum_sq = state['grad_sum_sq'] + s = state['s'] + lamb = lr * math.sqrt(state['step']) # Apply weight decay if weight_decay != 0: @@ -166,7 +163,7 @@ class MADGRAD(torch.optim.Optimizer): rms = grad_sum_sq.pow(1 / 3).add_(eps) x0 = p.addcdiv(s, rms, value=1) else: - x0 = state["x0"] + x0 = state['x0'] # Accumulate second moments grad_sum_sq.addcmul_(grad, grad, value=lamb) @@ -184,5 +181,4 @@ class MADGRAD(torch.optim.Optimizer): # p is a moving average of z p.mul_(1 - ck).add_(z, alpha=ck) - self.state['step'] += 1 return loss diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py index bc58014e..e1749156 100644 --- a/timm/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -164,12 +164,14 @@ def create_optimizer_v2( optimizer = Adafactor(parameters, **opt_args) elif opt_lower == 'lamb': optimizer = Lamb(parameters, **opt_args) + elif opt_lower == 'lambc': + optimizer = Lamb(parameters, trust_clip=True, **opt_args) elif opt_lower == 'larc': - optimizer = Lars(parameters, momentum=momentum, larc=True, **opt_args) + optimizer = Lars(parameters, momentum=momentum, trust_clip=True, **opt_args) elif opt_lower == 'lars': optimizer = Lars(parameters, momentum=momentum, **opt_args) elif opt_lower == 'nlarc': - optimizer = Lars(parameters, momentum=momentum, larc=True, nesterov=True, **opt_args) + optimizer = Lars(parameters, momentum=momentum, trust_clip=True, nesterov=True, **opt_args) elif opt_lower == 'nlars': optimizer = Lars(parameters, momentum=momentum, nesterov=True, **opt_args) elif opt_lower == 'madgrad': From a16a7538529e8f0e196257a708852ab9ea6ff997 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 18 Aug 2021 22:55:02 -0700 Subject: [PATCH 3/3] Add lamb/lars to optim init imports, remove stray comment --- timm/optim/__init__.py | 8 +++++--- timm/optim/lars.py | 1 - 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/timm/optim/__init__.py b/timm/optim/__init__.py index 2df9fb12..7ee4958e 100644 --- a/timm/optim/__init__.py +++ b/timm/optim/__init__.py @@ -1,7 +1,10 @@ -from .adamp import AdamP -from .adamw import AdamW +from .adabelief import AdaBelief from .adafactor import Adafactor from .adahessian import Adahessian +from .adamp import AdamP +from .adamw import AdamW +from .lamb import Lamb +from .lars import Lars from .lookahead import Lookahead from .madgrad import MADGRAD from .nadam import Nadam @@ -9,5 +12,4 @@ from .nvnovograd import NvNovoGrad from .radam import RAdam from .rmsprop_tf import RMSpropTF from .sgdp import SGDP -from .adabelief import AdaBelief from .optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs diff --git a/timm/optim/lars.py b/timm/optim/lars.py index 958c2d0e..98198e67 100644 --- a/timm/optim/lars.py +++ b/timm/optim/lars.py @@ -87,7 +87,6 @@ class Lars(Optimizer): device = self.param_groups[0]['params'][0].device one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly - # exclude scaling for params with 0 weight decay for group in self.param_groups: weight_decay = group['weight_decay'] momentum = group['momentum']