Add missing leaky_relu layer factory defn, update Apex/Native loss scaler interfaces to support unscaled grad clipping. Bump ver to 0.2.2 for pending release.

pull/256/head
Ross Wightman 4 years ago
parent 186075ef03
commit fcb6258877

@ -46,6 +46,7 @@ _ACT_LAYER_DEFAULT = dict(
mish=Mish,
relu=nn.ReLU,
relu6=nn.ReLU6,
leaky_relu=nn.LeakyReLU,
elu=nn.ELU,
prelu=nn.PReLU,
celu=nn.CELU,

@ -15,9 +15,11 @@ except ImportError:
class ApexScaler:
state_dict_key = "amp"
def __call__(self, loss, optimizer):
def __call__(self, loss, optimizer, clip_grad=None, parameters=None):
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
if clip_grad:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), clip_grad)
optimizer.step()
def state_dict(self):
@ -35,8 +37,12 @@ class NativeScaler:
def __init__(self):
self._scaler = torch.cuda.amp.GradScaler()
def __call__(self, loss, optimizer):
def __call__(self, loss, optimizer, clip_grad=None, parameters=None):
self._scaler.scale(loss).backward()
if clip_grad:
assert parameters is not None
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
self._scaler.step(optimizer)
self._scaler.update()

@ -1 +1 @@
__version__ = '0.2.1'
__version__ = '0.2.2'

Loading…
Cancel
Save