parent
5f9aff395c
commit
4f49b94311
@ -0,0 +1,42 @@
|
||||
""" Adaptive Gradient Clipping
|
||||
|
||||
An impl of AGC, as per (https://arxiv.org/abs/2102.06171):
|
||||
|
||||
@article{brock2021high,
|
||||
author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan},
|
||||
title={High-Performance Large-Scale Image Recognition Without Normalization},
|
||||
journal={arXiv preprint arXiv:},
|
||||
year={2021}
|
||||
}
|
||||
|
||||
Code references:
|
||||
* Official JAX impl (paper authors): https://github.com/deepmind/deepmind-research/tree/master/nfnets
|
||||
* Phil Wang's PyTorch gist: https://gist.github.com/lucidrains/0d6560077edac419ab5d3aa29e674d5c
|
||||
|
||||
Hacked together by / Copyright 2021 Ross Wightman
|
||||
"""
|
||||
import torch
|
||||
|
||||
|
||||
def unitwise_norm(x, norm_type=2.0):
|
||||
if x.ndim <= 1:
|
||||
return x.norm(norm_type)
|
||||
else:
|
||||
# works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor
|
||||
# might need special cases for other weights (possibly MHA) where this may not be true
|
||||
return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True)
|
||||
|
||||
|
||||
def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0):
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
for p in parameters:
|
||||
if p.grad is None:
|
||||
continue
|
||||
p_data = p.detach()
|
||||
g_data = p.grad.detach()
|
||||
max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor)
|
||||
grad_norm = unitwise_norm(g_data, norm_type=norm_type)
|
||||
clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6))
|
||||
new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad)
|
||||
p.grad.detach().copy_(new_grads)
|
@ -0,0 +1,23 @@
|
||||
import torch
|
||||
|
||||
from timm.utils.agc import adaptive_clip_grad
|
||||
|
||||
|
||||
def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0):
|
||||
""" Dispatch to gradient clipping method
|
||||
|
||||
Args:
|
||||
parameters (Iterable): model parameters to clip
|
||||
value (float): clipping value/factor/norm, mode dependant
|
||||
mode (str): clipping mode, one of 'norm', 'value', 'agc'
|
||||
norm_type (float): p-norm, default 2.0
|
||||
"""
|
||||
if mode == 'norm':
|
||||
torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type)
|
||||
elif mode == 'value':
|
||||
torch.nn.utils.clip_grad_value_(parameters, value)
|
||||
elif mode == 'agc':
|
||||
adaptive_clip_grad(parameters, value, norm_type=norm_type)
|
||||
else:
|
||||
assert False, f"Unknown clip mode ({mode})."
|
||||
|
Loading…
Reference in new issue