add madgradw optimizer

pull/813/head
Ross Wightman 3 years ago
parent 55fb5eedf6
commit a6af48be64

@ -490,7 +490,7 @@ def test_lamb(optimizer):
_test_model(optimizer, dict(lr=1e-3)) _test_model(optimizer, dict(lr=1e-3))
@pytest.mark.parametrize('optimizer', ['madgrad']) @pytest.mark.parametrize('optimizer', ['madgrad', 'madgradw'])
def test_madgrad(optimizer): def test_madgrad(optimizer):
_test_basic_cases( _test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)

@ -53,7 +53,13 @@ class MADGRAD(torch.optim.Optimizer):
""" """
def __init__( def __init__(
self, params: _params_t, lr: float = 1e-2, momentum: float = 0.9, weight_decay: float = 0, eps: float = 1e-6, self,
params: _params_t,
lr: float = 1e-2,
momentum: float = 0.9,
weight_decay: float = 0,
eps: float = 1e-6,
decoupled_decay: bool = False,
): ):
if momentum < 0 or momentum >= 1: if momentum < 0 or momentum >= 1:
raise ValueError(f"Momentum {momentum} must be in the range [0,1]") raise ValueError(f"Momentum {momentum} must be in the range [0,1]")
@ -64,7 +70,8 @@ class MADGRAD(torch.optim.Optimizer):
if eps < 0: if eps < 0:
raise ValueError(f"Eps must be non-negative") raise ValueError(f"Eps must be non-negative")
defaults = dict(lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay) defaults = dict(
lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay, decoupled_decay=decoupled_decay)
super().__init__(params, defaults) super().__init__(params, defaults)
@property @property
@ -95,7 +102,7 @@ class MADGRAD(torch.optim.Optimizer):
for group in self.param_groups: for group in self.param_groups:
eps = group["eps"] eps = group["eps"]
lr = group["lr"] + eps lr = group["lr"] + eps
decay = group["weight_decay"] weight_decay = group["weight_decay"]
momentum = group["momentum"] momentum = group["momentum"]
ck = 1 - momentum ck = 1 - momentum
@ -120,11 +127,13 @@ class MADGRAD(torch.optim.Optimizer):
s = state["s"] s = state["s"]
# Apply weight decay # Apply weight decay
if decay != 0: if weight_decay != 0:
if grad.is_sparse: if group['decoupled_decay']:
raise RuntimeError("weight_decay option is not compatible with sparse gradients") p.data.mul_(1.0 - group['lr'] * weight_decay)
else:
grad.add_(p.data, alpha=decay) if grad.is_sparse:
raise RuntimeError("weight_decay option is not compatible with sparse gradients")
grad.add_(p.data, alpha=weight_decay)
if grad.is_sparse: if grad.is_sparse:
grad = grad.coalesce() grad = grad.coalesce()

@ -165,6 +165,8 @@ def create_optimizer_v2(
optimizer = Lamb(parameters, **opt_args) optimizer = Lamb(parameters, **opt_args)
elif opt_lower == 'madgrad': elif opt_lower == 'madgrad':
optimizer = MADGRAD(parameters, momentum=momentum, **opt_args) optimizer = MADGRAD(parameters, momentum=momentum, **opt_args)
elif opt_lower == 'madgradw':
optimizer = MADGRAD(parameters, momentum=momentum, decoupled_decay=True, **opt_args)
elif opt_lower == 'novograd' or opt_lower == 'nvnovograd': elif opt_lower == 'novograd' or opt_lower == 'nvnovograd':
optimizer = NvNovoGrad(parameters, **opt_args) optimizer = NvNovoGrad(parameters, **opt_args)
elif opt_lower == 'rmsprop': elif opt_lower == 'rmsprop':

Loading…
Cancel
Save