Add NFNet-F model weights ported from DeepMind Haiku impl and new set of models w/ compatible config.

pull/447/head
Ross Wightman 3 years ago
parent 4ea5931964
commit 678ba4e0a2

@ -2,6 +2,20 @@
## What's New
### Feb 18, 2021
* Add pretrained weights and model variants for NFNet-F* models from [DeepMind Haiku impl](https://github.com/deepmind/deepmind-research/tree/master/nfnets).
* Models are prefixed with `dm_`. They require SAME padding conv, skipinit enabled, and activation gains applied in act fn.
* These models are big, expect to run out of GPU memory. With the GELU activiation + other options, they are roughly 1/2 the inference speed of my SiLU PyTorch optimized `s` variants.
* Original model results are based on pre-processing that is not the same as all other models so you'll see different results in the results csv (once updated).
* Matching the original pre-processing as closely as possible I get these results:
* `dm_nfnet_f6` - 86.352
* `dm_nfnet_f5` - 86.100
* `dm_nfnet_f4` - 85.834
* `dm_nfnet_f3` - 85.676
* `dm_nfnet_f2` - 85.178
* `dm_nfnet_f1` - 84.696
* `dm_nfnet_f0` - 83.464
### Feb 16, 2021
* Add Adaptive Gradient Clipping (AGC) as per https://arxiv.org/abs/2102.06171. Integrated w/ PyTorch gradient clipping via mode arg that defaults to prev 'norm' mode. For backward arg compat, clip-grad arg must be specified to enable when using train.py.
* AGC w/ default clipping factor `--clip-grad .01 --clip-mode agc`

@ -29,6 +29,6 @@ from .separable_conv import SeparableConv2d, SeparableConvBnAct
from .space_to_depth import SpaceToDepthModule
from .split_attn import SplitAttnConv2d
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
from .weight_init import trunc_normal_

@ -2,8 +2,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from .padding import get_padding
from .conv2d_same import conv2d_same
from .padding import get_padding, get_padding_value, pad_same
def get_weight(module):
@ -19,8 +18,8 @@ class StdConv2d(nn.Conv2d):
https://arxiv.org/abs/1903.10520v2
"""
def __init__(
self, in_channel, out_channels, kernel_size, stride=1,
padding=None, dilation=1, groups=1, bias=False, eps=1e-5):
self, in_channel, out_channels, kernel_size, stride=1, padding=None, dilation=1,
groups=1, bias=False, eps=1e-5):
if padding is None:
padding = get_padding(kernel_size, stride, dilation)
super().__init__(
@ -45,10 +44,13 @@ class StdConv2dSame(nn.Conv2d):
https://arxiv.org/abs/1903.10520v2
"""
def __init__(
self, in_channel, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=False, eps=1e-5):
self, in_channel, out_channels, kernel_size, stride=1, padding='SAME', dilation=1,
groups=1, bias=False, eps=1e-5):
padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation)
super().__init__(
in_channel, out_channels, kernel_size, stride=stride,
padding=0, dilation=dilation, groups=groups, bias=bias)
in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
groups=groups, bias=bias)
self.same_pad = is_dynamic
self.eps = eps
def get_weight(self):
@ -57,7 +59,9 @@ class StdConv2dSame(nn.Conv2d):
return weight
def forward(self, x):
x = conv2d_same(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
if self.same_pad:
x = pad_same(x, self.kernel_size, self.stride, self.dilation)
x = F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
return x
@ -68,17 +72,18 @@ class ScaledStdConv2d(nn.Conv2d):
https://arxiv.org/abs/2101.08692
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1,
bias=True, gain=True, gamma=1.0, eps=1e-5, use_layernorm=False):
def __init__(
self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1,
bias=True, gamma=1.0, eps=1e-5, use_layernorm=False):
if padding is None:
padding = get_padding(kernel_size, stride, dilation)
super().__init__(
in_channels, out_channels, kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=bias)
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) if gain else None
in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
groups=groups, bias=bias)
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1))
self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in)
self.eps = eps ** 2 if use_layernorm else eps
self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory use
self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel
def get_weight(self):
if self.use_layernorm:
@ -86,9 +91,52 @@ class ScaledStdConv2d(nn.Conv2d):
else:
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
weight = self.scale * (self.weight - mean) / (std + self.eps)
if self.gain is not None:
weight = weight * self.gain
return weight
return self.gain * weight
def forward(self, x):
return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
class ScaledStdConv2dSame(nn.Conv2d):
"""Conv2d layer with Scaled Weight Standardization and Tensorflow-like SAME padding support
NOTE: operations and default eps slightly changed from non-SAME impl to closer match Deepmind Haiku impl.
Fore the sake of completeness, numeric differences are minor with arprox .005 top-1 difference.
Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` -
https://arxiv.org/abs/2101.08692
"""
def __init__(
self, in_channels, out_channels, kernel_size, stride=1, padding='SAME', dilation=1, groups=1,
bias=True, gamma=1.0, eps=1e-5, use_layernorm=False):
padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation)
super().__init__(
in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
groups=groups, bias=bias)
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1))
self.scale = gamma * self.weight[0].numel() ** -0.5
self.same_pad = is_dynamic
self.eps = eps ** 2 if use_layernorm else eps
self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel
# NOTE an alternate formulation to consider, closer to DeepMind Haiku impl but doesn't seem
# to make much numerical difference (+/- .002 to .004) in top-1 during eval.
# def get_weight(self):
# var, mean = torch.var_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
# scale = torch.rsqrt((self.weight[0].numel() * var).clamp_(self.eps)) * self.gain
# weight = (self.weight - mean) * scale
# return self.gain * weight
def get_weight(self):
if self.use_layernorm:
weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
else:
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
weight = self.scale * (self.weight - mean) / (std + self.eps)
return self.gain * weight
def forward(self, x):
if self.same_pad:
x = pad_same(x, self.kernel_size, self.stride, self.dilation)
return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)

@ -24,12 +24,12 @@ from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .registry import register_model
from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, get_act_layer, get_attn, make_divisible, get_act_fn
from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\
get_act_layer, get_act_fn, get_attn, make_divisible
def _dcfg(url='', **kwargs):
@ -38,75 +38,102 @@ def _dcfg(url='', **kwargs):
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.9, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.conv', 'classifier': 'head.fc',
'first_conv': 'stem.conv1', 'classifier': 'head.fc',
**kwargs
}
default_cfgs = dict(
dm_nfnet_f0=_dcfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f0-604f9c3a.pth',
pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), crop_pct=.9),
dm_nfnet_f1=_dcfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f1-fc540f82.pth',
pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320), crop_pct=0.91),
dm_nfnet_f2=_dcfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f2-89875923.pth',
pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352), crop_pct=0.92),
dm_nfnet_f3=_dcfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f3-d74ab3aa.pth',
pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416), crop_pct=0.94),
dm_nfnet_f4=_dcfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f4-0ac5b10b.pth',
pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512), crop_pct=0.951),
dm_nfnet_f5=_dcfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f5-ecb20ab1.pth',
pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544), crop_pct=0.954),
dm_nfnet_f6=_dcfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f6-e0f12116.pth',
pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576), crop_pct=0.956),
nfnet_f0=_dcfg(
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'),
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
nfnet_f1=_dcfg(
url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320), first_conv='stem.conv1'),
url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320)),
nfnet_f2=_dcfg(
url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352), first_conv='stem.conv1'),
url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352)),
nfnet_f3=_dcfg(
url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416), first_conv='stem.conv1'),
url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416)),
nfnet_f4=_dcfg(
url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512), first_conv='stem.conv1'),
url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512)),
nfnet_f5=_dcfg(
url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544), first_conv='stem.conv1'),
url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544)),
nfnet_f6=_dcfg(
url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576), first_conv='stem.conv1'),
url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576)),
nfnet_f7=_dcfg(
url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608), first_conv='stem.conv1'),
url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608)),
nfnet_f0s=_dcfg(
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'),
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
nfnet_f1s=_dcfg(
url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320), first_conv='stem.conv1'),
url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320)),
nfnet_f2s=_dcfg(
url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352), first_conv='stem.conv1'),
url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352)),
nfnet_f3s=_dcfg(
url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416), first_conv='stem.conv1'),
url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416)),
nfnet_f4s=_dcfg(
url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512), first_conv='stem.conv1'),
url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512)),
nfnet_f5s=_dcfg(
url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544), first_conv='stem.conv1'),
url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544)),
nfnet_f6s=_dcfg(
url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576), first_conv='stem.conv1'),
url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576)),
nfnet_f7s=_dcfg(
url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608), first_conv='stem.conv1'),
url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608)),
nfnet_l0a=_dcfg(
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'),
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
nfnet_l0b=_dcfg(
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'),
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
nfnet_l0c=_dcfg(
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'),
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
nf_regnet_b0=_dcfg(url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
nf_regnet_b0=_dcfg(
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv'),
nf_regnet_b1=_dcfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_regnet_b1_256_ra2-ad85cfef.pth',
pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288)), # NOT to paper spec
nf_regnet_b2=_dcfg(url='', pool_size=(8, 8), input_size=(3, 240, 240), test_input_size=(3, 272, 272)),
nf_regnet_b3=_dcfg(url='', pool_size=(9, 9), input_size=(3, 288, 288), test_input_size=(3, 320, 320)),
nf_regnet_b4=_dcfg(url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384)),
nf_regnet_b5=_dcfg(url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 456, 456)),
nf_resnet26=_dcfg(url=''),
pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), first_conv='stem.conv'), # NOT to paper spec
nf_regnet_b2=_dcfg(
url='', pool_size=(8, 8), input_size=(3, 240, 240), test_input_size=(3, 272, 272), first_conv='stem.conv'),
nf_regnet_b3=_dcfg(
url='', pool_size=(9, 9), input_size=(3, 288, 288), test_input_size=(3, 320, 320), first_conv='stem.conv'),
nf_regnet_b4=_dcfg(
url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384), first_conv='stem.conv'),
nf_regnet_b5=_dcfg(
url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 456, 456), first_conv='stem.conv'),
nf_resnet26=_dcfg(url='', first_conv='stem.conv'),
nf_resnet50=_dcfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_resnet50_ra2-9f236009.pth',
pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), crop_pct=0.94),
nf_resnet101=_dcfg(url=''),
pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), crop_pct=0.94, first_conv='stem.conv'),
nf_resnet101=_dcfg(url='', first_conv='stem.conv'),
nf_seresnet26=_dcfg(url=''),
nf_seresnet50=_dcfg(url=''),
nf_seresnet101=_dcfg(url=''),
nf_seresnet26=_dcfg(url='', first_conv='stem.conv'),
nf_seresnet50=_dcfg(url='', first_conv='stem.conv'),
nf_seresnet101=_dcfg(url='', first_conv='stem.conv'),
nf_ecaresnet26=_dcfg(url=''),
nf_ecaresnet50=_dcfg(url=''),
nf_ecaresnet101=_dcfg(url=''),
nf_ecaresnet26=_dcfg(url='', first_conv='stem.conv'),
nf_ecaresnet50=_dcfg(url='', first_conv='stem.conv'),
nf_ecaresnet101=_dcfg(url='', first_conv='stem.conv'),
)
@ -115,7 +142,6 @@ class NfCfg:
depths: Tuple[int, int, int, int]
channels: Tuple[int, int, int, int]
alpha: float = 0.2
gamma_in_act: bool = False
stem_type: str = '3x3'
stem_chs: Optional[int] = None
group_size: Optional[int] = None
@ -128,6 +154,8 @@ class NfCfg:
ch_div: int = 8 # round channels % 8 == 0 to keep tensor-core use optimal
reg: bool = False # enables EfficientNet-like options used in RegNet variants, expand from in_chs, se in middle
extra_conv: bool = False # extra 3x3 bottleneck convolution for NFNet models
gamma_in_act: bool = False
same_padding: bool = False
skipinit: bool = False # disabled by default, non-trivial performance impact
zero_init_fc: bool = False
act_layer: str = 'silu'
@ -163,8 +191,26 @@ def _nfnet_cfg(
return cfg
def _dm_nfnet_cfg(depths, channels=(256, 512, 1536, 1536), act_layer='gelu', skipinit=True):
attn_kwargs = dict(reduction_ratio=0.5, divisor=8)
cfg = NfCfg(
depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=128,
bottle_ratio=0.5, extra_conv=True, gamma_in_act=True, same_padding=True, skipinit=skipinit,
num_features=int(channels[-1] * 2.0), act_layer=act_layer, attn_layer='se', attn_kwargs=attn_kwargs)
return cfg
model_cfgs = dict(
# NFNet-F models w/ GeLU
# NFNet-F models w/ GELU compatible with DeepMind weights
dm_nfnet_f0=_dm_nfnet_cfg(depths=(1, 2, 6, 3)),
dm_nfnet_f1=_dm_nfnet_cfg(depths=(2, 4, 12, 6)),
dm_nfnet_f2=_dm_nfnet_cfg(depths=(3, 6, 18, 9)),
dm_nfnet_f3=_dm_nfnet_cfg(depths=(4, 8, 24, 12)),
dm_nfnet_f4=_dm_nfnet_cfg(depths=(5, 10, 30, 15)),
dm_nfnet_f5=_dm_nfnet_cfg(depths=(6, 12, 36, 18)),
dm_nfnet_f6=_dm_nfnet_cfg(depths=(7, 14, 42, 21)),
# NFNet-F models w/ GELU (I will likely deprecate/remove these models and just keep dm_ ver for GELU)
nfnet_f0=_nfnet_cfg(depths=(1, 2, 6, 3)),
nfnet_f1=_nfnet_cfg(depths=(2, 4, 12, 6)),
nfnet_f2=_nfnet_cfg(depths=(3, 6, 18, 9)),
@ -229,7 +275,7 @@ class GammaAct(nn.Module):
self.inplace = inplace
def forward(self, x):
return self.gamma * self.act_fn(x, inplace=self.inplace)
return self.act_fn(x, inplace=self.inplace).mul_(self.gamma)
def act_with_gamma(act_type, gamma: float = 1.):
@ -325,8 +371,7 @@ class NormFreeBlock(nn.Module):
out = self.drop_path(out)
if self.skipinit_gain is not None:
# this really slows things down for some reason, TBD
out = out * self.skipinit_gain
out.mul_(self.skipinit_gain) # this slows things down more than expected, TBD
out = out * self.alpha + shortcut
return out
@ -419,12 +464,13 @@ class NormFreeNet(nn.Module):
self.num_classes = num_classes
self.drop_rate = drop_rate
assert cfg.act_layer in _nonlin_gamma, f"Please add non-linearity constants for activation ({cfg.act_layer})."
conv_layer = ScaledStdConv2dSame if cfg.same_padding else ScaledStdConv2d
if cfg.gamma_in_act:
act_layer = act_with_gamma(cfg.act_layer, gamma=_nonlin_gamma[cfg.act_layer])
conv_layer = partial(ScaledStdConv2d, bias=True, gain=True)
conv_layer = partial(conv_layer, eps=1e-4) # DM weights better with higher eps
else:
act_layer = get_act_layer(cfg.act_layer)
conv_layer = partial(ScaledStdConv2d, bias=True, gain=True, gamma=_nonlin_gamma[cfg.act_layer])
conv_layer = partial(conv_layer, gamma=_nonlin_gamma[cfg.act_layer])
attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
stem_chs = make_divisible((cfg.stem_chs or cfg.channels[0]) * cfg.width_factor, cfg.ch_div)
@ -538,6 +584,69 @@ def _create_normfreenet(variant, pretrained=False, **kwargs):
**kwargs)
@register_model
def dm_nfnet_f0(pretrained=False, **kwargs):
""" NFNet-F0 (DeepMind weight compatible)
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('dm_nfnet_f0', pretrained=pretrained, **kwargs)
@register_model
def dm_nfnet_f1(pretrained=False, **kwargs):
""" NFNet-F1 (DeepMind weight compatible)
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('dm_nfnet_f1', pretrained=pretrained, **kwargs)
@register_model
def dm_nfnet_f2(pretrained=False, **kwargs):
""" NFNet-F2 (DeepMind weight compatible)
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('dm_nfnet_f2', pretrained=pretrained, **kwargs)
@register_model
def dm_nfnet_f3(pretrained=False, **kwargs):
""" NFNet-F3 (DeepMind weight compatible)
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('dm_nfnet_f3', pretrained=pretrained, **kwargs)
@register_model
def dm_nfnet_f4(pretrained=False, **kwargs):
""" NFNet-F4 (DeepMind weight compatible)
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('dm_nfnet_f4', pretrained=pretrained, **kwargs)
@register_model
def dm_nfnet_f5(pretrained=False, **kwargs):
""" NFNet-F5 (DeepMind weight compatible)
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('dm_nfnet_f5', pretrained=pretrained, **kwargs)
@register_model
def dm_nfnet_f6(pretrained=False, **kwargs):
""" NFNet-F6 (DeepMind weight compatible)
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('dm_nfnet_f6', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f0(pretrained=False, **kwargs):
""" NFNet-F0

Loading…
Cancel
Save