|
|
|
@ -53,7 +53,13 @@ class MADGRAD(torch.optim.Optimizer):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self, params: _params_t, lr: float = 1e-2, momentum: float = 0.9, weight_decay: float = 0, eps: float = 1e-6,
|
|
|
|
|
self,
|
|
|
|
|
params: _params_t,
|
|
|
|
|
lr: float = 1e-2,
|
|
|
|
|
momentum: float = 0.9,
|
|
|
|
|
weight_decay: float = 0,
|
|
|
|
|
eps: float = 1e-6,
|
|
|
|
|
decoupled_decay: bool = False,
|
|
|
|
|
):
|
|
|
|
|
if momentum < 0 or momentum >= 1:
|
|
|
|
|
raise ValueError(f"Momentum {momentum} must be in the range [0,1]")
|
|
|
|
@ -64,7 +70,8 @@ class MADGRAD(torch.optim.Optimizer):
|
|
|
|
|
if eps < 0:
|
|
|
|
|
raise ValueError(f"Eps must be non-negative")
|
|
|
|
|
|
|
|
|
|
defaults = dict(lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay)
|
|
|
|
|
defaults = dict(
|
|
|
|
|
lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay, decoupled_decay=decoupled_decay)
|
|
|
|
|
super().__init__(params, defaults)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
@ -95,7 +102,7 @@ class MADGRAD(torch.optim.Optimizer):
|
|
|
|
|
for group in self.param_groups:
|
|
|
|
|
eps = group["eps"]
|
|
|
|
|
lr = group["lr"] + eps
|
|
|
|
|
decay = group["weight_decay"]
|
|
|
|
|
weight_decay = group["weight_decay"]
|
|
|
|
|
momentum = group["momentum"]
|
|
|
|
|
|
|
|
|
|
ck = 1 - momentum
|
|
|
|
@ -120,11 +127,13 @@ class MADGRAD(torch.optim.Optimizer):
|
|
|
|
|
s = state["s"]
|
|
|
|
|
|
|
|
|
|
# Apply weight decay
|
|
|
|
|
if decay != 0:
|
|
|
|
|
if grad.is_sparse:
|
|
|
|
|
raise RuntimeError("weight_decay option is not compatible with sparse gradients")
|
|
|
|
|
|
|
|
|
|
grad.add_(p.data, alpha=decay)
|
|
|
|
|
if weight_decay != 0:
|
|
|
|
|
if group['decoupled_decay']:
|
|
|
|
|
p.data.mul_(1.0 - group['lr'] * weight_decay)
|
|
|
|
|
else:
|
|
|
|
|
if grad.is_sparse:
|
|
|
|
|
raise RuntimeError("weight_decay option is not compatible with sparse gradients")
|
|
|
|
|
grad.add_(p.data, alpha=weight_decay)
|
|
|
|
|
|
|
|
|
|
if grad.is_sparse:
|
|
|
|
|
grad = grad.coalesce()
|
|
|
|
|