From 20d66beead659c82d4bd7358ae3929ad74e5a0d8 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 14 May 2019 18:30:46 -0700 Subject: [PATCH] Move RMSpropTF another step closer to Tensorflow impl * init square_avg with one instead of zero as per TF * match TF order of ops for square_avg accumulation * move LR scaling to momentum buffer accumulator as per TF * add decoupled weight decay flag (not in TF) --- optim/rmsprop_tf.py | 43 ++++++++++++++++++++++++++++++------------- 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/optim/rmsprop_tf.py b/optim/rmsprop_tf.py index b88238ea..b4298734 100644 --- a/optim/rmsprop_tf.py +++ b/optim/rmsprop_tf.py @@ -19,16 +19,20 @@ class RMSpropTF(Optimizer): parameter groups lr (float, optional): learning rate (default: 1e-2) momentum (float, optional): momentum factor (default: 0) - alpha (float, optional): smoothing constant (default: 0.99) + alpha (float, optional): smoothing (decay) constant (default: 0.9) eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) + numerical stability (default: 1e-10) centered (bool, optional) : if ``True``, compute the centered RMSProp, the gradient is normalized by an estimation of its variance weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101 + lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer + update as per defaults in Tensorflow """ - def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False): + def __init__(self, params, lr=1e-2, alpha=0.9, eps=1e-10, weight_decay=0, momentum=0., centered=False, + decoupled_decay=False, lr_in_momentum=True): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: @@ -40,7 +44,8 @@ class RMSpropTF(Optimizer): if not 0.0 <= alpha: raise ValueError("Invalid alpha value: {}".format(alpha)) - defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay) + defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay, + decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum) super(RMSpropTF, self).__init__(params, defaults) def __setstate__(self, state): @@ -72,33 +77,45 @@ class RMSpropTF(Optimizer): # State initialization if len(state) == 0: state['step'] = 0 - state['square_avg'] = torch.zeros_like(p.data) + state['square_avg'] = torch.ones_like(p.data) # PyTorch inits to zero if group['momentum'] > 0: state['momentum_buffer'] = torch.zeros_like(p.data) if group['centered']: state['grad_avg'] = torch.zeros_like(p.data) square_avg = state['square_avg'] - alpha = group['alpha'] + one_minus_alpha = 1. - group['alpha'] state['step'] += 1 if group['weight_decay'] != 0: - grad = grad.add(group['weight_decay'], p.data) + if group['decoupled_decay']: + p.data.add_(-group['weight_decay'], p.data) + else: + grad = grad.add(group['weight_decay'], p.data) - square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) + # Tensorflow order of ops for updating squared avg + square_avg.add_(one_minus_alpha, grad.pow(2) - square_avg) + # square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) # PyTorch original if group['centered']: grad_avg = state['grad_avg'] - grad_avg.mul_(alpha).add_(1 - alpha, grad) - avg = square_avg.addcmul(-1, grad_avg, grad_avg).add(group['eps']).sqrt_() + grad_avg.add_(one_minus_alpha, grad - grad_avg) + # grad_avg.mul_(alpha).add_(1 - alpha, grad) # PyTorch original + avg = square_avg.addcmul(-1, grad_avg, grad_avg).add(group['eps']).sqrt_() # eps moved in sqrt else: - avg = square_avg.add(group['eps']).sqrt_() + avg = square_avg.add(group['eps']).sqrt_() # eps moved in sqrt if group['momentum'] > 0: buf = state['momentum_buffer'] - buf.mul_(group['momentum']).addcdiv_(grad, avg) - p.data.add_(-group['lr'], buf) + # Tensorflow accumulates the LR scaling in the momentum buffer + if group['lr_in_momentum']: + buf.mul_(group['momentum']).addcdiv_(group['lr'], grad, avg) + p.data.add_(-buf) + else: + # PyTorch scales the param update by LR + buf.mul_(group['momentum']).addcdiv_(grad, avg) + p.data.add_(-group['lr'], buf) else: p.data.addcdiv_(-group['lr'], grad, avg)