|
|
|
@ -98,7 +98,7 @@ class Lamb(Optimizer):
|
|
|
|
|
and returns the loss.
|
|
|
|
|
"""
|
|
|
|
|
device = self.param_groups[0]["params"][0].device
|
|
|
|
|
one_tensor = torch.tensor(1.0, device=device)
|
|
|
|
|
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
|
|
|
|
|
|
|
|
|
|
loss = None
|
|
|
|
|
if closure is not None:
|
|
|
|
@ -115,7 +115,9 @@ class Lamb(Optimizer):
|
|
|
|
|
global_grad_norm.add_(grad.pow(2).sum())
|
|
|
|
|
|
|
|
|
|
global_grad_norm = torch.sqrt(global_grad_norm)
|
|
|
|
|
max_grad_norm = self.defaults['max_grad_norm']
|
|
|
|
|
# FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes
|
|
|
|
|
# scalar types properly https://github.com/pytorch/pytorch/issues/9190
|
|
|
|
|
max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device)
|
|
|
|
|
clip_global_grad_norm = torch.where(
|
|
|
|
|
global_grad_norm > max_grad_norm,
|
|
|
|
|
global_grad_norm / max_grad_norm,
|
|
|
|
|