Add Nvidia's NovogGrad impl from Jasper (cleaner/faster than current) and Apex Fused optimizers

pull/32/head
Ross Wightman 5 years ago
parent 3d9c8a6489
commit 64966f61f7

@ -3,5 +3,6 @@ from .rmsprop_tf import RMSpropTF
from .adamw import AdamW from .adamw import AdamW
from .radam import RAdam from .radam import RAdam
from .novograd import NovoGrad from .novograd import NovoGrad
from .nvnovograd import NvNovoGrad
from .lookahead import Lookahead from .lookahead import Lookahead
from .optim_factory import create_optimizer from .optim_factory import create_optimizer

@ -0,0 +1,118 @@
""" Nvidia NovoGrad Optimizer.
Original impl by Nvidia from Jasper example:
- https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper
Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks`
- https://arxiv.org/abs/1905.11286
"""
import torch
from torch.optim.optimizer import Optimizer
import math
class NvNovoGrad(Optimizer):
"""
Implements Novograd algorithm.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.95, 0.98))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
grad_averaging: gradient averaging
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
"""
def __init__(self, params, lr=1e-3, betas=(0.95, 0.98), eps=1e-8,
weight_decay=0, grad_averaging=False, amsgrad=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay,
grad_averaging=grad_averaging,
amsgrad=amsgrad)
super(NvNovoGrad, self).__init__(params, defaults)
def __setstate__(self, state):
super(NvNovoGrad, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Sparse gradients are not supported.')
amsgrad = group['amsgrad']
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
if amsgrad:
max_exp_avg_sq = state['max_exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
norm = torch.sum(torch.pow(grad, 2))
if exp_avg_sq == 0:
exp_avg_sq.copy_(norm)
else:
exp_avg_sq.mul_(beta2).add_(1 - beta2, norm)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
else:
denom = exp_avg_sq.sqrt().add_(group['eps'])
grad.div_(denom)
if group['weight_decay'] != 0:
grad.add_(group['weight_decay'], p.data)
if group['grad_averaging']:
grad.mul_(1 - beta1)
exp_avg.mul_(beta1).add_(grad)
p.data.add_(-group['lr'], exp_avg)
return loss

@ -1,5 +1,11 @@
import torch
from torch import optim as optim from torch import optim as optim
from timm.optim import Nadam, RMSpropTF, AdamW, RAdam, NovoGrad, Lookahead from timm.optim import Nadam, RMSpropTF, AdamW, RAdam, NovoGrad, NvNovoGrad, Lookahead
try:
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
has_apex = True
except ImportError:
has_apex = False
def add_weight_decay(model, weight_decay=1e-5, skip_list=()): def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
@ -20,9 +26,10 @@ def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
def create_optimizer(args, model, filter_bias_and_bn=True): def create_optimizer(args, model, filter_bias_and_bn=True):
opt_lower = args.opt.lower() opt_lower = args.opt.lower()
weight_decay = args.weight_decay weight_decay = args.weight_decay
if opt_lower == 'adamw' or opt_lower == 'radam': if 'adamw' in opt_lower or 'radam' in opt_lower:
# compensate for the way current AdamW and RAdam optimizers # Compensate for the way current AdamW and RAdam optimizers apply LR to the weight-decay
# apply the weight-decay # I don't believe they follow the paper or original Torch7 impl which schedules weight
# decay based on the ratio of current_lr/initial_lr
weight_decay /= args.lr weight_decay /= args.lr
if weight_decay and filter_bias_and_bn: if weight_decay and filter_bias_and_bn:
parameters = add_weight_decay(model, weight_decay) parameters = add_weight_decay(model, weight_decay)
@ -30,12 +37,14 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
else: else:
parameters = model.parameters() parameters = model.parameters()
if 'fused' in opt_lower:
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
opt_split = opt_lower.split('_') opt_split = opt_lower.split('_')
opt_lower = opt_split[-1] opt_lower = opt_split[-1]
if opt_lower == 'sgd': if opt_lower == 'sgd':
optimizer = optim.SGD( optimizer = optim.SGD(
parameters, lr=args.lr, parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True)
momentum=args.momentum, weight_decay=weight_decay, nesterov=True)
elif opt_lower == 'adam': elif opt_lower == 'adam':
optimizer = optim.Adam( optimizer = optim.Adam(
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
@ -61,6 +70,22 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
momentum=args.momentum, weight_decay=weight_decay) momentum=args.momentum, weight_decay=weight_decay)
elif opt_lower == 'novograd': elif opt_lower == 'novograd':
optimizer = NovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) optimizer = NovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
elif opt_lower == 'nvnovograd':
optimizer = NvNovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
elif opt_lower == 'fusedsgd':
optimizer = FusedSGD(
parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True)
elif opt_lower == 'fusedadam':
optimizer = FusedAdam(
parameters, lr=args.lr, adam_w_mode=False, weight_decay=weight_decay, eps=args.opt_eps)
elif opt_lower == 'fusedadamw':
optimizer = FusedAdam(
parameters, lr=args.lr, adam_w_mode=True, weight_decay=weight_decay, eps=args.opt_eps)
elif opt_lower == 'fusedlamb':
optimizer = FusedLAMB(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
elif opt_lower == 'fusednovograd':
optimizer = FusedNovoGrad(
parameters, lr=args.lr, betas=(0.95, 0.98), weight_decay=weight_decay, eps=args.opt_eps)
else: else:
assert False and "Invalid optimizer" assert False and "Invalid optimizer"
raise ValueError raise ValueError

Loading…
Cancel
Save