From a6af48be64620f2be49e2aa4b04c1a6cbe4a4198 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 17 Aug 2021 22:19:27 -0700 Subject: [PATCH] add madgradw optimizer --- tests/test_optim.py | 2 +- timm/optim/madgrad.py | 25 +++++++++++++++++-------- timm/optim/optim_factory.py | 2 ++ 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/tests/test_optim.py b/tests/test_optim.py index 5a0a677c..eacc8e29 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -490,7 +490,7 @@ def test_lamb(optimizer): _test_model(optimizer, dict(lr=1e-3)) -@pytest.mark.parametrize('optimizer', ['madgrad']) +@pytest.mark.parametrize('optimizer', ['madgrad', 'madgradw']) def test_madgrad(optimizer): _test_basic_cases( lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) diff --git a/timm/optim/madgrad.py b/timm/optim/madgrad.py index f9ab24e3..7f8d73e8 100644 --- a/timm/optim/madgrad.py +++ b/timm/optim/madgrad.py @@ -53,7 +53,13 @@ class MADGRAD(torch.optim.Optimizer): """ def __init__( - self, params: _params_t, lr: float = 1e-2, momentum: float = 0.9, weight_decay: float = 0, eps: float = 1e-6, + self, + params: _params_t, + lr: float = 1e-2, + momentum: float = 0.9, + weight_decay: float = 0, + eps: float = 1e-6, + decoupled_decay: bool = False, ): if momentum < 0 or momentum >= 1: raise ValueError(f"Momentum {momentum} must be in the range [0,1]") @@ -64,7 +70,8 @@ class MADGRAD(torch.optim.Optimizer): if eps < 0: raise ValueError(f"Eps must be non-negative") - defaults = dict(lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay) + defaults = dict( + lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay, decoupled_decay=decoupled_decay) super().__init__(params, defaults) @property @@ -95,7 +102,7 @@ class MADGRAD(torch.optim.Optimizer): for group in self.param_groups: eps = group["eps"] lr = group["lr"] + eps - decay = group["weight_decay"] + weight_decay = group["weight_decay"] momentum = group["momentum"] ck = 1 - momentum @@ -120,11 +127,13 @@ class MADGRAD(torch.optim.Optimizer): s = state["s"] # Apply weight decay - if decay != 0: - if grad.is_sparse: - raise RuntimeError("weight_decay option is not compatible with sparse gradients") - - grad.add_(p.data, alpha=decay) + if weight_decay != 0: + if group['decoupled_decay']: + p.data.mul_(1.0 - group['lr'] * weight_decay) + else: + if grad.is_sparse: + raise RuntimeError("weight_decay option is not compatible with sparse gradients") + grad.add_(p.data, alpha=weight_decay) if grad.is_sparse: grad = grad.coalesce() diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py index c4a3c101..2157f73d 100644 --- a/timm/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -165,6 +165,8 @@ def create_optimizer_v2( optimizer = Lamb(parameters, **opt_args) elif opt_lower == 'madgrad': optimizer = MADGRAD(parameters, momentum=momentum, **opt_args) + elif opt_lower == 'madgradw': + optimizer = MADGRAD(parameters, momentum=momentum, decoupled_decay=True, **opt_args) elif opt_lower == 'novograd' or opt_lower == 'nvnovograd': optimizer = NvNovoGrad(parameters, **opt_args) elif opt_lower == 'rmsprop':