import torch from timm.utils.agc import adaptive_clip_grad def dispatch_clip_grad( parameters, value: float, mode: str = 'norm', norm_type: float = 2.0): """ Dispatch to gradient clipping method Args: parameters (Iterable): model parameters to clip value (float): clipping value/factor/norm, mode dependant mode (str): clipping mode, one of 'norm', 'value', 'agc' norm_type (float): p-norm, default 2.0 """ if mode == 'norm': torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type) elif mode == 'value': torch.nn.utils.clip_grad_value_(parameters, value) elif mode == 'agc': adaptive_clip_grad(parameters, value, norm_type=norm_type) else: assert False, f"Unknown clip mode ({mode})."