From 4f49b94311860e5b695d2a919413d1aae4e0eb9c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 15 Feb 2021 23:22:44 -0800 Subject: [PATCH] Initial AGC impl. Still testing. --- timm/models/__init__.py | 2 +- timm/models/helpers.py | 20 +++++++++++++------- timm/utils/__init__.py | 2 ++ timm/utils/agc.py | 42 +++++++++++++++++++++++++++++++++++++++++ timm/utils/clip_grad.py | 23 ++++++++++++++++++++++ timm/utils/cuda.py | 10 ++++++---- train.py | 11 ++++++++--- 7 files changed, 95 insertions(+), 15 deletions(-) create mode 100644 timm/utils/agc.py create mode 100644 timm/utils/clip_grad.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index dc56848e..8d99d19b 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -31,7 +31,7 @@ from .xception import * from .xception_aligned import * from .factory import create_model -from .helpers import load_checkpoint, resume_checkpoint +from .helpers import load_checkpoint, resume_checkpoint, model_parameters from .layers import TestTimePoolHead, apply_test_time_pool from .layers import convert_splitbn_model from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit diff --git a/timm/models/helpers.py b/timm/models/helpers.py index d9b501da..4d9b8a28 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -113,10 +113,9 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_ digits of the SHA256 hash of the contents of the file. The hash is used to ensure unique names and to verify the contents of the file. Default: False """ - if cfg is None: - cfg = getattr(model, 'default_cfg') - if cfg is None or 'url' not in cfg or not cfg['url']: - _logger.warning("Pretrained model URL does not exist, using random initialization.") + cfg = cfg or getattr(model, 'default_cfg') + if cfg is None or not cfg.get('url', None): + _logger.warning("No pretrained weights exist for this model. Using random initialization.") return url = cfg['url'] @@ -174,9 +173,8 @@ def adapt_input_conv(in_chans, conv_weight): def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False): - if cfg is None: - cfg = getattr(model, 'default_cfg') - if cfg is None or 'url' not in cfg or not cfg['url']: + cfg = cfg or getattr(model, 'default_cfg') + if cfg is None or not cfg.get('url', None): _logger.warning("No pretrained weights exist for this model. Using random initialization.") return @@ -376,3 +374,11 @@ def build_model_with_cfg( model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg return model + + +def model_parameters(model, exclude_head=False): + if exclude_head: + # FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering + return [p for p in model.parameters()][:-2] + else: + return model.parameters() diff --git a/timm/utils/__init__.py b/timm/utils/__init__.py index 0f7c4b05..1c526e8c 100644 --- a/timm/utils/__init__.py +++ b/timm/utils/__init__.py @@ -1,4 +1,6 @@ +from .agc import adaptive_clip_grad from .checkpoint_saver import CheckpointSaver +from .clip_grad import dispatch_clip_grad from .cuda import ApexScaler, NativeScaler from .distributed import distribute_bn, reduce_tensor from .jit import set_jit_legacy diff --git a/timm/utils/agc.py b/timm/utils/agc.py new file mode 100644 index 00000000..f5140172 --- /dev/null +++ b/timm/utils/agc.py @@ -0,0 +1,42 @@ +""" Adaptive Gradient Clipping + +An impl of AGC, as per (https://arxiv.org/abs/2102.06171): + +@article{brock2021high, + author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan}, + title={High-Performance Large-Scale Image Recognition Without Normalization}, + journal={arXiv preprint arXiv:}, + year={2021} +} + +Code references: + * Official JAX impl (paper authors): https://github.com/deepmind/deepmind-research/tree/master/nfnets + * Phil Wang's PyTorch gist: https://gist.github.com/lucidrains/0d6560077edac419ab5d3aa29e674d5c + +Hacked together by / Copyright 2021 Ross Wightman +""" +import torch + + +def unitwise_norm(x, norm_type=2.0): + if x.ndim <= 1: + return x.norm(norm_type) + else: + # works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor + # might need special cases for other weights (possibly MHA) where this may not be true + return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True) + + +def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + for p in parameters: + if p.grad is None: + continue + p_data = p.detach() + g_data = p.grad.detach() + max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor) + grad_norm = unitwise_norm(g_data, norm_type=norm_type) + clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6)) + new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad) + p.grad.detach().copy_(new_grads) diff --git a/timm/utils/clip_grad.py b/timm/utils/clip_grad.py new file mode 100644 index 00000000..7eb40697 --- /dev/null +++ b/timm/utils/clip_grad.py @@ -0,0 +1,23 @@ +import torch + +from timm.utils.agc import adaptive_clip_grad + + +def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0): + """ Dispatch to gradient clipping method + + Args: + parameters (Iterable): model parameters to clip + value (float): clipping value/factor/norm, mode dependant + mode (str): clipping mode, one of 'norm', 'value', 'agc' + norm_type (float): p-norm, default 2.0 + """ + if mode == 'norm': + torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type) + elif mode == 'value': + torch.nn.utils.clip_grad_value_(parameters, value) + elif mode == 'agc': + adaptive_clip_grad(parameters, value, norm_type=norm_type) + else: + assert False, f"Unknown clip mode ({mode})." + diff --git a/timm/utils/cuda.py b/timm/utils/cuda.py index bcd29f58..9e7bddf3 100644 --- a/timm/utils/cuda.py +++ b/timm/utils/cuda.py @@ -11,15 +11,17 @@ except ImportError: amp = None has_apex = False +from .clip_grad import dispatch_clip_grad + class ApexScaler: state_dict_key = "amp" - def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False): + def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward(create_graph=create_graph) if clip_grad is not None: - torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), clip_grad) + dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode) optimizer.step() def state_dict(self): @@ -37,12 +39,12 @@ class NativeScaler: def __init__(self): self._scaler = torch.cuda.amp.GradScaler() - def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False): + def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): self._scaler.scale(loss).backward(create_graph=create_graph) if clip_grad is not None: assert parameters is not None self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place - torch.nn.utils.clip_grad_norm_(parameters, clip_grad) + dispatch_clip_grad(parameters, clip_grad, mode=clip_mode) self._scaler.step(optimizer) self._scaler.update() diff --git a/train.py b/train.py index 0333d72f..b787a88c 100755 --- a/train.py +++ b/train.py @@ -29,7 +29,7 @@ import torchvision.utils from torch.nn.parallel import DistributedDataParallel as NativeDDP from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset -from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model +from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model, model_parameters from timm.utils import * from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy from timm.optim import create_optimizer @@ -637,11 +637,16 @@ def train_one_epoch( optimizer.zero_grad() if loss_scaler is not None: loss_scaler( - loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order) + loss, optimizer, + clip_grad=args.clip_grad, clip_mode=args.clip_mode, + parameters=model_parameters(model, exclude_head='agc' in args.clip_mode), + create_graph=second_order) else: loss.backward(create_graph=second_order) if args.clip_grad is not None: - torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad) + dispatch_clip_grad( + model_parameters(model, exclude_head='agc' in args.clip_mode), + value=args.clip_grad, mode=args.clip_mode) optimizer.step() if model_ema is not None: