diff --git a/timm/optim/lamb.py b/timm/optim/lamb.py index 19cfd121..5308e348 100644 --- a/timm/optim/lamb.py +++ b/timm/optim/lamb.py @@ -98,7 +98,7 @@ class Lamb(Optimizer): and returns the loss. """ device = self.param_groups[0]["params"][0].device - one_tensor = torch.tensor(1.0, device=device) + one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly loss = None if closure is not None: @@ -115,7 +115,9 @@ class Lamb(Optimizer): global_grad_norm.add_(grad.pow(2).sum()) global_grad_norm = torch.sqrt(global_grad_norm) - max_grad_norm = self.defaults['max_grad_norm'] + # FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes + # scalar types properly https://github.com/pytorch/pytorch/issues/9190 + max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device) clip_global_grad_norm = torch.where( global_grad_norm > max_grad_norm, global_grad_norm / max_grad_norm,