@ -15,9 +15,11 @@ except ImportError:
class ApexScaler :
class ApexScaler :
state_dict_key = " amp "
state_dict_key = " amp "
def __call__ ( self , loss , optimizer ):
def __call__ ( self , loss , optimizer , clip_grad = None , parameters = None ):
with amp . scale_loss ( loss , optimizer ) as scaled_loss :
with amp . scale_loss ( loss , optimizer ) as scaled_loss :
scaled_loss . backward ( )
scaled_loss . backward ( )
if clip_grad :
torch . nn . utils . clip_grad_norm_ ( amp . master_params ( optimizer ) , clip_grad )
optimizer . step ( )
optimizer . step ( )
def state_dict ( self ) :
def state_dict ( self ) :
@ -35,8 +37,12 @@ class NativeScaler:
def __init__ ( self ) :
def __init__ ( self ) :
self . _scaler = torch . cuda . amp . GradScaler ( )
self . _scaler = torch . cuda . amp . GradScaler ( )
def __call__ ( self , loss , optimizer ):
def __call__ ( self , loss , optimizer , clip_grad = None , parameters = None ):
self . _scaler . scale ( loss ) . backward ( )
self . _scaler . scale ( loss ) . backward ( )
if clip_grad :
assert parameters is not None
self . _scaler . unscale_ ( optimizer ) # unscale the gradients of optimizer's assigned params in-place
torch . nn . utils . clip_grad_norm_ ( parameters , clip_grad )
self . _scaler . step ( optimizer )
self . _scaler . step ( optimizer )
self . _scaler . update ( )
self . _scaler . update ( )