You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
135 lines
5.1 KiB
135 lines
5.1 KiB
""" PyTorch LARS / LARC Optimizer
|
|
|
|
An implementation of LARS (SGD) + LARC in PyTorch
|
|
|
|
Based on:
|
|
* PyTorch SGD: https://github.com/pytorch/pytorch/blob/1.7/torch/optim/sgd.py#L100
|
|
* NVIDIA APEX LARC: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
|
|
|
|
Additional cleanup and modifications to properly support PyTorch XLA.
|
|
|
|
Copyright 2021 Ross Wightman
|
|
"""
|
|
import torch
|
|
from torch.optim.optimizer import Optimizer
|
|
|
|
|
|
class Lars(Optimizer):
|
|
""" LARS for PyTorch
|
|
|
|
Paper: `Large batch training of Convolutional Networks` - https://arxiv.org/pdf/1708.03888.pdf
|
|
|
|
Args:
|
|
params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
|
|
lr (float, optional): learning rate (default: 1.0).
|
|
momentum (float, optional): momentum factor (default: 0)
|
|
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
|
dampening (float, optional): dampening for momentum (default: 0)
|
|
nesterov (bool, optional): enables Nesterov momentum (default: False)
|
|
trust_coeff (float): trust coefficient for computing adaptive lr / trust_ratio (default: 0.001)
|
|
eps (float): eps for division denominator (default: 1e-8)
|
|
trust_clip (bool): enable LARC trust ratio clipping (default: False)
|
|
always_adapt (bool): always apply LARS LR adapt, otherwise only when group weight_decay != 0 (default: False)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
params,
|
|
lr=1.0,
|
|
momentum=0,
|
|
dampening=0,
|
|
weight_decay=0,
|
|
nesterov=False,
|
|
trust_coeff=0.001,
|
|
eps=1e-8,
|
|
trust_clip=False,
|
|
always_adapt=False,
|
|
):
|
|
if lr < 0.0:
|
|
raise ValueError(f"Invalid learning rate: {lr}")
|
|
if momentum < 0.0:
|
|
raise ValueError(f"Invalid momentum value: {momentum}")
|
|
if weight_decay < 0.0:
|
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
|
if nesterov and (momentum <= 0 or dampening != 0):
|
|
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
|
|
|
|
defaults = dict(
|
|
lr=lr,
|
|
momentum=momentum,
|
|
dampening=dampening,
|
|
weight_decay=weight_decay,
|
|
nesterov=nesterov,
|
|
trust_coeff=trust_coeff,
|
|
eps=eps,
|
|
trust_clip=trust_clip,
|
|
always_adapt=always_adapt,
|
|
)
|
|
super().__init__(params, defaults)
|
|
|
|
def __setstate__(self, state):
|
|
super().__setstate__(state)
|
|
for group in self.param_groups:
|
|
group.setdefault("nesterov", False)
|
|
|
|
@torch.no_grad()
|
|
def step(self, closure=None):
|
|
"""Performs a single optimization step.
|
|
|
|
Args:
|
|
closure (callable, optional): A closure that reevaluates the model and returns the loss.
|
|
"""
|
|
loss = None
|
|
if closure is not None:
|
|
with torch.enable_grad():
|
|
loss = closure()
|
|
|
|
device = self.param_groups[0]['params'][0].device
|
|
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
|
|
|
|
for group in self.param_groups:
|
|
weight_decay = group['weight_decay']
|
|
momentum = group['momentum']
|
|
dampening = group['dampening']
|
|
nesterov = group['nesterov']
|
|
trust_coeff = group['trust_coeff']
|
|
eps = group['eps']
|
|
|
|
for p in group['params']:
|
|
if p.grad is None:
|
|
continue
|
|
grad = p.grad
|
|
|
|
# apply LARS LR adaptation, LARC clipping, weight decay
|
|
# ref: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
|
|
if weight_decay != 0 or group['always_adapt']:
|
|
w_norm = p.norm(2.0)
|
|
g_norm = grad.norm(2.0)
|
|
trust_ratio = trust_coeff * w_norm / (g_norm + w_norm * weight_decay + eps)
|
|
# FIXME nested where required since logical and/or not working in PT XLA
|
|
trust_ratio = torch.where(
|
|
w_norm > 0,
|
|
torch.where(g_norm > 0, trust_ratio, one_tensor),
|
|
one_tensor,
|
|
)
|
|
if group['trust_clip']:
|
|
trust_ratio = torch.minimum(trust_ratio / group['lr'], one_tensor)
|
|
grad.add(p, alpha=weight_decay)
|
|
grad.mul_(trust_ratio)
|
|
|
|
# apply SGD update https://github.com/pytorch/pytorch/blob/1.7/torch/optim/sgd.py#L100
|
|
if momentum != 0:
|
|
param_state = self.state[p]
|
|
if 'momentum_buffer' not in param_state:
|
|
buf = param_state['momentum_buffer'] = torch.clone(grad).detach()
|
|
else:
|
|
buf = param_state['momentum_buffer']
|
|
buf.mul_(momentum).add_(grad, alpha=1. - dampening)
|
|
if nesterov:
|
|
grad = grad.add(buf, alpha=momentum)
|
|
else:
|
|
grad = buf
|
|
|
|
p.add_(grad, alpha=-group['lr'])
|
|
|
|
return loss |