Merge branch 'master' into bits_and_tpu

pull/1239/head
Ross Wightman 3 years ago
commit f4fb068b11

@ -490,6 +490,33 @@ def test_lamb(optimizer):
_test_model(optimizer, dict(lr=1e-3)) _test_model(optimizer, dict(lr=1e-3))
@pytest.mark.parametrize('optimizer', ['lars', 'larc', 'nlars', 'nlarc'])
def test_lars(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=1e-3),
optimizer,
lr=1e-1)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=1e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=1e-3), optimizer)
)
_test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
)
_test_model(optimizer, dict(lr=1e-3))
@pytest.mark.parametrize('optimizer', ['madgrad', 'madgradw']) @pytest.mark.parametrize('optimizer', ['madgrad', 'madgradw'])
def test_madgrad(optimizer): def test_madgrad(optimizer):
_test_basic_cases( _test_basic_cases(

@ -68,6 +68,7 @@ class AdaBelief(Optimizer):
for group in self.param_groups: for group in self.param_groups:
group.setdefault('amsgrad', False) group.setdefault('amsgrad', False)
@torch.no_grad()
def reset(self): def reset(self):
for group in self.param_groups: for group in self.param_groups:
for p in group['params']: for p in group['params']:
@ -77,14 +78,15 @@ class AdaBelief(Optimizer):
# State initialization # State initialization
state['step'] = 0 state['step'] = 0
# Exponential moving average of gradient values # Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data) state['exp_avg'] = torch.zeros_like(p)
# Exponential moving average of squared gradient values # Exponential moving average of squared gradient values
state['exp_avg_var'] = torch.zeros_like(p.data) state['exp_avg_var'] = torch.zeros_like(p)
if amsgrad: if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values # Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_var'] = torch.zeros_like(p.data) state['max_exp_avg_var'] = torch.zeros_like(p)
@torch.no_grad()
def step(self, closure=None): def step(self, closure=None):
"""Performs a single optimization step. """Performs a single optimization step.
Arguments: Arguments:
@ -93,50 +95,47 @@ class AdaBelief(Optimizer):
""" """
loss = None loss = None
if closure is not None: if closure is not None:
loss = closure() with torch.enable_grad():
loss = closure()
for group in self.param_groups: for group in self.param_groups:
for p in group['params']: for p in group['params']:
if p.grad is None: if p.grad is None:
continue continue
grad = p.grad
# cast data type if grad.dtype in {torch.float16, torch.bfloat16}:
half_precision = False grad = grad.float()
if p.data.dtype == torch.float16:
half_precision = True
p.data = p.data.float()
p.grad = p.grad.float()
grad = p.grad.data
if grad.is_sparse: if grad.is_sparse:
raise RuntimeError( raise RuntimeError(
'AdaBelief does not support sparse gradients, please consider SparseAdam instead') 'AdaBelief does not support sparse gradients, please consider SparseAdam instead')
amsgrad = group['amsgrad']
state = self.state[p] p_fp32 = p
if p.dtype in {torch.float16, torch.bfloat16}:
p_fp32 = p_fp32.float()
amsgrad = group['amsgrad']
beta1, beta2 = group['betas'] beta1, beta2 = group['betas']
state = self.state[p]
# State initialization # State initialization
if len(state) == 0: if len(state) == 0:
state['step'] = 0 state['step'] = 0
# Exponential moving average of gradient values # Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data) state['exp_avg'] = torch.zeros_like(p_fp32)
# Exponential moving average of squared gradient values # Exponential moving average of squared gradient values
state['exp_avg_var'] = torch.zeros_like(p.data) state['exp_avg_var'] = torch.zeros_like(p_fp32)
if amsgrad: if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values # Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_var'] = torch.zeros_like(p.data) state['max_exp_avg_var'] = torch.zeros_like(p_fp32)
# perform weight decay, check if decoupled weight decay # perform weight decay, check if decoupled weight decay
if group['decoupled_decay']: if group['decoupled_decay']:
if not group['fixed_decay']: if not group['fixed_decay']:
p.data.mul_(1.0 - group['lr'] * group['weight_decay']) p_fp32.mul_(1.0 - group['lr'] * group['weight_decay'])
else: else:
p.data.mul_(1.0 - group['weight_decay']) p_fp32.mul_(1.0 - group['weight_decay'])
else: else:
if group['weight_decay'] != 0: if group['weight_decay'] != 0:
grad.add_(p.data, alpha=group['weight_decay']) grad.add_(p_fp32, alpha=group['weight_decay'])
# get current state variable # get current state variable
exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var'] exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']
@ -164,7 +163,7 @@ class AdaBelief(Optimizer):
if not group['rectify']: if not group['rectify']:
# Default update # Default update
step_size = group['lr'] / bias_correction1 step_size = group['lr'] / bias_correction1
p.data.addcdiv_(exp_avg, denom, value=-step_size) p_fp32.addcdiv_(exp_avg, denom, value=-step_size)
else: else:
# Rectified update, forked from RAdam # Rectified update, forked from RAdam
buffered = group['buffer'][int(state['step'] % 10)] buffered = group['buffer'][int(state['step'] % 10)]
@ -192,12 +191,11 @@ class AdaBelief(Optimizer):
if num_sma >= 5: if num_sma >= 5:
denom = exp_avg_var.sqrt().add_(group['eps']) denom = exp_avg_var.sqrt().add_(group['eps'])
p.data.addcdiv_(exp_avg, denom, value=-step_size * group['lr']) p_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
elif step_size > 0: elif step_size > 0:
p.data.add_(exp_avg, alpha=-step_size * group['lr']) p_fp32.add_(exp_avg, alpha=-step_size * group['lr'])
if half_precision: if p.dtype in {torch.float16, torch.bfloat16}:
p.data = p.data.half() p.copy_(p_fp32)
p.grad = p.grad.half()
return loss return loss

@ -76,6 +76,7 @@ class Adafactor(torch.optim.Optimizer):
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
return torch.mul(r_factor, c_factor) return torch.mul(r_factor, c_factor)
@torch.no_grad()
def step(self, closure=None): def step(self, closure=None):
"""Performs a single optimization step. """Performs a single optimization step.
Arguments: Arguments:
@ -83,22 +84,22 @@ class Adafactor(torch.optim.Optimizer):
""" """
loss = None loss = None
if closure is not None: if closure is not None:
loss = closure() with torch.enable_grad():
loss = closure()
for group in self.param_groups: for group in self.param_groups:
for p in group['params']: for p in group['params']:
if p.grad is None: if p.grad is None:
continue continue
grad = p.grad.data grad = p.grad
if grad.dtype in {torch.float16, torch.bfloat16}: if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float() grad = grad.float()
if grad.is_sparse: if grad.is_sparse:
raise RuntimeError('Adafactor does not support sparse gradients.') raise RuntimeError('Adafactor does not support sparse gradients.')
state = self.state[p] state = self.state[p]
grad_shape = grad.shape
factored, use_first_moment = self._get_options(group, grad_shape) factored, use_first_moment = self._get_options(group, grad.shape)
# State Initialization # State Initialization
if len(state) == 0: if len(state) == 0:
state['step'] = 0 state['step'] = 0
@ -107,8 +108,8 @@ class Adafactor(torch.optim.Optimizer):
# Exponential moving average of gradient values # Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(grad) state['exp_avg'] = torch.zeros_like(grad)
if factored: if factored:
state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).to(grad) state['exp_avg_sq_row'] = torch.zeros(grad.shape[:-1]).to(grad)
state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) state['exp_avg_sq_col'] = torch.zeros(grad.shape[:-2] + grad.shape[-1:]).to(grad)
else: else:
state['exp_avg_sq'] = torch.zeros_like(grad) state['exp_avg_sq'] = torch.zeros_like(grad)
@ -122,12 +123,12 @@ class Adafactor(torch.optim.Optimizer):
else: else:
state['exp_avg_sq'] = state['exp_avg_sq'].to(grad) state['exp_avg_sq'] = state['exp_avg_sq'].to(grad)
p_data_fp32 = p.data p_fp32 = p
if p.data.dtype in {torch.float16, torch.bfloat16}: if p.dtype in {torch.float16, torch.bfloat16}:
p_data_fp32 = p_data_fp32.float() p_fp32 = p_fp32.float()
state['step'] += 1 state['step'] += 1
state['RMS'] = self._rms(p_data_fp32) state['RMS'] = self._rms(p_fp32)
lr_t = self._get_lr(group, state) lr_t = self._get_lr(group, state)
beta2t = 1.0 - math.pow(state['step'], group['decay_rate']) beta2t = 1.0 - math.pow(state['step'], group['decay_rate'])
@ -157,11 +158,10 @@ class Adafactor(torch.optim.Optimizer):
update = exp_avg update = exp_avg
if group['weight_decay'] != 0: if group['weight_decay'] != 0:
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * lr_t) p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * lr_t)
p_data_fp32.add_(-update) p_fp32.add_(-update)
if p.dtype in {torch.float16, torch.bfloat16}:
if p.data.dtype in {torch.float16, torch.bfloat16}: p.copy_(p_fp32)
p.data.copy_(p_data_fp32)
return loss return loss

@ -26,12 +26,13 @@ def projection(p, grad, perturb, delta: float, wd_ratio: float, eps: float):
wd = 1. wd = 1.
expand_size = (-1,) + (1,) * (len(p.shape) - 1) expand_size = (-1,) + (1,) * (len(p.shape) - 1)
for view_func in [_channel_view, _layer_view]: for view_func in [_channel_view, _layer_view]:
param_view = view_func(p.data) param_view = view_func(p)
grad_view = view_func(grad) grad_view = view_func(grad)
cosine_sim = F.cosine_similarity(grad_view, param_view, dim=1, eps=eps).abs_() cosine_sim = F.cosine_similarity(grad_view, param_view, dim=1, eps=eps).abs_()
# FIXME this is a problem for PyTorch XLA
if cosine_sim.max() < delta / math.sqrt(param_view.size(1)): if cosine_sim.max() < delta / math.sqrt(param_view.size(1)):
p_n = p.data / param_view.norm(p=2, dim=1).add_(eps).reshape(expand_size) p_n = p / param_view.norm(p=2, dim=1).add_(eps).reshape(expand_size)
perturb -= p_n * view_func(p_n * perturb).sum(dim=1).reshape(expand_size) perturb -= p_n * view_func(p_n * perturb).sum(dim=1).reshape(expand_size)
wd = wd_ratio wd = wd_ratio
return perturb, wd return perturb, wd
@ -47,17 +48,19 @@ class AdamP(Optimizer):
delta=delta, wd_ratio=wd_ratio, nesterov=nesterov) delta=delta, wd_ratio=wd_ratio, nesterov=nesterov)
super(AdamP, self).__init__(params, defaults) super(AdamP, self).__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None): def step(self, closure=None):
loss = None loss = None
if closure is not None: if closure is not None:
loss = closure() with torch.enable_grad():
loss = closure()
for group in self.param_groups: for group in self.param_groups:
for p in group['params']: for p in group['params']:
if p.grad is None: if p.grad is None:
continue continue
grad = p.grad.data grad = p.grad
beta1, beta2 = group['betas'] beta1, beta2 = group['betas']
nesterov = group['nesterov'] nesterov = group['nesterov']
@ -66,8 +69,8 @@ class AdamP(Optimizer):
# State initialization # State initialization
if len(state) == 0: if len(state) == 0:
state['step'] = 0 state['step'] = 0
state['exp_avg'] = torch.zeros_like(p.data) state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p.data) state['exp_avg_sq'] = torch.zeros_like(p)
# Adam # Adam
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
@ -94,9 +97,9 @@ class AdamP(Optimizer):
# Weight decay # Weight decay
if group['weight_decay'] > 0: if group['weight_decay'] > 0:
p.data.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio) p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio)
# Step # Step
p.data.add_(perturb, alpha=-step_size) p.add_(perturb, alpha=-step_size)
return loss return loss

@ -55,6 +55,7 @@ class AdamW(Optimizer):
for group in self.param_groups: for group in self.param_groups:
group.setdefault('amsgrad', False) group.setdefault('amsgrad', False)
@torch.no_grad()
def step(self, closure=None): def step(self, closure=None):
"""Performs a single optimization step. """Performs a single optimization step.
@ -64,7 +65,8 @@ class AdamW(Optimizer):
""" """
loss = None loss = None
if closure is not None: if closure is not None:
loss = closure() with torch.enable_grad():
loss = closure()
for group in self.param_groups: for group in self.param_groups:
for p in group['params']: for p in group['params']:
@ -75,7 +77,7 @@ class AdamW(Optimizer):
p.data.mul_(1 - group['lr'] * group['weight_decay']) p.data.mul_(1 - group['lr'] * group['weight_decay'])
# Perform optimization step # Perform optimization step
grad = p.grad.data grad = p.grad
if grad.is_sparse: if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
amsgrad = group['amsgrad'] amsgrad = group['amsgrad']
@ -86,12 +88,12 @@ class AdamW(Optimizer):
if len(state) == 0: if len(state) == 0:
state['step'] = 0 state['step'] = 0
# Exponential moving average of gradient values # Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data) state['exp_avg'] = torch.zeros_like(p)
# Exponential moving average of squared gradient values # Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data) state['exp_avg_sq'] = torch.zeros_like(p)
if amsgrad: if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values # Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p.data) state['max_exp_avg_sq'] = torch.zeros_like(p)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
if amsgrad: if amsgrad:
@ -115,6 +117,6 @@ class AdamW(Optimizer):
step_size = group['lr'] / bias_correction1 step_size = group['lr'] / bias_correction1
p.data.addcdiv_(exp_avg, denom, value=-step_size) p.addcdiv_(exp_avg, denom, value=-step_size)
return loss return loss

@ -5,10 +5,14 @@ This optimizer code was adapted from the following (starting with latest)
* https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py * https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
* https://github.com/cybertronai/pytorch-lamb * https://github.com/cybertronai/pytorch-lamb
Use FusedLamb if you can. The reason for including this variant of Lamb is to have a version that is Use FusedLamb if you can (GPU). The reason for including this variant of Lamb is to have a version that is
similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install APEX for whatever reason. similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install/use APEX.
In addition to some cleanup, this Lamb impl has been modified to support PyTorch XLA and has been tested on TPU.
Original copyrights for above sources are below. Original copyrights for above sources are below.
Modifications Copyright 2021 Ross Wightman
""" """
# Copyright (c) 2021, Habana Labs Ltd. All rights reserved. # Copyright (c) 2021, Habana Labs Ltd. All rights reserved.
@ -60,8 +64,7 @@ class Lamb(Optimizer):
LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
Arguments: Arguments:
params (iterable): iterable of parameters to optimize or dicts defining params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
parameter groups.
lr (float, optional): learning rate. (default: 1e-3) lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its norm. (default: (0.9, 0.999)) running averages of gradient and its norm. (default: (0.9, 0.999))
@ -72,8 +75,7 @@ class Lamb(Optimizer):
calculating running averages of gradient. (default: True) calculating running averages of gradient. (default: True)
set_grad_none (bool, optional): whether set grad to None when zero_grad() set_grad_none (bool, optional): whether set grad to None when zero_grad()
method is called. (default: True) method is called. (default: True)
max_grad_norm (float, optional): value used to clip global grad norm max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0)
(default: 1.0)
use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0 use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False) weight decay parameter (default: False)
@ -91,25 +93,26 @@ class Lamb(Optimizer):
grad_averaging=grad_averaging, max_grad_norm=max_grad_norm, use_nvlamb=use_nvlamb) grad_averaging=grad_averaging, max_grad_norm=max_grad_norm, use_nvlamb=use_nvlamb)
super().__init__(params, defaults) super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None): def step(self, closure=None):
"""Performs a single optimization step. """Performs a single optimization step.
Arguments: Arguments:
closure (callable, optional): A closure that reevaluates the model closure (callable, optional): A closure that reevaluates the model
and returns the loss. and returns the loss.
""" """
device = self.param_groups[0]["params"][0].device
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
loss = None loss = None
if closure is not None: if closure is not None:
loss = closure() with torch.enable_grad():
loss = closure()
device = self.param_groups[0]["params"][0].device
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
global_grad_norm = torch.zeros(1, device=device) global_grad_norm = torch.zeros(1, device=device)
for group in self.param_groups: for group in self.param_groups:
for p in group['params']: for p in group['params']:
if p.grad is None: if p.grad is None:
continue continue
grad = p.grad.data grad = p.grad
if grad.is_sparse: if grad.is_sparse:
raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
global_grad_norm.add_(grad.pow(2).sum()) global_grad_norm.add_(grad.pow(2).sum())
@ -145,15 +148,15 @@ class Lamb(Optimizer):
for p in group['params']: for p in group['params']:
if p.grad is None: if p.grad is None:
continue continue
grad = p.grad.data.div_(clip_global_grad_norm) grad = p.grad.div_(clip_global_grad_norm)
state = self.state[p] state = self.state[p]
# State initialization # State initialization
if len(state) == 0: if len(state) == 0:
# Exponential moving average of gradient values # Exponential moving average of gradient valuesa
state['exp_avg'] = torch.zeros_like(p.data) state['exp_avg'] = torch.zeros_like(p)
# Exponential moving average of squared gradient values # Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data) state['exp_avg_sq'] = torch.zeros_like(p)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
@ -166,20 +169,21 @@ class Lamb(Optimizer):
weight_decay = group['weight_decay'] weight_decay = group['weight_decay']
if weight_decay != 0: if weight_decay != 0:
update.add_(p.data, alpha=weight_decay) update.add_(p, alpha=weight_decay)
trust_ratio = one_tensor
if weight_decay != 0 or group['use_nvlamb']: if weight_decay != 0 or group['use_nvlamb']:
# Layer adaptation. By default, skip layer adaptation on parameters that are # Layer adaptation. By default, skip layer adaptation on parameters that are
# excluded from weight decay, unless use_nvlamb == True, then always enabled. # excluded from weight decay, unless use_nvlamb == True, then always enabled.
w_norm = p.data.norm(2.0) w_norm = p.norm(2.0)
g_norm = update.norm(2.0) g_norm = update.norm(2.0)
# FIXME nested where required since logical and/or not working in PT XLA
trust_ratio = torch.where( trust_ratio = torch.where(
w_norm > 0, w_norm > 0,
torch.where(g_norm > 0, w_norm / g_norm, one_tensor), torch.where(g_norm > 0, w_norm / g_norm, one_tensor),
one_tensor, one_tensor,
) )
update.mul_(trust_ratio) update.mul_(trust_ratio)
p.data.add_(update, alpha=-group['lr'])
p.add_(update, alpha=-group['lr'])
return loss return loss

@ -0,0 +1,136 @@
""" PyTorch LARS / LARC Optimizer
An implementation of LARS (SGD) + LARC in PyTorch
Based on:
* PyTorch SGD: https://github.com/pytorch/pytorch/blob/1.7/torch/optim/sgd.py#L100
* NVIDIA APEX LARC: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
Additional cleanup and modifications to properly support PyTorch XLA.
Copyright 2021 Ross Wightman
"""
import torch
from torch.optim.optimizer import Optimizer, required
class Lars(Optimizer):
""" LARS for PyTorch
Paper: `Large batch training of Convolutional Networks` - https://arxiv.org/pdf/1708.03888.pdf
Args:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
momentum (float, optional): momentum factor (default: 0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
dampening (float, optional): dampening for momentum (default: 0)
nesterov (bool, optional): enables Nesterov momentum (default: False)
trust_coeff (float): trust coefficient for computing adaptive lr / trust_ratio (default: 0.001)
eps (float): eps for division denominator (default: 1e-8)
larc (bool): enable LARC clipping (default: False)
always_scale (bool): always apply LARS scaling, otherwise only when group weight_decay != 0 (default: False)
"""
def __init__(
self,
params,
lr=required,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
trust_coeff=0.001,
eps=1e-8,
larc=False,
always_scale=False,
):
if lr is not required and lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if momentum < 0.0:
raise ValueError(f"Invalid momentum value: {momentum}")
if weight_decay < 0.0:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
defaults = dict(
lr=lr,
momentum=momentum,
dampening=dampening,
weight_decay=weight_decay,
nesterov=nesterov,
trust_coeff=trust_coeff,
eps=eps,
larc=larc,
always_scale=always_scale,
)
super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("nesterov", False)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
device = self.param_groups[0]["params"][0].device
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
# exclude scaling for params with 0 weight decay
for group in self.param_groups:
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
trust_coeff = group['trust_coeff']
eps = group['eps']
for p in group['params']:
if p.grad is None:
continue
grad = p.grad
# apply LARS scaling, LARC clipping, weight decay
# ref: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
if weight_decay != 0 or group['always_scale']:
w_norm = p.norm(2.0)
g_norm = grad.norm(2.0)
trust_ratio = trust_coeff * w_norm / (g_norm + w_norm * weight_decay + eps)
# FIXME nested where required since logical and/or not working in PT XLA
trust_ratio = torch.where(
w_norm > 0,
torch.where(g_norm > 0, trust_ratio, one_tensor),
one_tensor,
)
if group['larc']:
trust_ratio = torch.minimum(trust_ratio / group['lr'], one_tensor)
grad.add(p, alpha=weight_decay)
grad.mul_(trust_ratio)
# apply SGD update https://github.com/pytorch/pytorch/blob/1.7/torch/optim/sgd.py#L100
if momentum != 0:
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
buf = param_state['momentum_buffer'] = torch.clone(grad).detach()
else:
buf = param_state['momentum_buffer']
buf.mul_(momentum).add_(grad, alpha=1. - dampening)
if nesterov:
grad = grad.add(buf, alpha=momentum)
else:
grad = buf
p.add_(grad, alpha=-group['lr'])
return loss

@ -27,22 +27,24 @@ class Lookahead(Optimizer):
for group in self._base_optimizer.param_groups: for group in self._base_optimizer.param_groups:
group.setdefault(name, default) group.setdefault(name, default)
@torch.no_grad()
def update_slow(self, group): def update_slow(self, group):
for fast_p in group["params"]: for fast_p in group["params"]:
if fast_p.grad is None: if fast_p.grad is None:
continue continue
param_state = self._base_optimizer.state[fast_p] param_state = self._base_optimizer.state[fast_p]
if 'lookahead_slow_buff' not in param_state: if 'lookahead_slow_buff' not in param_state:
param_state['lookahead_slow_buff'] = torch.empty_like(fast_p.data) param_state['lookahead_slow_buff'] = torch.empty_like(fast_p)
param_state['lookahead_slow_buff'].copy_(fast_p.data) param_state['lookahead_slow_buff'].copy_(fast_p)
slow = param_state['lookahead_slow_buff'] slow = param_state['lookahead_slow_buff']
slow.add_(fast_p.data - slow, alpha=group['lookahead_alpha']) slow.add_(fast_p - slow, alpha=group['lookahead_alpha'])
fast_p.data.copy_(slow) fast_p.copy_(slow)
def sync_lookahead(self): def sync_lookahead(self):
for group in self._base_optimizer.param_groups: for group in self._base_optimizer.param_groups:
self.update_slow(group) self.update_slow(group)
@torch.no_grad()
def step(self, closure=None): def step(self, closure=None):
loss = self._base_optimizer.step(closure) loss = self._base_optimizer.step(closure)
for group in self._base_optimizer.param_groups: for group in self._base_optimizer.param_groups:

@ -82,6 +82,7 @@ class MADGRAD(torch.optim.Optimizer):
def supports_flat_params(self) -> bool: def supports_flat_params(self) -> bool:
return True return True
@torch.no_grad()
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
"""Performs a single optimization step. """Performs a single optimization step.
@ -91,13 +92,10 @@ class MADGRAD(torch.optim.Optimizer):
""" """
loss = None loss = None
if closure is not None: if closure is not None:
loss = closure() with torch.enable_grad():
loss = closure()
# step counter must be stored in state to ensure correct behavior under step = self.state.setdefault('step', 0) # k
# optimizer sharding
if 'k' not in self.state:
self.state['k'] = torch.tensor([0], dtype=torch.long)
k = self.state['k'].item()
for group in self.param_groups: for group in self.param_groups:
eps = group["eps"] eps = group["eps"]
@ -106,19 +104,19 @@ class MADGRAD(torch.optim.Optimizer):
momentum = group["momentum"] momentum = group["momentum"]
ck = 1 - momentum ck = 1 - momentum
lamb = lr * math.pow(k + 1, 0.5) lamb = lr * math.sqrt(step + 1)
for p in group["params"]: for p in group["params"]:
if p.grad is None: if p.grad is None:
continue continue
grad = p.grad.data grad = p.grad
state = self.state[p] state = self.state[p]
if "grad_sum_sq" not in state: if "grad_sum_sq" not in state:
state["grad_sum_sq"] = torch.zeros_like(p.data).detach() state["grad_sum_sq"] = torch.zeros_like(p)
state["s"] = torch.zeros_like(p.data).detach() state["s"] = torch.zeros_like(p)
if momentum != 0: if momentum != 0:
state["x0"] = torch.clone(p.data).detach() state["x0"] = torch.clone(p).detach()
if momentum != 0.0 and grad.is_sparse: if momentum != 0.0 and grad.is_sparse:
raise RuntimeError("momentum != 0 is not compatible with sparse gradients") raise RuntimeError("momentum != 0 is not compatible with sparse gradients")
@ -129,11 +127,11 @@ class MADGRAD(torch.optim.Optimizer):
# Apply weight decay # Apply weight decay
if weight_decay != 0: if weight_decay != 0:
if group['decoupled_decay']: if group['decoupled_decay']:
p.data.mul_(1.0 - group['lr'] * weight_decay) p.mul_(1.0 - group['lr'] * weight_decay)
else: else:
if grad.is_sparse: if grad.is_sparse:
raise RuntimeError("weight_decay option is not compatible with sparse gradients") raise RuntimeError("weight_decay option is not compatible with sparse gradients")
grad.add_(p.data, alpha=weight_decay) grad.add_(p, alpha=weight_decay)
if grad.is_sparse: if grad.is_sparse:
grad = grad.coalesce() grad = grad.coalesce()
@ -161,12 +159,12 @@ class MADGRAD(torch.optim.Optimizer):
p_kp1_masked_vals = x0_masked_vals.addcdiv(s_masked._values(), rms_masked_vals, value=-1) p_kp1_masked_vals = x0_masked_vals.addcdiv(s_masked._values(), rms_masked_vals, value=-1)
# Copy updated masked p to dense p using an add operation # Copy updated masked p to dense p using an add operation
p_masked._values().add_(p_kp1_masked_vals, alpha=-1) p_masked._values().add_(p_kp1_masked_vals, alpha=-1)
p.data.add_(p_masked, alpha=-1) p.add_(p_masked, alpha=-1)
else: else:
if momentum == 0: if momentum == 0:
# Compute x_0 from other known quantities # Compute x_0 from other known quantities
rms = grad_sum_sq.pow(1 / 3).add_(eps) rms = grad_sum_sq.pow(1 / 3).add_(eps)
x0 = p.data.addcdiv(s, rms, value=1) x0 = p.addcdiv(s, rms, value=1)
else: else:
x0 = state["x0"] x0 = state["x0"]
@ -175,16 +173,16 @@ class MADGRAD(torch.optim.Optimizer):
rms = grad_sum_sq.pow(1 / 3).add_(eps) rms = grad_sum_sq.pow(1 / 3).add_(eps)
# Update s # Update s
s.data.add_(grad, alpha=lamb) s.add_(grad, alpha=lamb)
# Step # Step
if momentum == 0: if momentum == 0:
p.data.copy_(x0.addcdiv(s, rms, value=-1)) p.copy_(x0.addcdiv(s, rms, value=-1))
else: else:
z = x0.addcdiv(s, rms, value=-1) z = x0.addcdiv(s, rms, value=-1)
# p is a moving average of z # p is a moving average of z
p.data.mul_(1 - ck).add_(z, alpha=ck) p.mul_(1 - ck).add_(z, alpha=ck)
self.state['k'] += 1 self.state['step'] += 1
return loss return loss

@ -1,3 +1,5 @@
import math
import torch import torch
from torch.optim.optimizer import Optimizer from torch.optim.optimizer import Optimizer
@ -33,6 +35,7 @@ class Nadam(Optimizer):
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, schedule_decay=schedule_decay) lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, schedule_decay=schedule_decay)
super(Nadam, self).__init__(params, defaults) super(Nadam, self).__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None): def step(self, closure=None):
"""Performs a single optimization step. """Performs a single optimization step.
@ -42,21 +45,22 @@ class Nadam(Optimizer):
""" """
loss = None loss = None
if closure is not None: if closure is not None:
loss = closure() with torch.enable_grad():
loss = closure()
for group in self.param_groups: for group in self.param_groups:
for p in group['params']: for p in group['params']:
if p.grad is None: if p.grad is None:
continue continue
grad = p.grad.data grad = p.grad
state = self.state[p] state = self.state[p]
# State initialization # State initialization
if len(state) == 0: if len(state) == 0:
state['step'] = 0 state['step'] = 0
state['m_schedule'] = 1. state['m_schedule'] = 1.
state['exp_avg'] = torch.zeros_like(p.data) state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p.data) state['exp_avg_sq'] = torch.zeros_like(p)
# Warming momentum schedule # Warming momentum schedule
m_schedule = state['m_schedule'] m_schedule = state['m_schedule']
@ -66,9 +70,10 @@ class Nadam(Optimizer):
eps = group['eps'] eps = group['eps']
state['step'] += 1 state['step'] += 1
t = state['step'] t = state['step']
bias_correction2 = 1 - beta2 ** t
if group['weight_decay'] != 0: if group['weight_decay'] != 0:
grad = grad.add(p.data, alpha=group['weight_decay']) grad = grad.add(p, alpha=group['weight_decay'])
momentum_cache_t = beta1 * (1. - 0.5 * (0.96 ** (t * schedule_decay))) momentum_cache_t = beta1 * (1. - 0.5 * (0.96 ** (t * schedule_decay)))
momentum_cache_t_1 = beta1 * (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay))) momentum_cache_t_1 = beta1 * (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay)))
@ -79,10 +84,9 @@ class Nadam(Optimizer):
# Decay the first and second moment running average coefficient # Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1. - beta1) exp_avg.mul_(beta1).add_(grad, alpha=1. - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1. - beta2) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1. - beta2)
exp_avg_sq_prime = exp_avg_sq / (1. - beta2 ** t)
denom = exp_avg_sq_prime.sqrt_().add_(eps)
p.data.addcdiv_(grad, denom, value=-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new)) denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
p.data.addcdiv_(exp_avg, denom, value=-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next)) p.addcdiv_(grad, denom, value=-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new))
p.addcdiv_(exp_avg, denom, value=-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next))
return loss return loss

@ -51,6 +51,7 @@ class NvNovoGrad(Optimizer):
for group in self.param_groups: for group in self.param_groups:
group.setdefault('amsgrad', False) group.setdefault('amsgrad', False)
@torch.no_grad()
def step(self, closure=None): def step(self, closure=None):
"""Performs a single optimization step. """Performs a single optimization step.
@ -60,13 +61,14 @@ class NvNovoGrad(Optimizer):
""" """
loss = None loss = None
if closure is not None: if closure is not None:
loss = closure() with torch.enable_grad():
loss = closure()
for group in self.param_groups: for group in self.param_groups:
for p in group['params']: for p in group['params']:
if p.grad is None: if p.grad is None:
continue continue
grad = p.grad.data grad = p.grad
if grad.is_sparse: if grad.is_sparse:
raise RuntimeError('Sparse gradients are not supported.') raise RuntimeError('Sparse gradients are not supported.')
amsgrad = group['amsgrad'] amsgrad = group['amsgrad']
@ -77,7 +79,7 @@ class NvNovoGrad(Optimizer):
if len(state) == 0: if len(state) == 0:
state['step'] = 0 state['step'] = 0
# Exponential moving average of gradient values # Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data) state['exp_avg'] = torch.zeros_like(p)
# Exponential moving average of squared gradient values # Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
if amsgrad: if amsgrad:
@ -108,11 +110,11 @@ class NvNovoGrad(Optimizer):
grad.div_(denom) grad.div_(denom)
if group['weight_decay'] != 0: if group['weight_decay'] != 0:
grad.add_(p.data, alpha=group['weight_decay']) grad.add_(p, alpha=group['weight_decay'])
if group['grad_averaging']: if group['grad_averaging']:
grad.mul_(1 - beta1) grad.mul_(1 - beta1)
exp_avg.mul_(beta1).add_(grad) exp_avg.mul_(beta1).add_(grad)
p.data.add_(exp_avg, alpha=-group['lr']) p.add_(exp_avg, alpha=-group['lr'])
return loss return loss

@ -12,6 +12,7 @@ from .adafactor import Adafactor
from .adahessian import Adahessian from .adahessian import Adahessian
from .adamp import AdamP from .adamp import AdamP
from .lamb import Lamb from .lamb import Lamb
from .lars import Lars
from .lookahead import Lookahead from .lookahead import Lookahead
from .madgrad import MADGRAD from .madgrad import MADGRAD
from .nadam import Nadam from .nadam import Nadam
@ -163,6 +164,14 @@ def create_optimizer_v2(
optimizer = Adafactor(parameters, **opt_args) optimizer = Adafactor(parameters, **opt_args)
elif opt_lower == 'lamb': elif opt_lower == 'lamb':
optimizer = Lamb(parameters, **opt_args) optimizer = Lamb(parameters, **opt_args)
elif opt_lower == 'larc':
optimizer = Lars(parameters, momentum=momentum, larc=True, **opt_args)
elif opt_lower == 'lars':
optimizer = Lars(parameters, momentum=momentum, **opt_args)
elif opt_lower == 'nlarc':
optimizer = Lars(parameters, momentum=momentum, larc=True, nesterov=True, **opt_args)
elif opt_lower == 'nlars':
optimizer = Lars(parameters, momentum=momentum, nesterov=True, **opt_args)
elif opt_lower == 'madgrad': elif opt_lower == 'madgrad':
optimizer = MADGRAD(parameters, momentum=momentum, **opt_args) optimizer = MADGRAD(parameters, momentum=momentum, **opt_args)
elif opt_lower == 'madgradw': elif opt_lower == 'madgradw':

@ -18,31 +18,33 @@ class RAdam(Optimizer):
def __setstate__(self, state): def __setstate__(self, state):
super(RAdam, self).__setstate__(state) super(RAdam, self).__setstate__(state)
@torch.no_grad()
def step(self, closure=None): def step(self, closure=None):
loss = None loss = None
if closure is not None: if closure is not None:
loss = closure() with torch.enable_grad():
loss = closure()
for group in self.param_groups: for group in self.param_groups:
for p in group['params']: for p in group['params']:
if p.grad is None: if p.grad is None:
continue continue
grad = p.grad.data.float() grad = p.grad.float()
if grad.is_sparse: if grad.is_sparse:
raise RuntimeError('RAdam does not support sparse gradients') raise RuntimeError('RAdam does not support sparse gradients')
p_data_fp32 = p.data.float() p_fp32 = p.float()
state = self.state[p] state = self.state[p]
if len(state) == 0: if len(state) == 0:
state['step'] = 0 state['step'] = 0
state['exp_avg'] = torch.zeros_like(p_data_fp32) state['exp_avg'] = torch.zeros_like(p_fp32)
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) state['exp_avg_sq'] = torch.zeros_like(p_fp32)
else: else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) state['exp_avg'] = state['exp_avg'].type_as(p_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_fp32)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas'] beta1, beta2 = group['betas']
@ -73,15 +75,15 @@ class RAdam(Optimizer):
buffered[2] = step_size buffered[2] = step_size
if group['weight_decay'] != 0: if group['weight_decay'] != 0:
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr']) p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * group['lr'])
# more conservative since it's an approximated value # more conservative since it's an approximated value
if num_sma >= 5: if num_sma >= 5:
denom = exp_avg_sq.sqrt().add_(group['eps']) denom = exp_avg_sq.sqrt().add_(group['eps'])
p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size) p_fp32.addcdiv_(exp_avg, denom, value=-step_size)
else: else:
p_data_fp32.add_(exp_avg, alpha=-step_size) p_fp32.add_(exp_avg, alpha=-step_size)
p.data.copy_(p_data_fp32) p.copy_(p_fp32)
return loss return loss

@ -4,7 +4,7 @@ Originally cut & paste from PyTorch RMSProp
https://github.com/pytorch/pytorch/blob/063946d2b3f3f1e953a2a3b54e0b34f1393de295/torch/optim/rmsprop.py https://github.com/pytorch/pytorch/blob/063946d2b3f3f1e953a2a3b54e0b34f1393de295/torch/optim/rmsprop.py
Licensed under BSD-Clause 3 (ish), https://github.com/pytorch/pytorch/blob/master/LICENSE Licensed under BSD-Clause 3 (ish), https://github.com/pytorch/pytorch/blob/master/LICENSE
Modifications Copyright 2020 Ross Wightman Modifications Copyright 2021 Ross Wightman
""" """
import torch import torch
@ -69,6 +69,7 @@ class RMSpropTF(Optimizer):
group.setdefault('momentum', 0) group.setdefault('momentum', 0)
group.setdefault('centered', False) group.setdefault('centered', False)
@torch.no_grad()
def step(self, closure=None): def step(self, closure=None):
"""Performs a single optimization step. """Performs a single optimization step.
@ -78,13 +79,14 @@ class RMSpropTF(Optimizer):
""" """
loss = None loss = None
if closure is not None: if closure is not None:
loss = closure() with torch.enable_grad():
loss = closure()
for group in self.param_groups: for group in self.param_groups:
for p in group['params']: for p in group['params']:
if p.grad is None: if p.grad is None:
continue continue
grad = p.grad.data grad = p.grad
if grad.is_sparse: if grad.is_sparse:
raise RuntimeError('RMSprop does not support sparse gradients') raise RuntimeError('RMSprop does not support sparse gradients')
state = self.state[p] state = self.state[p]
@ -92,11 +94,11 @@ class RMSpropTF(Optimizer):
# State initialization # State initialization
if len(state) == 0: if len(state) == 0:
state['step'] = 0 state['step'] = 0
state['square_avg'] = torch.ones_like(p.data) # PyTorch inits to zero state['square_avg'] = torch.ones_like(p) # PyTorch inits to zero
if group['momentum'] > 0: if group['momentum'] > 0:
state['momentum_buffer'] = torch.zeros_like(p.data) state['momentum_buffer'] = torch.zeros_like(p)
if group['centered']: if group['centered']:
state['grad_avg'] = torch.zeros_like(p.data) state['grad_avg'] = torch.zeros_like(p)
square_avg = state['square_avg'] square_avg = state['square_avg']
one_minus_alpha = 1. - group['alpha'] one_minus_alpha = 1. - group['alpha']
@ -105,9 +107,9 @@ class RMSpropTF(Optimizer):
if group['weight_decay'] != 0: if group['weight_decay'] != 0:
if group['decoupled_decay']: if group['decoupled_decay']:
p.data.mul_(1. - group['lr'] * group['weight_decay']) p.mul_(1. - group['lr'] * group['weight_decay'])
else: else:
grad = grad.add(p.data, alpha=group['weight_decay']) grad = grad.add(p, alpha=group['weight_decay'])
# Tensorflow order of ops for updating squared avg # Tensorflow order of ops for updating squared avg
square_avg.add_(grad.pow(2) - square_avg, alpha=one_minus_alpha) square_avg.add_(grad.pow(2) - square_avg, alpha=one_minus_alpha)
@ -126,12 +128,12 @@ class RMSpropTF(Optimizer):
# Tensorflow accumulates the LR scaling in the momentum buffer # Tensorflow accumulates the LR scaling in the momentum buffer
if group['lr_in_momentum']: if group['lr_in_momentum']:
buf.mul_(group['momentum']).addcdiv_(grad, avg, value=group['lr']) buf.mul_(group['momentum']).addcdiv_(grad, avg, value=group['lr'])
p.data.add_(-buf) p.add_(-buf)
else: else:
# PyTorch scales the param update by LR # PyTorch scales the param update by LR
buf.mul_(group['momentum']).addcdiv_(grad, avg) buf.mul_(group['momentum']).addcdiv_(grad, avg)
p.data.add_(buf, alpha=-group['lr']) p.add_(buf, alpha=-group['lr'])
else: else:
p.data.addcdiv_(grad, avg, value=-group['lr']) p.addcdiv_(grad, avg, value=-group['lr'])
return loss return loss

@ -24,10 +24,12 @@ class SGDP(Optimizer):
nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio) nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio)
super(SGDP, self).__init__(params, defaults) super(SGDP, self).__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None): def step(self, closure=None):
loss = None loss = None
if closure is not None: if closure is not None:
loss = closure() with torch.enable_grad():
loss = closure()
for group in self.param_groups: for group in self.param_groups:
weight_decay = group['weight_decay'] weight_decay = group['weight_decay']
@ -38,12 +40,12 @@ class SGDP(Optimizer):
for p in group['params']: for p in group['params']:
if p.grad is None: if p.grad is None:
continue continue
grad = p.grad.data grad = p.grad
state = self.state[p] state = self.state[p]
# State initialization # State initialization
if len(state) == 0: if len(state) == 0:
state['momentum'] = torch.zeros_like(p.data) state['momentum'] = torch.zeros_like(p)
# SGD # SGD
buf = state['momentum'] buf = state['momentum']
@ -60,9 +62,9 @@ class SGDP(Optimizer):
# Weight decay # Weight decay
if weight_decay != 0: if weight_decay != 0:
p.data.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum)) p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum))
# Step # Step
p.data.add_(d_p, alpha=-group['lr']) p.add_(d_p, alpha=-group['lr'])
return loss return loss

Loading…
Cancel
Save