From fcb6258877536910df4961f8193c444ba931fc65 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 2 Oct 2020 16:17:42 -0700 Subject: [PATCH] Add missing leaky_relu layer factory defn, update Apex/Native loss scaler interfaces to support unscaled grad clipping. Bump ver to 0.2.2 for pending release. --- timm/models/layers/create_act.py | 1 + timm/utils/cuda.py | 10 ++++++++-- timm/version.py | 2 +- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/timm/models/layers/create_act.py b/timm/models/layers/create_act.py index a0eec2ad..5bc4db99 100644 --- a/timm/models/layers/create_act.py +++ b/timm/models/layers/create_act.py @@ -46,6 +46,7 @@ _ACT_LAYER_DEFAULT = dict( mish=Mish, relu=nn.ReLU, relu6=nn.ReLU6, + leaky_relu=nn.LeakyReLU, elu=nn.ELU, prelu=nn.PReLU, celu=nn.CELU, diff --git a/timm/utils/cuda.py b/timm/utils/cuda.py index 695f40b1..d972002c 100644 --- a/timm/utils/cuda.py +++ b/timm/utils/cuda.py @@ -15,9 +15,11 @@ except ImportError: class ApexScaler: state_dict_key = "amp" - def __call__(self, loss, optimizer): + def __call__(self, loss, optimizer, clip_grad=None, parameters=None): with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() + if clip_grad: + torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), clip_grad) optimizer.step() def state_dict(self): @@ -35,8 +37,12 @@ class NativeScaler: def __init__(self): self._scaler = torch.cuda.amp.GradScaler() - def __call__(self, loss, optimizer): + def __call__(self, loss, optimizer, clip_grad=None, parameters=None): self._scaler.scale(loss).backward() + if clip_grad: + assert parameters is not None + self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place + torch.nn.utils.clip_grad_norm_(parameters, clip_grad) self._scaler.step(optimizer) self._scaler.update() diff --git a/timm/version.py b/timm/version.py index fc79d63d..020ed73d 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.2.1' +__version__ = '0.2.2'