One more scalar -> tensor fix for lamb optimizer

pull/816/head
Ross Wightman 3 years ago
parent 8f68193c91
commit 9541f4963b

@ -98,7 +98,7 @@ class Lamb(Optimizer):
and returns the loss.
"""
device = self.param_groups[0]["params"][0].device
one_tensor = torch.tensor(1.0, device=device)
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
loss = None
if closure is not None:
@ -115,7 +115,9 @@ class Lamb(Optimizer):
global_grad_norm.add_(grad.pow(2).sum())
global_grad_norm = torch.sqrt(global_grad_norm)
max_grad_norm = self.defaults['max_grad_norm']
# FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes
# scalar types properly https://github.com/pytorch/pytorch/issues/9190
max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device)
clip_global_grad_norm = torch.where(
global_grad_norm > max_grad_norm,
global_grad_norm / max_grad_norm,

Loading…
Cancel
Save