Merge pull request #437 from rwightman/agc
Adaptive Gradient Clipping (AGC) Implpull/440/head v0.1-dnf-weights
commit
4ea5931964
@ -0,0 +1,42 @@
|
|||||||
|
""" Adaptive Gradient Clipping
|
||||||
|
|
||||||
|
An impl of AGC, as per (https://arxiv.org/abs/2102.06171):
|
||||||
|
|
||||||
|
@article{brock2021high,
|
||||||
|
author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan},
|
||||||
|
title={High-Performance Large-Scale Image Recognition Without Normalization},
|
||||||
|
journal={arXiv preprint arXiv:},
|
||||||
|
year={2021}
|
||||||
|
}
|
||||||
|
|
||||||
|
Code references:
|
||||||
|
* Official JAX impl (paper authors): https://github.com/deepmind/deepmind-research/tree/master/nfnets
|
||||||
|
* Phil Wang's PyTorch gist: https://gist.github.com/lucidrains/0d6560077edac419ab5d3aa29e674d5c
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2021 Ross Wightman
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def unitwise_norm(x, norm_type=2.0):
|
||||||
|
if x.ndim <= 1:
|
||||||
|
return x.norm(norm_type)
|
||||||
|
else:
|
||||||
|
# works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor
|
||||||
|
# might need special cases for other weights (possibly MHA) where this may not be true
|
||||||
|
return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True)
|
||||||
|
|
||||||
|
|
||||||
|
def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0):
|
||||||
|
if isinstance(parameters, torch.Tensor):
|
||||||
|
parameters = [parameters]
|
||||||
|
for p in parameters:
|
||||||
|
if p.grad is None:
|
||||||
|
continue
|
||||||
|
p_data = p.detach()
|
||||||
|
g_data = p.grad.detach()
|
||||||
|
max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor)
|
||||||
|
grad_norm = unitwise_norm(g_data, norm_type=norm_type)
|
||||||
|
clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6))
|
||||||
|
new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad)
|
||||||
|
p.grad.detach().copy_(new_grads)
|
@ -0,0 +1,23 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from timm.utils.agc import adaptive_clip_grad
|
||||||
|
|
||||||
|
|
||||||
|
def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0):
|
||||||
|
""" Dispatch to gradient clipping method
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parameters (Iterable): model parameters to clip
|
||||||
|
value (float): clipping value/factor/norm, mode dependant
|
||||||
|
mode (str): clipping mode, one of 'norm', 'value', 'agc'
|
||||||
|
norm_type (float): p-norm, default 2.0
|
||||||
|
"""
|
||||||
|
if mode == 'norm':
|
||||||
|
torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type)
|
||||||
|
elif mode == 'value':
|
||||||
|
torch.nn.utils.clip_grad_value_(parameters, value)
|
||||||
|
elif mode == 'agc':
|
||||||
|
adaptive_clip_grad(parameters, value, norm_type=norm_type)
|
||||||
|
else:
|
||||||
|
assert False, f"Unknown clip mode ({mode})."
|
||||||
|
|
@ -1 +1 @@
|
|||||||
__version__ = '0.4.3'
|
__version__ = '0.4.4'
|
||||||
|
Loading…
Reference in new issue