Merge pull request #2 from rwightman/master

Merging commits from timm
pull/434/head
szingaro 4 years ago committed by GitHub
commit d79d372be4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,2 @@
# These are supported funding model platforms
github: rwightman

@ -2,6 +2,27 @@
## What's New
### Feb 18, 2021
* Add pretrained weights and model variants for NFNet-F* models from [DeepMind Haiku impl](https://github.com/deepmind/deepmind-research/tree/master/nfnets).
* Models are prefixed with `dm_`. They require SAME padding conv, skipinit enabled, and activation gains applied in act fn.
* These models are big, expect to run out of GPU memory. With the GELU activiation + other options, they are roughly 1/2 the inference speed of my SiLU PyTorch optimized `s` variants.
* Original model results are based on pre-processing that is not the same as all other models so you'll see different results in the results csv (once updated).
* Matching the original pre-processing as closely as possible I get these results:
* `dm_nfnet_f6` - 86.352
* `dm_nfnet_f5` - 86.100
* `dm_nfnet_f4` - 85.834
* `dm_nfnet_f3` - 85.676
* `dm_nfnet_f2` - 85.178
* `dm_nfnet_f1` - 84.696
* `dm_nfnet_f0` - 83.464
### Feb 16, 2021
* Add Adaptive Gradient Clipping (AGC) as per https://arxiv.org/abs/2102.06171. Integrated w/ PyTorch gradient clipping via mode arg that defaults to prev 'norm' mode. For backward arg compat, clip-grad arg must be specified to enable when using train.py.
* AGC w/ default clipping factor `--clip-grad .01 --clip-mode agc`
* PyTorch global norm of 1.0 (old behaviour, always norm), `--clip-grad 1.0`
* PyTorch value clipping of 10, `--clip-grad 10. --clip-mode value`
* AGC performance is definitely sensitive to the clipping factor. More experimentation needed to determine good values for smaller batch sizes and optimizers besides those in paper. So far I've found .001-.005 is necessary for stable RMSProp training w/ NFNet/NF-ResNet.
### Feb 12, 2021
* Update Normalization-Free nets to include new NFNet-F (https://arxiv.org/abs/2102.06171) model defs
@ -238,6 +259,7 @@ Several (less common) features that I often utilize in my projects are included.
* Efficient Channel Attention - ECA (https://arxiv.org/abs/1910.03151)
* Blur Pooling (https://arxiv.org/abs/1904.11486)
* Space-to-Depth by [mrT23](https://github.com/mrT23/TResNet/blob/master/src/models/tresnet/layers/space_to_depth.py) (https://arxiv.org/abs/1801.04590) -- original paper?
* Adaptive Gradient Clipping (https://arxiv.org/abs/2102.06171, https://github.com/deepmind/deepmind-research/tree/master/nfnets)
## Results

@ -21,7 +21,7 @@ if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
# GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models
EXCLUDE_FILTERS = [
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm',
'nfnet_f4*', 'nfnet_f5*', 'nfnet_f6*', 'nfnet_f7*'] + NON_STD_FILTERS
'*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*'] + NON_STD_FILTERS
else:
EXCLUDE_FILTERS = NON_STD_FILTERS

@ -31,7 +31,7 @@ from .xception import *
from .xception_aligned import *
from .factory import create_model
from .helpers import load_checkpoint, resume_checkpoint
from .helpers import load_checkpoint, resume_checkpoint, model_parameters
from .layers import TestTimePoolHead, apply_test_time_pool
from .layers import convert_splitbn_model
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit

@ -113,10 +113,9 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_
digits of the SHA256 hash of the contents of the file. The hash is used to
ensure unique names and to verify the contents of the file. Default: False
"""
if cfg is None:
cfg = getattr(model, 'default_cfg')
if cfg is None or 'url' not in cfg or not cfg['url']:
_logger.warning("Pretrained model URL does not exist, using random initialization.")
cfg = cfg or getattr(model, 'default_cfg')
if cfg is None or not cfg.get('url', None):
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
return
url = cfg['url']
@ -174,9 +173,8 @@ def adapt_input_conv(in_chans, conv_weight):
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False):
if cfg is None:
cfg = getattr(model, 'default_cfg')
if cfg is None or 'url' not in cfg or not cfg['url']:
cfg = cfg or getattr(model, 'default_cfg')
if cfg is None or not cfg.get('url', None):
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
return
@ -381,3 +379,11 @@ def build_model_with_cfg(
model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg
return model
def model_parameters(model, exclude_head=False):
if exclude_head:
# FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering
return [p for p in model.parameters()][:-2]
else:
return model.parameters()

@ -29,6 +29,6 @@ from .separable_conv import SeparableConv2d, SeparableConvBnAct
from .space_to_depth import SpaceToDepthModule
from .split_attn import SplitAttnConv2d
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
from .weight_init import trunc_normal_

@ -2,8 +2,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from .padding import get_padding
from .conv2d_same import conv2d_same
from .padding import get_padding, get_padding_value, pad_same
def get_weight(module):
@ -19,8 +18,8 @@ class StdConv2d(nn.Conv2d):
https://arxiv.org/abs/1903.10520v2
"""
def __init__(
self, in_channel, out_channels, kernel_size, stride=1,
padding=None, dilation=1, groups=1, bias=False, eps=1e-5):
self, in_channel, out_channels, kernel_size, stride=1, padding=None, dilation=1,
groups=1, bias=False, eps=1e-5):
if padding is None:
padding = get_padding(kernel_size, stride, dilation)
super().__init__(
@ -45,10 +44,13 @@ class StdConv2dSame(nn.Conv2d):
https://arxiv.org/abs/1903.10520v2
"""
def __init__(
self, in_channel, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=False, eps=1e-5):
self, in_channel, out_channels, kernel_size, stride=1, padding='SAME', dilation=1,
groups=1, bias=False, eps=1e-5):
padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation)
super().__init__(
in_channel, out_channels, kernel_size, stride=stride,
padding=0, dilation=dilation, groups=groups, bias=bias)
in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
groups=groups, bias=bias)
self.same_pad = is_dynamic
self.eps = eps
def get_weight(self):
@ -57,7 +59,9 @@ class StdConv2dSame(nn.Conv2d):
return weight
def forward(self, x):
x = conv2d_same(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
if self.same_pad:
x = pad_same(x, self.kernel_size, self.stride, self.dilation)
x = F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
return x
@ -66,19 +70,22 @@ class ScaledStdConv2d(nn.Conv2d):
Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` -
https://arxiv.org/abs/2101.08692
NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor.
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1,
bias=True, gain=True, gamma=1.0, eps=1e-5, use_layernorm=False):
def __init__(
self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1,
bias=True, gamma=1.0, eps=1e-5, use_layernorm=False):
if padding is None:
padding = get_padding(kernel_size, stride, dilation)
super().__init__(
in_channels, out_channels, kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=bias)
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) if gain else None
in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
groups=groups, bias=bias)
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1))
self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in)
self.eps = eps ** 2 if use_layernorm else eps
self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory use
self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel
def get_weight(self):
if self.use_layernorm:
@ -86,9 +93,51 @@ class ScaledStdConv2d(nn.Conv2d):
else:
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
weight = self.scale * (self.weight - mean) / (std + self.eps)
if self.gain is not None:
weight = weight * self.gain
return weight
return self.gain * weight
def forward(self, x):
return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
class ScaledStdConv2dSame(nn.Conv2d):
"""Conv2d layer with Scaled Weight Standardization and Tensorflow-like SAME padding support
Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` -
https://arxiv.org/abs/2101.08692
NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor.
"""
def __init__(
self, in_channels, out_channels, kernel_size, stride=1, padding='SAME', dilation=1, groups=1,
bias=True, gamma=1.0, eps=1e-5, use_layernorm=False):
padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation)
super().__init__(
in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
groups=groups, bias=bias)
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1))
self.scale = gamma * self.weight[0].numel() ** -0.5
self.same_pad = is_dynamic
self.eps = eps ** 2 if use_layernorm else eps
self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel
# NOTE an alternate formulation to consider, closer to DeepMind Haiku impl but doesn't seem
# to make much numerical difference (+/- .002 to .004) in top-1 during eval.
# def get_weight(self):
# var, mean = torch.var_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
# scale = torch.rsqrt((self.weight[0].numel() * var).clamp_(self.eps)) * self.gain
# weight = (self.weight - mean) * scale
# return self.gain * weight
def get_weight(self):
if self.use_layernorm:
weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
else:
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
weight = self.scale * (self.weight - mean) / (std + self.eps)
return self.gain * weight
def forward(self, x):
if self.same_pad:
x = pad_same(x, self.kernel_size, self.stride, self.dilation)
return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)

@ -24,12 +24,12 @@ from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .registry import register_model
from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, get_act_layer, get_attn, make_divisible, get_act_fn
from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\
get_act_layer, get_act_fn, get_attn, make_divisible
def _dcfg(url='', **kwargs):
@ -38,75 +38,102 @@ def _dcfg(url='', **kwargs):
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.9, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.conv', 'classifier': 'head.fc',
'first_conv': 'stem.conv1', 'classifier': 'head.fc',
**kwargs
}
default_cfgs = dict(
dm_nfnet_f0=_dcfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f0-604f9c3a.pth',
pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), crop_pct=.9),
dm_nfnet_f1=_dcfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f1-fc540f82.pth',
pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320), crop_pct=0.91),
dm_nfnet_f2=_dcfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f2-89875923.pth',
pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352), crop_pct=0.92),
dm_nfnet_f3=_dcfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f3-d74ab3aa.pth',
pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416), crop_pct=0.94),
dm_nfnet_f4=_dcfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f4-0ac5b10b.pth',
pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512), crop_pct=0.951),
dm_nfnet_f5=_dcfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f5-ecb20ab1.pth',
pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544), crop_pct=0.954),
dm_nfnet_f6=_dcfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f6-e0f12116.pth',
pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576), crop_pct=0.956),
nfnet_f0=_dcfg(
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'),
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
nfnet_f1=_dcfg(
url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320), first_conv='stem.conv1'),
url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320)),
nfnet_f2=_dcfg(
url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352), first_conv='stem.conv1'),
url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352)),
nfnet_f3=_dcfg(
url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416), first_conv='stem.conv1'),
url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416)),
nfnet_f4=_dcfg(
url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512), first_conv='stem.conv1'),
url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512)),
nfnet_f5=_dcfg(
url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544), first_conv='stem.conv1'),
url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544)),
nfnet_f6=_dcfg(
url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576), first_conv='stem.conv1'),
url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576)),
nfnet_f7=_dcfg(
url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608), first_conv='stem.conv1'),
url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608)),
nfnet_f0s=_dcfg(
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'),
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
nfnet_f1s=_dcfg(
url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320), first_conv='stem.conv1'),
url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320)),
nfnet_f2s=_dcfg(
url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352), first_conv='stem.conv1'),
url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352)),
nfnet_f3s=_dcfg(
url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416), first_conv='stem.conv1'),
url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416)),
nfnet_f4s=_dcfg(
url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512), first_conv='stem.conv1'),
url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512)),
nfnet_f5s=_dcfg(
url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544), first_conv='stem.conv1'),
url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544)),
nfnet_f6s=_dcfg(
url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576), first_conv='stem.conv1'),
url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576)),
nfnet_f7s=_dcfg(
url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608), first_conv='stem.conv1'),
url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608)),
nfnet_l0a=_dcfg(
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'),
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
nfnet_l0b=_dcfg(
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'),
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
nfnet_l0c=_dcfg(
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'),
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
nf_regnet_b0=_dcfg(url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
nf_regnet_b0=_dcfg(
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv'),
nf_regnet_b1=_dcfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_regnet_b1_256_ra2-ad85cfef.pth',
pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288)), # NOT to paper spec
nf_regnet_b2=_dcfg(url='', pool_size=(8, 8), input_size=(3, 240, 240), test_input_size=(3, 272, 272)),
nf_regnet_b3=_dcfg(url='', pool_size=(9, 9), input_size=(3, 288, 288), test_input_size=(3, 320, 320)),
nf_regnet_b4=_dcfg(url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384)),
nf_regnet_b5=_dcfg(url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 456, 456)),
nf_resnet26=_dcfg(url=''),
pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), first_conv='stem.conv'), # NOT to paper spec
nf_regnet_b2=_dcfg(
url='', pool_size=(8, 8), input_size=(3, 240, 240), test_input_size=(3, 272, 272), first_conv='stem.conv'),
nf_regnet_b3=_dcfg(
url='', pool_size=(9, 9), input_size=(3, 288, 288), test_input_size=(3, 320, 320), first_conv='stem.conv'),
nf_regnet_b4=_dcfg(
url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384), first_conv='stem.conv'),
nf_regnet_b5=_dcfg(
url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 456, 456), first_conv='stem.conv'),
nf_resnet26=_dcfg(url='', first_conv='stem.conv'),
nf_resnet50=_dcfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_resnet50_ra2-9f236009.pth',
pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), crop_pct=0.94),
nf_resnet101=_dcfg(url=''),
pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), crop_pct=0.94, first_conv='stem.conv'),
nf_resnet101=_dcfg(url='', first_conv='stem.conv'),
nf_seresnet26=_dcfg(url=''),
nf_seresnet50=_dcfg(url=''),
nf_seresnet101=_dcfg(url=''),
nf_seresnet26=_dcfg(url='', first_conv='stem.conv'),
nf_seresnet50=_dcfg(url='', first_conv='stem.conv'),
nf_seresnet101=_dcfg(url='', first_conv='stem.conv'),
nf_ecaresnet26=_dcfg(url=''),
nf_ecaresnet50=_dcfg(url=''),
nf_ecaresnet101=_dcfg(url=''),
nf_ecaresnet26=_dcfg(url='', first_conv='stem.conv'),
nf_ecaresnet50=_dcfg(url='', first_conv='stem.conv'),
nf_ecaresnet101=_dcfg(url='', first_conv='stem.conv'),
)
@ -115,7 +142,6 @@ class NfCfg:
depths: Tuple[int, int, int, int]
channels: Tuple[int, int, int, int]
alpha: float = 0.2
gamma_in_act: bool = False
stem_type: str = '3x3'
stem_chs: Optional[int] = None
group_size: Optional[int] = None
@ -128,6 +154,8 @@ class NfCfg:
ch_div: int = 8 # round channels % 8 == 0 to keep tensor-core use optimal
reg: bool = False # enables EfficientNet-like options used in RegNet variants, expand from in_chs, se in middle
extra_conv: bool = False # extra 3x3 bottleneck convolution for NFNet models
gamma_in_act: bool = False
same_padding: bool = False
skipinit: bool = False # disabled by default, non-trivial performance impact
zero_init_fc: bool = False
act_layer: str = 'silu'
@ -163,8 +191,26 @@ def _nfnet_cfg(
return cfg
def _dm_nfnet_cfg(depths, channels=(256, 512, 1536, 1536), act_layer='gelu', skipinit=True):
attn_kwargs = dict(reduction_ratio=0.5, divisor=8)
cfg = NfCfg(
depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=128,
bottle_ratio=0.5, extra_conv=True, gamma_in_act=True, same_padding=True, skipinit=skipinit,
num_features=int(channels[-1] * 2.0), act_layer=act_layer, attn_layer='se', attn_kwargs=attn_kwargs)
return cfg
model_cfgs = dict(
# NFNet-F models w/ GeLU
# NFNet-F models w/ GELU compatible with DeepMind weights
dm_nfnet_f0=_dm_nfnet_cfg(depths=(1, 2, 6, 3)),
dm_nfnet_f1=_dm_nfnet_cfg(depths=(2, 4, 12, 6)),
dm_nfnet_f2=_dm_nfnet_cfg(depths=(3, 6, 18, 9)),
dm_nfnet_f3=_dm_nfnet_cfg(depths=(4, 8, 24, 12)),
dm_nfnet_f4=_dm_nfnet_cfg(depths=(5, 10, 30, 15)),
dm_nfnet_f5=_dm_nfnet_cfg(depths=(6, 12, 36, 18)),
dm_nfnet_f6=_dm_nfnet_cfg(depths=(7, 14, 42, 21)),
# NFNet-F models w/ GELU (I will likely deprecate/remove these models and just keep dm_ ver for GELU)
nfnet_f0=_nfnet_cfg(depths=(1, 2, 6, 3)),
nfnet_f1=_nfnet_cfg(depths=(2, 4, 12, 6)),
nfnet_f2=_nfnet_cfg(depths=(3, 6, 18, 9)),
@ -229,7 +275,7 @@ class GammaAct(nn.Module):
self.inplace = inplace
def forward(self, x):
return self.gamma * self.act_fn(x, inplace=self.inplace)
return self.act_fn(x, inplace=self.inplace).mul_(self.gamma)
def act_with_gamma(act_type, gamma: float = 1.):
@ -325,8 +371,7 @@ class NormFreeBlock(nn.Module):
out = self.drop_path(out)
if self.skipinit_gain is not None:
# this really slows things down for some reason, TBD
out = out * self.skipinit_gain
out.mul_(self.skipinit_gain) # this slows things down more than expected, TBD
out = out * self.alpha + shortcut
return out
@ -419,12 +464,13 @@ class NormFreeNet(nn.Module):
self.num_classes = num_classes
self.drop_rate = drop_rate
assert cfg.act_layer in _nonlin_gamma, f"Please add non-linearity constants for activation ({cfg.act_layer})."
conv_layer = ScaledStdConv2dSame if cfg.same_padding else ScaledStdConv2d
if cfg.gamma_in_act:
act_layer = act_with_gamma(cfg.act_layer, gamma=_nonlin_gamma[cfg.act_layer])
conv_layer = partial(ScaledStdConv2d, bias=True, gain=True)
conv_layer = partial(conv_layer, eps=1e-4) # DM weights better with higher eps
else:
act_layer = get_act_layer(cfg.act_layer)
conv_layer = partial(ScaledStdConv2d, bias=True, gain=True, gamma=_nonlin_gamma[cfg.act_layer])
conv_layer = partial(conv_layer, gamma=_nonlin_gamma[cfg.act_layer])
attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
stem_chs = make_divisible((cfg.stem_chs or cfg.channels[0]) * cfg.width_factor, cfg.ch_div)
@ -538,6 +584,69 @@ def _create_normfreenet(variant, pretrained=False, **kwargs):
**kwargs)
@register_model
def dm_nfnet_f0(pretrained=False, **kwargs):
""" NFNet-F0 (DeepMind weight compatible)
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('dm_nfnet_f0', pretrained=pretrained, **kwargs)
@register_model
def dm_nfnet_f1(pretrained=False, **kwargs):
""" NFNet-F1 (DeepMind weight compatible)
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('dm_nfnet_f1', pretrained=pretrained, **kwargs)
@register_model
def dm_nfnet_f2(pretrained=False, **kwargs):
""" NFNet-F2 (DeepMind weight compatible)
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('dm_nfnet_f2', pretrained=pretrained, **kwargs)
@register_model
def dm_nfnet_f3(pretrained=False, **kwargs):
""" NFNet-F3 (DeepMind weight compatible)
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('dm_nfnet_f3', pretrained=pretrained, **kwargs)
@register_model
def dm_nfnet_f4(pretrained=False, **kwargs):
""" NFNet-F4 (DeepMind weight compatible)
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('dm_nfnet_f4', pretrained=pretrained, **kwargs)
@register_model
def dm_nfnet_f5(pretrained=False, **kwargs):
""" NFNet-F5 (DeepMind weight compatible)
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('dm_nfnet_f5', pretrained=pretrained, **kwargs)
@register_model
def dm_nfnet_f6(pretrained=False, **kwargs):
""" NFNet-F6 (DeepMind weight compatible)
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('dm_nfnet_f6', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f0(pretrained=False, **kwargs):
""" NFNet-F0

@ -1,4 +1,6 @@
from .agc import adaptive_clip_grad
from .checkpoint_saver import CheckpointSaver
from .clip_grad import dispatch_clip_grad
from .cuda import ApexScaler, NativeScaler
from .distributed import distribute_bn, reduce_tensor
from .jit import set_jit_legacy

@ -0,0 +1,42 @@
""" Adaptive Gradient Clipping
An impl of AGC, as per (https://arxiv.org/abs/2102.06171):
@article{brock2021high,
author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan},
title={High-Performance Large-Scale Image Recognition Without Normalization},
journal={arXiv preprint arXiv:},
year={2021}
}
Code references:
* Official JAX impl (paper authors): https://github.com/deepmind/deepmind-research/tree/master/nfnets
* Phil Wang's PyTorch gist: https://gist.github.com/lucidrains/0d6560077edac419ab5d3aa29e674d5c
Hacked together by / Copyright 2021 Ross Wightman
"""
import torch
def unitwise_norm(x, norm_type=2.0):
if x.ndim <= 1:
return x.norm(norm_type)
else:
# works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor
# might need special cases for other weights (possibly MHA) where this may not be true
return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True)
def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
for p in parameters:
if p.grad is None:
continue
p_data = p.detach()
g_data = p.grad.detach()
max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor)
grad_norm = unitwise_norm(g_data, norm_type=norm_type)
clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6))
new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad)
p.grad.detach().copy_(new_grads)

@ -0,0 +1,23 @@
import torch
from timm.utils.agc import adaptive_clip_grad
def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0):
""" Dispatch to gradient clipping method
Args:
parameters (Iterable): model parameters to clip
value (float): clipping value/factor/norm, mode dependant
mode (str): clipping mode, one of 'norm', 'value', 'agc'
norm_type (float): p-norm, default 2.0
"""
if mode == 'norm':
torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type)
elif mode == 'value':
torch.nn.utils.clip_grad_value_(parameters, value)
elif mode == 'agc':
adaptive_clip_grad(parameters, value, norm_type=norm_type)
else:
assert False, f"Unknown clip mode ({mode})."

@ -11,15 +11,17 @@ except ImportError:
amp = None
has_apex = False
from .clip_grad import dispatch_clip_grad
class ApexScaler:
state_dict_key = "amp"
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False):
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward(create_graph=create_graph)
if clip_grad is not None:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), clip_grad)
dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode)
optimizer.step()
def state_dict(self):
@ -37,12 +39,12 @@ class NativeScaler:
def __init__(self):
self._scaler = torch.cuda.amp.GradScaler()
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False):
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
self._scaler.scale(loss).backward(create_graph=create_graph)
if clip_grad is not None:
assert parameters is not None
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
self._scaler.step(optimizer)
self._scaler.update()

@ -1 +1 @@
__version__ = '0.4.3'
__version__ = '0.4.4'

@ -29,7 +29,7 @@ import torchvision.utils
from torch.nn.parallel import DistributedDataParallel as NativeDDP
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model
from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model, model_parameters
from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer
@ -116,7 +116,8 @@ parser.add_argument('--weight-decay', type=float, default=0.0001,
help='weight decay (default: 0.0001)')
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--clip-mode', type=str, default='norm',
help='Gradient clipping mode. One of ("norm", "value", "agc")')
# Learning rate schedule parameters
@ -637,11 +638,16 @@ def train_one_epoch(
optimizer.zero_grad()
if loss_scaler is not None:
loss_scaler(
loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order)
loss, optimizer,
clip_grad=args.clip_grad, clip_mode=args.clip_mode,
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
create_graph=second_order)
else:
loss.backward(create_graph=second_order)
if args.clip_grad is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
dispatch_clip_grad(
model_parameters(model, exclude_head='agc' in args.clip_mode),
value=args.clip_grad, mode=args.clip_mode)
optimizer.step()
if model_ema is not None:

Loading…
Cancel
Save