From 4f49b94311860e5b695d2a919413d1aae4e0eb9c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 15 Feb 2021 23:22:44 -0800 Subject: [PATCH 1/8] Initial AGC impl. Still testing. --- timm/models/__init__.py | 2 +- timm/models/helpers.py | 20 +++++++++++++------- timm/utils/__init__.py | 2 ++ timm/utils/agc.py | 42 +++++++++++++++++++++++++++++++++++++++++ timm/utils/clip_grad.py | 23 ++++++++++++++++++++++ timm/utils/cuda.py | 10 ++++++---- train.py | 11 ++++++++--- 7 files changed, 95 insertions(+), 15 deletions(-) create mode 100644 timm/utils/agc.py create mode 100644 timm/utils/clip_grad.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index dc56848e..8d99d19b 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -31,7 +31,7 @@ from .xception import * from .xception_aligned import * from .factory import create_model -from .helpers import load_checkpoint, resume_checkpoint +from .helpers import load_checkpoint, resume_checkpoint, model_parameters from .layers import TestTimePoolHead, apply_test_time_pool from .layers import convert_splitbn_model from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit diff --git a/timm/models/helpers.py b/timm/models/helpers.py index d9b501da..4d9b8a28 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -113,10 +113,9 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_ digits of the SHA256 hash of the contents of the file. The hash is used to ensure unique names and to verify the contents of the file. Default: False """ - if cfg is None: - cfg = getattr(model, 'default_cfg') - if cfg is None or 'url' not in cfg or not cfg['url']: - _logger.warning("Pretrained model URL does not exist, using random initialization.") + cfg = cfg or getattr(model, 'default_cfg') + if cfg is None or not cfg.get('url', None): + _logger.warning("No pretrained weights exist for this model. Using random initialization.") return url = cfg['url'] @@ -174,9 +173,8 @@ def adapt_input_conv(in_chans, conv_weight): def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False): - if cfg is None: - cfg = getattr(model, 'default_cfg') - if cfg is None or 'url' not in cfg or not cfg['url']: + cfg = cfg or getattr(model, 'default_cfg') + if cfg is None or not cfg.get('url', None): _logger.warning("No pretrained weights exist for this model. Using random initialization.") return @@ -376,3 +374,11 @@ def build_model_with_cfg( model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg return model + + +def model_parameters(model, exclude_head=False): + if exclude_head: + # FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering + return [p for p in model.parameters()][:-2] + else: + return model.parameters() diff --git a/timm/utils/__init__.py b/timm/utils/__init__.py index 0f7c4b05..1c526e8c 100644 --- a/timm/utils/__init__.py +++ b/timm/utils/__init__.py @@ -1,4 +1,6 @@ +from .agc import adaptive_clip_grad from .checkpoint_saver import CheckpointSaver +from .clip_grad import dispatch_clip_grad from .cuda import ApexScaler, NativeScaler from .distributed import distribute_bn, reduce_tensor from .jit import set_jit_legacy diff --git a/timm/utils/agc.py b/timm/utils/agc.py new file mode 100644 index 00000000..f5140172 --- /dev/null +++ b/timm/utils/agc.py @@ -0,0 +1,42 @@ +""" Adaptive Gradient Clipping + +An impl of AGC, as per (https://arxiv.org/abs/2102.06171): + +@article{brock2021high, + author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan}, + title={High-Performance Large-Scale Image Recognition Without Normalization}, + journal={arXiv preprint arXiv:}, + year={2021} +} + +Code references: + * Official JAX impl (paper authors): https://github.com/deepmind/deepmind-research/tree/master/nfnets + * Phil Wang's PyTorch gist: https://gist.github.com/lucidrains/0d6560077edac419ab5d3aa29e674d5c + +Hacked together by / Copyright 2021 Ross Wightman +""" +import torch + + +def unitwise_norm(x, norm_type=2.0): + if x.ndim <= 1: + return x.norm(norm_type) + else: + # works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor + # might need special cases for other weights (possibly MHA) where this may not be true + return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True) + + +def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + for p in parameters: + if p.grad is None: + continue + p_data = p.detach() + g_data = p.grad.detach() + max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor) + grad_norm = unitwise_norm(g_data, norm_type=norm_type) + clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6)) + new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad) + p.grad.detach().copy_(new_grads) diff --git a/timm/utils/clip_grad.py b/timm/utils/clip_grad.py new file mode 100644 index 00000000..7eb40697 --- /dev/null +++ b/timm/utils/clip_grad.py @@ -0,0 +1,23 @@ +import torch + +from timm.utils.agc import adaptive_clip_grad + + +def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0): + """ Dispatch to gradient clipping method + + Args: + parameters (Iterable): model parameters to clip + value (float): clipping value/factor/norm, mode dependant + mode (str): clipping mode, one of 'norm', 'value', 'agc' + norm_type (float): p-norm, default 2.0 + """ + if mode == 'norm': + torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type) + elif mode == 'value': + torch.nn.utils.clip_grad_value_(parameters, value) + elif mode == 'agc': + adaptive_clip_grad(parameters, value, norm_type=norm_type) + else: + assert False, f"Unknown clip mode ({mode})." + diff --git a/timm/utils/cuda.py b/timm/utils/cuda.py index bcd29f58..9e7bddf3 100644 --- a/timm/utils/cuda.py +++ b/timm/utils/cuda.py @@ -11,15 +11,17 @@ except ImportError: amp = None has_apex = False +from .clip_grad import dispatch_clip_grad + class ApexScaler: state_dict_key = "amp" - def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False): + def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward(create_graph=create_graph) if clip_grad is not None: - torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), clip_grad) + dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode) optimizer.step() def state_dict(self): @@ -37,12 +39,12 @@ class NativeScaler: def __init__(self): self._scaler = torch.cuda.amp.GradScaler() - def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False): + def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): self._scaler.scale(loss).backward(create_graph=create_graph) if clip_grad is not None: assert parameters is not None self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place - torch.nn.utils.clip_grad_norm_(parameters, clip_grad) + dispatch_clip_grad(parameters, clip_grad, mode=clip_mode) self._scaler.step(optimizer) self._scaler.update() diff --git a/train.py b/train.py index 0333d72f..b787a88c 100755 --- a/train.py +++ b/train.py @@ -29,7 +29,7 @@ import torchvision.utils from torch.nn.parallel import DistributedDataParallel as NativeDDP from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset -from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model +from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model, model_parameters from timm.utils import * from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy from timm.optim import create_optimizer @@ -637,11 +637,16 @@ def train_one_epoch( optimizer.zero_grad() if loss_scaler is not None: loss_scaler( - loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order) + loss, optimizer, + clip_grad=args.clip_grad, clip_mode=args.clip_mode, + parameters=model_parameters(model, exclude_head='agc' in args.clip_mode), + create_graph=second_order) else: loss.backward(create_graph=second_order) if args.clip_grad is not None: - torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad) + dispatch_clip_grad( + model_parameters(model, exclude_head='agc' in args.clip_mode), + value=args.clip_grad, mode=args.clip_mode) optimizer.step() if model_ema is not None: From 01653db104c8d60d2bac643b169c04139f3ae668 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 15 Feb 2021 23:27:16 -0800 Subject: [PATCH 2/8] Missed clip-mode arg for repo train script --- train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index b787a88c..9abcfed3 100755 --- a/train.py +++ b/train.py @@ -116,7 +116,8 @@ parser.add_argument('--weight-decay', type=float, default=0.0001, help='weight decay (default: 0.0001)') parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', help='Clip gradient norm (default: None, no clipping)') - +parser.add_argument('--clip-mode', type=str, default='norm', + help='Gradient clipping mode. One of ("norm", "value", "agc")') # Learning rate schedule parameters From 9de2ec5e442d2b714b004a7f2c2baf272c9e3854 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 16 Feb 2021 09:12:23 -0800 Subject: [PATCH 3/8] Update README for AGC and bump version to 0.4.4 --- README.md | 8 ++++++++ timm/version.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index c4f3a588..8b1f6f28 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,13 @@ ## What's New +### Feb 16, 2021 +* Add Adaptive Gradient Clipping (AGC) as per https://arxiv.org/abs/2102.06171. Integrated w/ PyTorch gradient clipping via mode arg that defaults to prev 'norm' mode. For backward arg compat, clip-grad arg must be specified to enable when using train.py. + * AGC w/ default clipping factor `--clip-grad .01 --clip-mode agc` + * PyTorch global norm of 1.0 (old behaviour, always norm), `--clip-grad 1.0` + * PyTorch value clipping of 10, `--clip-grad 10. --clip-mode value` + * AGC performance is definitely sensitive to the clipping factor. More experimentation needed to determine good values for smaller batch sizes and optimizers besides those in paper. So far I've found .001-.005 is necessary for stable RMSProp training. + ### Feb 12, 2021 * Update Normalization-Free nets to include new NFNet-F (https://arxiv.org/abs/2102.06171) model defs @@ -238,6 +245,7 @@ Several (less common) features that I often utilize in my projects are included. * Efficient Channel Attention - ECA (https://arxiv.org/abs/1910.03151) * Blur Pooling (https://arxiv.org/abs/1904.11486) * Space-to-Depth by [mrT23](https://github.com/mrT23/TResNet/blob/master/src/models/tresnet/layers/space_to_depth.py) (https://arxiv.org/abs/1801.04590) -- original paper? +* Adaptive Gradient Clipping (https://arxiv.org/abs/2102.06171, https://github.com/deepmind/deepmind-research/tree/master/nfnets) ## Results diff --git a/timm/version.py b/timm/version.py index 908c0bb7..9a8e054a 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.4.3' +__version__ = '0.4.4' From 361fd0fc40708c868ff92218f24b01113a617572 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 16 Feb 2021 10:27:41 -0800 Subject: [PATCH 4/8] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 8b1f6f28..421bced4 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ * AGC w/ default clipping factor `--clip-grad .01 --clip-mode agc` * PyTorch global norm of 1.0 (old behaviour, always norm), `--clip-grad 1.0` * PyTorch value clipping of 10, `--clip-grad 10. --clip-mode value` - * AGC performance is definitely sensitive to the clipping factor. More experimentation needed to determine good values for smaller batch sizes and optimizers besides those in paper. So far I've found .001-.005 is necessary for stable RMSProp training. + * AGC performance is definitely sensitive to the clipping factor. More experimentation needed to determine good values for smaller batch sizes and optimizers besides those in paper. So far I've found .001-.005 is necessary for stable RMSProp training w/ NFNet/NF-ResNet. ### Feb 12, 2021 * Update Normalization-Free nets to include new NFNet-F (https://arxiv.org/abs/2102.06171) model defs From 678ba4e0a2c0b52c5e7b2ec0ba689399840282ee Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 18 Feb 2021 12:28:46 -0800 Subject: [PATCH 5/8] Add NFNet-F model weights ported from DeepMind Haiku impl and new set of models w/ compatible config. --- README.md | 14 +++ timm/models/layers/__init__.py | 2 +- timm/models/layers/std_conv.py | 82 +++++++++++--- timm/models/nfnet.py | 199 +++++++++++++++++++++++++-------- 4 files changed, 234 insertions(+), 63 deletions(-) diff --git a/README.md b/README.md index 421bced4..012f262e 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,20 @@ ## What's New +### Feb 18, 2021 +* Add pretrained weights and model variants for NFNet-F* models from [DeepMind Haiku impl](https://github.com/deepmind/deepmind-research/tree/master/nfnets). + * Models are prefixed with `dm_`. They require SAME padding conv, skipinit enabled, and activation gains applied in act fn. + * These models are big, expect to run out of GPU memory. With the GELU activiation + other options, they are roughly 1/2 the inference speed of my SiLU PyTorch optimized `s` variants. + * Original model results are based on pre-processing that is not the same as all other models so you'll see different results in the results csv (once updated). + * Matching the original pre-processing as closely as possible I get these results: + * `dm_nfnet_f6` - 86.352 + * `dm_nfnet_f5` - 86.100 + * `dm_nfnet_f4` - 85.834 + * `dm_nfnet_f3` - 85.676 + * `dm_nfnet_f2` - 85.178 + * `dm_nfnet_f1` - 84.696 + * `dm_nfnet_f0` - 83.464 + ### Feb 16, 2021 * Add Adaptive Gradient Clipping (AGC) as per https://arxiv.org/abs/2102.06171. Integrated w/ PyTorch gradient clipping via mode arg that defaults to prev 'norm' mode. For backward arg compat, clip-grad arg must be specified to enable when using train.py. * AGC w/ default clipping factor `--clip-grad .01 --clip-mode agc` diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 6eb9f8a1..f8d8d8c0 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -29,6 +29,6 @@ from .separable_conv import SeparableConv2d, SeparableConvBnAct from .space_to_depth import SpaceToDepthModule from .split_attn import SplitAttnConv2d from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model -from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d +from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame from .test_time_pool import TestTimePoolHead, apply_test_time_pool from .weight_init import trunc_normal_ diff --git a/timm/models/layers/std_conv.py b/timm/models/layers/std_conv.py index 80a8e5d7..cddfa258 100644 --- a/timm/models/layers/std_conv.py +++ b/timm/models/layers/std_conv.py @@ -2,8 +2,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .padding import get_padding -from .conv2d_same import conv2d_same +from .padding import get_padding, get_padding_value, pad_same def get_weight(module): @@ -19,8 +18,8 @@ class StdConv2d(nn.Conv2d): https://arxiv.org/abs/1903.10520v2 """ def __init__( - self, in_channel, out_channels, kernel_size, stride=1, - padding=None, dilation=1, groups=1, bias=False, eps=1e-5): + self, in_channel, out_channels, kernel_size, stride=1, padding=None, dilation=1, + groups=1, bias=False, eps=1e-5): if padding is None: padding = get_padding(kernel_size, stride, dilation) super().__init__( @@ -45,10 +44,13 @@ class StdConv2dSame(nn.Conv2d): https://arxiv.org/abs/1903.10520v2 """ def __init__( - self, in_channel, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=False, eps=1e-5): + self, in_channel, out_channels, kernel_size, stride=1, padding='SAME', dilation=1, + groups=1, bias=False, eps=1e-5): + padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) super().__init__( - in_channel, out_channels, kernel_size, stride=stride, - padding=0, dilation=dilation, groups=groups, bias=bias) + in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, + groups=groups, bias=bias) + self.same_pad = is_dynamic self.eps = eps def get_weight(self): @@ -57,7 +59,9 @@ class StdConv2dSame(nn.Conv2d): return weight def forward(self, x): - x = conv2d_same(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups) + if self.same_pad: + x = pad_same(x, self.kernel_size, self.stride, self.dilation) + x = F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups) return x @@ -68,17 +72,18 @@ class ScaledStdConv2d(nn.Conv2d): https://arxiv.org/abs/2101.08692 """ - def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1, - bias=True, gain=True, gamma=1.0, eps=1e-5, use_layernorm=False): + def __init__( + self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1, + bias=True, gamma=1.0, eps=1e-5, use_layernorm=False): if padding is None: padding = get_padding(kernel_size, stride, dilation) super().__init__( - in_channels, out_channels, kernel_size, stride=stride, - padding=padding, dilation=dilation, groups=groups, bias=bias) - self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) if gain else None + in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, + groups=groups, bias=bias) + self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in) self.eps = eps ** 2 if use_layernorm else eps - self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory use + self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel def get_weight(self): if self.use_layernorm: @@ -86,9 +91,52 @@ class ScaledStdConv2d(nn.Conv2d): else: std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) weight = self.scale * (self.weight - mean) / (std + self.eps) - if self.gain is not None: - weight = weight * self.gain - return weight + return self.gain * weight + + def forward(self, x): + return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups) + + +class ScaledStdConv2dSame(nn.Conv2d): + """Conv2d layer with Scaled Weight Standardization and Tensorflow-like SAME padding support + + NOTE: operations and default eps slightly changed from non-SAME impl to closer match Deepmind Haiku impl. + Fore the sake of completeness, numeric differences are minor with arprox .005 top-1 difference. + + Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - + https://arxiv.org/abs/2101.08692 + """ + + def __init__( + self, in_channels, out_channels, kernel_size, stride=1, padding='SAME', dilation=1, groups=1, + bias=True, gamma=1.0, eps=1e-5, use_layernorm=False): + padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) + super().__init__( + in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, + groups=groups, bias=bias) + self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) + self.scale = gamma * self.weight[0].numel() ** -0.5 + self.same_pad = is_dynamic + self.eps = eps ** 2 if use_layernorm else eps + self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel + + # NOTE an alternate formulation to consider, closer to DeepMind Haiku impl but doesn't seem + # to make much numerical difference (+/- .002 to .004) in top-1 during eval. + # def get_weight(self): + # var, mean = torch.var_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) + # scale = torch.rsqrt((self.weight[0].numel() * var).clamp_(self.eps)) * self.gain + # weight = (self.weight - mean) * scale + # return self.gain * weight + + def get_weight(self): + if self.use_layernorm: + weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps) + else: + std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) + weight = self.scale * (self.weight - mean) / (std + self.eps) + return self.gain * weight def forward(self, x): + if self.same_pad: + x = pad_same(x, self.kernel_size, self.stride, self.dilation) return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups) diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index b43ee5ef..dafe2efa 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -24,12 +24,12 @@ from functools import partial import torch import torch.nn as nn -import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg from .registry import register_model -from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, get_act_layer, get_attn, make_divisible, get_act_fn +from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\ + get_act_layer, get_act_fn, get_attn, make_divisible def _dcfg(url='', **kwargs): @@ -38,75 +38,102 @@ def _dcfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.9, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'stem.conv', 'classifier': 'head.fc', + 'first_conv': 'stem.conv1', 'classifier': 'head.fc', **kwargs } default_cfgs = dict( + dm_nfnet_f0=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f0-604f9c3a.pth', + pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), crop_pct=.9), + dm_nfnet_f1=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f1-fc540f82.pth', + pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320), crop_pct=0.91), + dm_nfnet_f2=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f2-89875923.pth', + pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352), crop_pct=0.92), + dm_nfnet_f3=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f3-d74ab3aa.pth', + pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416), crop_pct=0.94), + dm_nfnet_f4=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f4-0ac5b10b.pth', + pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512), crop_pct=0.951), + dm_nfnet_f5=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f5-ecb20ab1.pth', + pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544), crop_pct=0.954), + dm_nfnet_f6=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f6-e0f12116.pth', + pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576), crop_pct=0.956), + nfnet_f0=_dcfg( - url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'), + url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)), nfnet_f1=_dcfg( - url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320), first_conv='stem.conv1'), + url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320)), nfnet_f2=_dcfg( - url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352), first_conv='stem.conv1'), + url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352)), nfnet_f3=_dcfg( - url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416), first_conv='stem.conv1'), + url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416)), nfnet_f4=_dcfg( - url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512), first_conv='stem.conv1'), + url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512)), nfnet_f5=_dcfg( - url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544), first_conv='stem.conv1'), + url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544)), nfnet_f6=_dcfg( - url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576), first_conv='stem.conv1'), + url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576)), nfnet_f7=_dcfg( - url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608), first_conv='stem.conv1'), + url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608)), nfnet_f0s=_dcfg( - url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'), + url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)), nfnet_f1s=_dcfg( - url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320), first_conv='stem.conv1'), + url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320)), nfnet_f2s=_dcfg( - url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352), first_conv='stem.conv1'), + url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352)), nfnet_f3s=_dcfg( - url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416), first_conv='stem.conv1'), + url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416)), nfnet_f4s=_dcfg( - url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512), first_conv='stem.conv1'), + url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512)), nfnet_f5s=_dcfg( - url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544), first_conv='stem.conv1'), + url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544)), nfnet_f6s=_dcfg( - url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576), first_conv='stem.conv1'), + url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576)), nfnet_f7s=_dcfg( - url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608), first_conv='stem.conv1'), + url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608)), nfnet_l0a=_dcfg( - url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'), + url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)), nfnet_l0b=_dcfg( - url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'), + url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)), nfnet_l0c=_dcfg( - url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'), + url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)), - nf_regnet_b0=_dcfg(url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)), + nf_regnet_b0=_dcfg( + url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv'), nf_regnet_b1=_dcfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_regnet_b1_256_ra2-ad85cfef.pth', - pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288)), # NOT to paper spec - nf_regnet_b2=_dcfg(url='', pool_size=(8, 8), input_size=(3, 240, 240), test_input_size=(3, 272, 272)), - nf_regnet_b3=_dcfg(url='', pool_size=(9, 9), input_size=(3, 288, 288), test_input_size=(3, 320, 320)), - nf_regnet_b4=_dcfg(url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384)), - nf_regnet_b5=_dcfg(url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 456, 456)), - - nf_resnet26=_dcfg(url=''), + pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), first_conv='stem.conv'), # NOT to paper spec + nf_regnet_b2=_dcfg( + url='', pool_size=(8, 8), input_size=(3, 240, 240), test_input_size=(3, 272, 272), first_conv='stem.conv'), + nf_regnet_b3=_dcfg( + url='', pool_size=(9, 9), input_size=(3, 288, 288), test_input_size=(3, 320, 320), first_conv='stem.conv'), + nf_regnet_b4=_dcfg( + url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384), first_conv='stem.conv'), + nf_regnet_b5=_dcfg( + url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 456, 456), first_conv='stem.conv'), + + nf_resnet26=_dcfg(url='', first_conv='stem.conv'), nf_resnet50=_dcfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_resnet50_ra2-9f236009.pth', - pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), crop_pct=0.94), - nf_resnet101=_dcfg(url=''), + pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), crop_pct=0.94, first_conv='stem.conv'), + nf_resnet101=_dcfg(url='', first_conv='stem.conv'), - nf_seresnet26=_dcfg(url=''), - nf_seresnet50=_dcfg(url=''), - nf_seresnet101=_dcfg(url=''), + nf_seresnet26=_dcfg(url='', first_conv='stem.conv'), + nf_seresnet50=_dcfg(url='', first_conv='stem.conv'), + nf_seresnet101=_dcfg(url='', first_conv='stem.conv'), - nf_ecaresnet26=_dcfg(url=''), - nf_ecaresnet50=_dcfg(url=''), - nf_ecaresnet101=_dcfg(url=''), + nf_ecaresnet26=_dcfg(url='', first_conv='stem.conv'), + nf_ecaresnet50=_dcfg(url='', first_conv='stem.conv'), + nf_ecaresnet101=_dcfg(url='', first_conv='stem.conv'), ) @@ -115,7 +142,6 @@ class NfCfg: depths: Tuple[int, int, int, int] channels: Tuple[int, int, int, int] alpha: float = 0.2 - gamma_in_act: bool = False stem_type: str = '3x3' stem_chs: Optional[int] = None group_size: Optional[int] = None @@ -128,6 +154,8 @@ class NfCfg: ch_div: int = 8 # round channels % 8 == 0 to keep tensor-core use optimal reg: bool = False # enables EfficientNet-like options used in RegNet variants, expand from in_chs, se in middle extra_conv: bool = False # extra 3x3 bottleneck convolution for NFNet models + gamma_in_act: bool = False + same_padding: bool = False skipinit: bool = False # disabled by default, non-trivial performance impact zero_init_fc: bool = False act_layer: str = 'silu' @@ -163,8 +191,26 @@ def _nfnet_cfg( return cfg +def _dm_nfnet_cfg(depths, channels=(256, 512, 1536, 1536), act_layer='gelu', skipinit=True): + attn_kwargs = dict(reduction_ratio=0.5, divisor=8) + cfg = NfCfg( + depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=128, + bottle_ratio=0.5, extra_conv=True, gamma_in_act=True, same_padding=True, skipinit=skipinit, + num_features=int(channels[-1] * 2.0), act_layer=act_layer, attn_layer='se', attn_kwargs=attn_kwargs) + return cfg + + model_cfgs = dict( - # NFNet-F models w/ GeLU + # NFNet-F models w/ GELU compatible with DeepMind weights + dm_nfnet_f0=_dm_nfnet_cfg(depths=(1, 2, 6, 3)), + dm_nfnet_f1=_dm_nfnet_cfg(depths=(2, 4, 12, 6)), + dm_nfnet_f2=_dm_nfnet_cfg(depths=(3, 6, 18, 9)), + dm_nfnet_f3=_dm_nfnet_cfg(depths=(4, 8, 24, 12)), + dm_nfnet_f4=_dm_nfnet_cfg(depths=(5, 10, 30, 15)), + dm_nfnet_f5=_dm_nfnet_cfg(depths=(6, 12, 36, 18)), + dm_nfnet_f6=_dm_nfnet_cfg(depths=(7, 14, 42, 21)), + + # NFNet-F models w/ GELU (I will likely deprecate/remove these models and just keep dm_ ver for GELU) nfnet_f0=_nfnet_cfg(depths=(1, 2, 6, 3)), nfnet_f1=_nfnet_cfg(depths=(2, 4, 12, 6)), nfnet_f2=_nfnet_cfg(depths=(3, 6, 18, 9)), @@ -229,7 +275,7 @@ class GammaAct(nn.Module): self.inplace = inplace def forward(self, x): - return self.gamma * self.act_fn(x, inplace=self.inplace) + return self.act_fn(x, inplace=self.inplace).mul_(self.gamma) def act_with_gamma(act_type, gamma: float = 1.): @@ -325,8 +371,7 @@ class NormFreeBlock(nn.Module): out = self.drop_path(out) if self.skipinit_gain is not None: - # this really slows things down for some reason, TBD - out = out * self.skipinit_gain + out.mul_(self.skipinit_gain) # this slows things down more than expected, TBD out = out * self.alpha + shortcut return out @@ -419,12 +464,13 @@ class NormFreeNet(nn.Module): self.num_classes = num_classes self.drop_rate = drop_rate assert cfg.act_layer in _nonlin_gamma, f"Please add non-linearity constants for activation ({cfg.act_layer})." + conv_layer = ScaledStdConv2dSame if cfg.same_padding else ScaledStdConv2d if cfg.gamma_in_act: act_layer = act_with_gamma(cfg.act_layer, gamma=_nonlin_gamma[cfg.act_layer]) - conv_layer = partial(ScaledStdConv2d, bias=True, gain=True) + conv_layer = partial(conv_layer, eps=1e-4) # DM weights better with higher eps else: act_layer = get_act_layer(cfg.act_layer) - conv_layer = partial(ScaledStdConv2d, bias=True, gain=True, gamma=_nonlin_gamma[cfg.act_layer]) + conv_layer = partial(conv_layer, gamma=_nonlin_gamma[cfg.act_layer]) attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None stem_chs = make_divisible((cfg.stem_chs or cfg.channels[0]) * cfg.width_factor, cfg.ch_div) @@ -538,6 +584,69 @@ def _create_normfreenet(variant, pretrained=False, **kwargs): **kwargs) +@register_model +def dm_nfnet_f0(pretrained=False, **kwargs): + """ NFNet-F0 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f0', pretrained=pretrained, **kwargs) + + +@register_model +def dm_nfnet_f1(pretrained=False, **kwargs): + """ NFNet-F1 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f1', pretrained=pretrained, **kwargs) + + +@register_model +def dm_nfnet_f2(pretrained=False, **kwargs): + """ NFNet-F2 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f2', pretrained=pretrained, **kwargs) + + +@register_model +def dm_nfnet_f3(pretrained=False, **kwargs): + """ NFNet-F3 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f3', pretrained=pretrained, **kwargs) + + +@register_model +def dm_nfnet_f4(pretrained=False, **kwargs): + """ NFNet-F4 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f4', pretrained=pretrained, **kwargs) + + +@register_model +def dm_nfnet_f5(pretrained=False, **kwargs): + """ NFNet-F5 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f5', pretrained=pretrained, **kwargs) + + +@register_model +def dm_nfnet_f6(pretrained=False, **kwargs): + """ NFNet-F6 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f6', pretrained=pretrained, **kwargs) + + @register_model def nfnet_f0(pretrained=False, **kwargs): """ NFNet-F0 From 8563609b2828a27608dab0117027d5c2584529b7 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 18 Feb 2021 12:44:08 -0800 Subject: [PATCH 6/8] Update notes in ScaledStdConv impl --- timm/models/layers/std_conv.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/timm/models/layers/std_conv.py b/timm/models/layers/std_conv.py index cddfa258..077dc5fb 100644 --- a/timm/models/layers/std_conv.py +++ b/timm/models/layers/std_conv.py @@ -70,6 +70,8 @@ class ScaledStdConv2d(nn.Conv2d): Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - https://arxiv.org/abs/2101.08692 + + NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor. """ def __init__( @@ -100,11 +102,10 @@ class ScaledStdConv2d(nn.Conv2d): class ScaledStdConv2dSame(nn.Conv2d): """Conv2d layer with Scaled Weight Standardization and Tensorflow-like SAME padding support - NOTE: operations and default eps slightly changed from non-SAME impl to closer match Deepmind Haiku impl. - Fore the sake of completeness, numeric differences are minor with arprox .005 top-1 difference. - Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - https://arxiv.org/abs/2101.08692 + + NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor. """ def __init__( From da4839530c3decf1fa419a2e94f35c20bbeeadf4 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 18 Feb 2021 13:43:04 -0800 Subject: [PATCH 7/8] Fix test model filter to include dm_ variants that break GitHub CI limits --- tests/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index 407e0fe5..d085a623 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -21,7 +21,7 @@ if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system(): # GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models EXCLUDE_FILTERS = [ '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', - 'nfnet_f4*', 'nfnet_f5*', 'nfnet_f6*', 'nfnet_f7*'] + NON_STD_FILTERS + '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*'] + NON_STD_FILTERS else: EXCLUDE_FILTERS = NON_STD_FILTERS From 48371a33b11fc30ca23ed8988619af08902b215b Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 18 Feb 2021 16:12:53 -0800 Subject: [PATCH 8/8] Create FUNDING.yml $cale up the training... --- .github/FUNDING.yml | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 .github/FUNDING.yml diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 00000000..ab0474f2 --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,2 @@ +# These are supported funding model platforms +github: rwightman