@ -19,16 +19,20 @@ class RMSpropTF(Optimizer):
parameter groups
lr ( float , optional ) : learning rate ( default : 1e-2 )
momentum ( float , optional ) : momentum factor ( default : 0 )
alpha ( float , optional ) : smoothing constant ( default : 0. 9 9)
alpha ( float , optional ) : smoothing ( decay ) constant ( default : 0. 9)
eps ( float , optional ) : term added to the denominator to improve
numerical stability ( default : 1e- 8 )
numerical stability ( default : 1e- 10 )
centered ( bool , optional ) : if ` ` True ` ` , compute the centered RMSProp ,
the gradient is normalized by an estimation of its variance
weight_decay ( float , optional ) : weight decay ( L2 penalty ) ( default : 0 )
decoupled_decay ( bool , optional ) : decoupled weight decay as per https : / / arxiv . org / abs / 1711.05101
lr_in_momentum ( bool , optional ) : learning rate scaling is included in the momentum buffer
update as per defaults in Tensorflow
"""
def __init__ ( self , params , lr = 1e-2 , alpha = 0.99 , eps = 1e-8 , weight_decay = 0 , momentum = 0 , centered = False ) :
def __init__ ( self , params , lr = 1e-2 , alpha = 0.9 , eps = 1e-10 , weight_decay = 0 , momentum = 0. , centered = False ,
decoupled_decay = False , lr_in_momentum = True ) :
if not 0.0 < = lr :
raise ValueError ( " Invalid learning rate: {} " . format ( lr ) )
if not 0.0 < = eps :
@ -40,7 +44,8 @@ class RMSpropTF(Optimizer):
if not 0.0 < = alpha :
raise ValueError ( " Invalid alpha value: {} " . format ( alpha ) )
defaults = dict ( lr = lr , momentum = momentum , alpha = alpha , eps = eps , centered = centered , weight_decay = weight_decay )
defaults = dict ( lr = lr , momentum = momentum , alpha = alpha , eps = eps , centered = centered , weight_decay = weight_decay ,
decoupled_decay = decoupled_decay , lr_in_momentum = lr_in_momentum )
super ( RMSpropTF , self ) . __init__ ( params , defaults )
def __setstate__ ( self , state ) :
@ -72,33 +77,45 @@ class RMSpropTF(Optimizer):
# State initialization
if len ( state ) == 0 :
state [ ' step ' ] = 0
state [ ' square_avg ' ] = torch . zero s_like( p . data )
state [ ' square_avg ' ] = torch . one s_like( p . data ) # PyTorch inits to zero
if group [ ' momentum ' ] > 0 :
state [ ' momentum_buffer ' ] = torch . zeros_like ( p . data )
if group [ ' centered ' ] :
state [ ' grad_avg ' ] = torch . zeros_like ( p . data )
square_avg = state [ ' square_avg ' ]
alpha = group [ ' alpha ' ]
one_minus_ alpha = 1. - group [ ' alpha ' ]
state [ ' step ' ] + = 1
if group [ ' weight_decay ' ] != 0 :
grad = grad . add ( group [ ' weight_decay ' ] , p . data )
if group [ ' decoupled_decay ' ] :
p . data . add_ ( - group [ ' weight_decay ' ] , p . data )
else :
grad = grad . add ( group [ ' weight_decay ' ] , p . data )
square_avg . mul_ ( alpha ) . addcmul_ ( 1 - alpha , grad , grad )
# Tensorflow order of ops for updating squared avg
square_avg . add_ ( one_minus_alpha , grad . pow ( 2 ) - square_avg )
# square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) # PyTorch original
if group [ ' centered ' ] :
grad_avg = state [ ' grad_avg ' ]
grad_avg . mul_ ( alpha ) . add_ ( 1 - alpha , grad )
avg = square_avg . addcmul ( - 1 , grad_avg , grad_avg ) . add ( group [ ' eps ' ] ) . sqrt_ ( )
grad_avg . add_ ( one_minus_alpha , grad - grad_avg )
# grad_avg.mul_(alpha).add_(1 - alpha, grad) # PyTorch original
avg = square_avg . addcmul ( - 1 , grad_avg , grad_avg ) . add ( group [ ' eps ' ] ) . sqrt_ ( ) # eps moved in sqrt
else :
avg = square_avg . add ( group [ ' eps ' ] ) . sqrt_ ( )
avg = square_avg . add ( group [ ' eps ' ] ) . sqrt_ ( ) # eps moved in sqrt
if group [ ' momentum ' ] > 0 :
buf = state [ ' momentum_buffer ' ]
buf . mul_ ( group [ ' momentum ' ] ) . addcdiv_ ( grad , avg )
p . data . add_ ( - group [ ' lr ' ] , buf )
# Tensorflow accumulates the LR scaling in the momentum buffer
if group [ ' lr_in_momentum ' ] :
buf . mul_ ( group [ ' momentum ' ] ) . addcdiv_ ( group [ ' lr ' ] , grad , avg )
p . data . add_ ( - buf )
else :
# PyTorch scales the param update by LR
buf . mul_ ( group [ ' momentum ' ] ) . addcdiv_ ( grad , avg )
p . data . add_ ( - group [ ' lr ' ] , buf )
else :
p . data . addcdiv_ ( - group [ ' lr ' ] , grad , avg )