From 9541f4963bb6188789b65b2902373ec50ccfdc8e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 18 Aug 2021 11:20:25 -0700 Subject: [PATCH] One more scalar -> tensor fix for lamb optimizer --- timm/optim/lamb.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/timm/optim/lamb.py b/timm/optim/lamb.py index 19cfd121..5308e348 100644 --- a/timm/optim/lamb.py +++ b/timm/optim/lamb.py @@ -98,7 +98,7 @@ class Lamb(Optimizer): and returns the loss. """ device = self.param_groups[0]["params"][0].device - one_tensor = torch.tensor(1.0, device=device) + one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly loss = None if closure is not None: @@ -115,7 +115,9 @@ class Lamb(Optimizer): global_grad_norm.add_(grad.pow(2).sum()) global_grad_norm = torch.sqrt(global_grad_norm) - max_grad_norm = self.defaults['max_grad_norm'] + # FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes + # scalar types properly https://github.com/pytorch/pytorch/issues/9190 + max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device) clip_global_grad_norm = torch.where( global_grad_norm > max_grad_norm, global_grad_norm / max_grad_norm,