Move RMSpropTF another step closer to Tensorflow impl

* init square_avg with one instead of zero as per TF
* match TF order of ops for square_avg accumulation
* move LR scaling to momentum buffer accumulator as per TF
* add decoupled weight decay flag (not in TF)
pull/2/head
Ross Wightman 6 years ago
parent 89147a91e6
commit 20d66beead

@ -19,16 +19,20 @@ class RMSpropTF(Optimizer):
parameter groups
lr (float, optional): learning rate (default: 1e-2)
momentum (float, optional): momentum factor (default: 0)
alpha (float, optional): smoothing constant (default: 0.99)
alpha (float, optional): smoothing (decay) constant (default: 0.9)
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
numerical stability (default: 1e-10)
centered (bool, optional) : if ``True``, compute the centered RMSProp,
the gradient is normalized by an estimation of its variance
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101
lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer
update as per defaults in Tensorflow
"""
def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False):
def __init__(self, params, lr=1e-2, alpha=0.9, eps=1e-10, weight_decay=0, momentum=0., centered=False,
decoupled_decay=False, lr_in_momentum=True):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
@ -40,7 +44,8 @@ class RMSpropTF(Optimizer):
if not 0.0 <= alpha:
raise ValueError("Invalid alpha value: {}".format(alpha))
defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay)
defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay,
decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum)
super(RMSpropTF, self).__init__(params, defaults)
def __setstate__(self, state):
@ -72,31 +77,43 @@ class RMSpropTF(Optimizer):
# State initialization
if len(state) == 0:
state['step'] = 0
state['square_avg'] = torch.zeros_like(p.data)
state['square_avg'] = torch.ones_like(p.data) # PyTorch inits to zero
if group['momentum'] > 0:
state['momentum_buffer'] = torch.zeros_like(p.data)
if group['centered']:
state['grad_avg'] = torch.zeros_like(p.data)
square_avg = state['square_avg']
alpha = group['alpha']
one_minus_alpha = 1. - group['alpha']
state['step'] += 1
if group['weight_decay'] != 0:
if group['decoupled_decay']:
p.data.add_(-group['weight_decay'], p.data)
else:
grad = grad.add(group['weight_decay'], p.data)
square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad)
# Tensorflow order of ops for updating squared avg
square_avg.add_(one_minus_alpha, grad.pow(2) - square_avg)
# square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) # PyTorch original
if group['centered']:
grad_avg = state['grad_avg']
grad_avg.mul_(alpha).add_(1 - alpha, grad)
avg = square_avg.addcmul(-1, grad_avg, grad_avg).add(group['eps']).sqrt_()
grad_avg.add_(one_minus_alpha, grad - grad_avg)
# grad_avg.mul_(alpha).add_(1 - alpha, grad) # PyTorch original
avg = square_avg.addcmul(-1, grad_avg, grad_avg).add(group['eps']).sqrt_() # eps moved in sqrt
else:
avg = square_avg.add(group['eps']).sqrt_()
avg = square_avg.add(group['eps']).sqrt_() # eps moved in sqrt
if group['momentum'] > 0:
buf = state['momentum_buffer']
# Tensorflow accumulates the LR scaling in the momentum buffer
if group['lr_in_momentum']:
buf.mul_(group['momentum']).addcdiv_(group['lr'], grad, avg)
p.data.add_(-buf)
else:
# PyTorch scales the param update by LR
buf.mul_(group['momentum']).addcdiv_(grad, avg)
p.data.add_(-group['lr'], buf)
else:

Loading…
Cancel
Save