diff --git a/README.md b/README.md index 421bced4..012f262e 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,20 @@ ## What's New +### Feb 18, 2021 +* Add pretrained weights and model variants for NFNet-F* models from [DeepMind Haiku impl](https://github.com/deepmind/deepmind-research/tree/master/nfnets). + * Models are prefixed with `dm_`. They require SAME padding conv, skipinit enabled, and activation gains applied in act fn. + * These models are big, expect to run out of GPU memory. With the GELU activiation + other options, they are roughly 1/2 the inference speed of my SiLU PyTorch optimized `s` variants. + * Original model results are based on pre-processing that is not the same as all other models so you'll see different results in the results csv (once updated). + * Matching the original pre-processing as closely as possible I get these results: + * `dm_nfnet_f6` - 86.352 + * `dm_nfnet_f5` - 86.100 + * `dm_nfnet_f4` - 85.834 + * `dm_nfnet_f3` - 85.676 + * `dm_nfnet_f2` - 85.178 + * `dm_nfnet_f1` - 84.696 + * `dm_nfnet_f0` - 83.464 + ### Feb 16, 2021 * Add Adaptive Gradient Clipping (AGC) as per https://arxiv.org/abs/2102.06171. Integrated w/ PyTorch gradient clipping via mode arg that defaults to prev 'norm' mode. For backward arg compat, clip-grad arg must be specified to enable when using train.py. * AGC w/ default clipping factor `--clip-grad .01 --clip-mode agc` diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 6eb9f8a1..f8d8d8c0 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -29,6 +29,6 @@ from .separable_conv import SeparableConv2d, SeparableConvBnAct from .space_to_depth import SpaceToDepthModule from .split_attn import SplitAttnConv2d from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model -from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d +from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame from .test_time_pool import TestTimePoolHead, apply_test_time_pool from .weight_init import trunc_normal_ diff --git a/timm/models/layers/std_conv.py b/timm/models/layers/std_conv.py index 80a8e5d7..cddfa258 100644 --- a/timm/models/layers/std_conv.py +++ b/timm/models/layers/std_conv.py @@ -2,8 +2,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .padding import get_padding -from .conv2d_same import conv2d_same +from .padding import get_padding, get_padding_value, pad_same def get_weight(module): @@ -19,8 +18,8 @@ class StdConv2d(nn.Conv2d): https://arxiv.org/abs/1903.10520v2 """ def __init__( - self, in_channel, out_channels, kernel_size, stride=1, - padding=None, dilation=1, groups=1, bias=False, eps=1e-5): + self, in_channel, out_channels, kernel_size, stride=1, padding=None, dilation=1, + groups=1, bias=False, eps=1e-5): if padding is None: padding = get_padding(kernel_size, stride, dilation) super().__init__( @@ -45,10 +44,13 @@ class StdConv2dSame(nn.Conv2d): https://arxiv.org/abs/1903.10520v2 """ def __init__( - self, in_channel, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=False, eps=1e-5): + self, in_channel, out_channels, kernel_size, stride=1, padding='SAME', dilation=1, + groups=1, bias=False, eps=1e-5): + padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) super().__init__( - in_channel, out_channels, kernel_size, stride=stride, - padding=0, dilation=dilation, groups=groups, bias=bias) + in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, + groups=groups, bias=bias) + self.same_pad = is_dynamic self.eps = eps def get_weight(self): @@ -57,7 +59,9 @@ class StdConv2dSame(nn.Conv2d): return weight def forward(self, x): - x = conv2d_same(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups) + if self.same_pad: + x = pad_same(x, self.kernel_size, self.stride, self.dilation) + x = F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups) return x @@ -68,17 +72,18 @@ class ScaledStdConv2d(nn.Conv2d): https://arxiv.org/abs/2101.08692 """ - def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1, - bias=True, gain=True, gamma=1.0, eps=1e-5, use_layernorm=False): + def __init__( + self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1, + bias=True, gamma=1.0, eps=1e-5, use_layernorm=False): if padding is None: padding = get_padding(kernel_size, stride, dilation) super().__init__( - in_channels, out_channels, kernel_size, stride=stride, - padding=padding, dilation=dilation, groups=groups, bias=bias) - self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) if gain else None + in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, + groups=groups, bias=bias) + self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in) self.eps = eps ** 2 if use_layernorm else eps - self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory use + self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel def get_weight(self): if self.use_layernorm: @@ -86,9 +91,52 @@ class ScaledStdConv2d(nn.Conv2d): else: std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) weight = self.scale * (self.weight - mean) / (std + self.eps) - if self.gain is not None: - weight = weight * self.gain - return weight + return self.gain * weight + + def forward(self, x): + return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups) + + +class ScaledStdConv2dSame(nn.Conv2d): + """Conv2d layer with Scaled Weight Standardization and Tensorflow-like SAME padding support + + NOTE: operations and default eps slightly changed from non-SAME impl to closer match Deepmind Haiku impl. + Fore the sake of completeness, numeric differences are minor with arprox .005 top-1 difference. + + Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - + https://arxiv.org/abs/2101.08692 + """ + + def __init__( + self, in_channels, out_channels, kernel_size, stride=1, padding='SAME', dilation=1, groups=1, + bias=True, gamma=1.0, eps=1e-5, use_layernorm=False): + padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) + super().__init__( + in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, + groups=groups, bias=bias) + self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) + self.scale = gamma * self.weight[0].numel() ** -0.5 + self.same_pad = is_dynamic + self.eps = eps ** 2 if use_layernorm else eps + self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel + + # NOTE an alternate formulation to consider, closer to DeepMind Haiku impl but doesn't seem + # to make much numerical difference (+/- .002 to .004) in top-1 during eval. + # def get_weight(self): + # var, mean = torch.var_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) + # scale = torch.rsqrt((self.weight[0].numel() * var).clamp_(self.eps)) * self.gain + # weight = (self.weight - mean) * scale + # return self.gain * weight + + def get_weight(self): + if self.use_layernorm: + weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps) + else: + std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) + weight = self.scale * (self.weight - mean) / (std + self.eps) + return self.gain * weight def forward(self, x): + if self.same_pad: + x = pad_same(x, self.kernel_size, self.stride, self.dilation) return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups) diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index b43ee5ef..dafe2efa 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -24,12 +24,12 @@ from functools import partial import torch import torch.nn as nn -import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg from .registry import register_model -from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, get_act_layer, get_attn, make_divisible, get_act_fn +from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\ + get_act_layer, get_act_fn, get_attn, make_divisible def _dcfg(url='', **kwargs): @@ -38,75 +38,102 @@ def _dcfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.9, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'stem.conv', 'classifier': 'head.fc', + 'first_conv': 'stem.conv1', 'classifier': 'head.fc', **kwargs } default_cfgs = dict( + dm_nfnet_f0=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f0-604f9c3a.pth', + pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), crop_pct=.9), + dm_nfnet_f1=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f1-fc540f82.pth', + pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320), crop_pct=0.91), + dm_nfnet_f2=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f2-89875923.pth', + pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352), crop_pct=0.92), + dm_nfnet_f3=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f3-d74ab3aa.pth', + pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416), crop_pct=0.94), + dm_nfnet_f4=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f4-0ac5b10b.pth', + pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512), crop_pct=0.951), + dm_nfnet_f5=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f5-ecb20ab1.pth', + pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544), crop_pct=0.954), + dm_nfnet_f6=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f6-e0f12116.pth', + pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576), crop_pct=0.956), + nfnet_f0=_dcfg( - url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'), + url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)), nfnet_f1=_dcfg( - url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320), first_conv='stem.conv1'), + url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320)), nfnet_f2=_dcfg( - url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352), first_conv='stem.conv1'), + url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352)), nfnet_f3=_dcfg( - url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416), first_conv='stem.conv1'), + url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416)), nfnet_f4=_dcfg( - url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512), first_conv='stem.conv1'), + url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512)), nfnet_f5=_dcfg( - url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544), first_conv='stem.conv1'), + url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544)), nfnet_f6=_dcfg( - url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576), first_conv='stem.conv1'), + url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576)), nfnet_f7=_dcfg( - url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608), first_conv='stem.conv1'), + url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608)), nfnet_f0s=_dcfg( - url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'), + url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)), nfnet_f1s=_dcfg( - url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320), first_conv='stem.conv1'), + url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320)), nfnet_f2s=_dcfg( - url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352), first_conv='stem.conv1'), + url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352)), nfnet_f3s=_dcfg( - url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416), first_conv='stem.conv1'), + url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416)), nfnet_f4s=_dcfg( - url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512), first_conv='stem.conv1'), + url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512)), nfnet_f5s=_dcfg( - url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544), first_conv='stem.conv1'), + url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544)), nfnet_f6s=_dcfg( - url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576), first_conv='stem.conv1'), + url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576)), nfnet_f7s=_dcfg( - url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608), first_conv='stem.conv1'), + url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608)), nfnet_l0a=_dcfg( - url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'), + url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)), nfnet_l0b=_dcfg( - url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'), + url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)), nfnet_l0c=_dcfg( - url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'), + url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)), - nf_regnet_b0=_dcfg(url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)), + nf_regnet_b0=_dcfg( + url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv'), nf_regnet_b1=_dcfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_regnet_b1_256_ra2-ad85cfef.pth', - pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288)), # NOT to paper spec - nf_regnet_b2=_dcfg(url='', pool_size=(8, 8), input_size=(3, 240, 240), test_input_size=(3, 272, 272)), - nf_regnet_b3=_dcfg(url='', pool_size=(9, 9), input_size=(3, 288, 288), test_input_size=(3, 320, 320)), - nf_regnet_b4=_dcfg(url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384)), - nf_regnet_b5=_dcfg(url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 456, 456)), - - nf_resnet26=_dcfg(url=''), + pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), first_conv='stem.conv'), # NOT to paper spec + nf_regnet_b2=_dcfg( + url='', pool_size=(8, 8), input_size=(3, 240, 240), test_input_size=(3, 272, 272), first_conv='stem.conv'), + nf_regnet_b3=_dcfg( + url='', pool_size=(9, 9), input_size=(3, 288, 288), test_input_size=(3, 320, 320), first_conv='stem.conv'), + nf_regnet_b4=_dcfg( + url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384), first_conv='stem.conv'), + nf_regnet_b5=_dcfg( + url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 456, 456), first_conv='stem.conv'), + + nf_resnet26=_dcfg(url='', first_conv='stem.conv'), nf_resnet50=_dcfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_resnet50_ra2-9f236009.pth', - pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), crop_pct=0.94), - nf_resnet101=_dcfg(url=''), + pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), crop_pct=0.94, first_conv='stem.conv'), + nf_resnet101=_dcfg(url='', first_conv='stem.conv'), - nf_seresnet26=_dcfg(url=''), - nf_seresnet50=_dcfg(url=''), - nf_seresnet101=_dcfg(url=''), + nf_seresnet26=_dcfg(url='', first_conv='stem.conv'), + nf_seresnet50=_dcfg(url='', first_conv='stem.conv'), + nf_seresnet101=_dcfg(url='', first_conv='stem.conv'), - nf_ecaresnet26=_dcfg(url=''), - nf_ecaresnet50=_dcfg(url=''), - nf_ecaresnet101=_dcfg(url=''), + nf_ecaresnet26=_dcfg(url='', first_conv='stem.conv'), + nf_ecaresnet50=_dcfg(url='', first_conv='stem.conv'), + nf_ecaresnet101=_dcfg(url='', first_conv='stem.conv'), ) @@ -115,7 +142,6 @@ class NfCfg: depths: Tuple[int, int, int, int] channels: Tuple[int, int, int, int] alpha: float = 0.2 - gamma_in_act: bool = False stem_type: str = '3x3' stem_chs: Optional[int] = None group_size: Optional[int] = None @@ -128,6 +154,8 @@ class NfCfg: ch_div: int = 8 # round channels % 8 == 0 to keep tensor-core use optimal reg: bool = False # enables EfficientNet-like options used in RegNet variants, expand from in_chs, se in middle extra_conv: bool = False # extra 3x3 bottleneck convolution for NFNet models + gamma_in_act: bool = False + same_padding: bool = False skipinit: bool = False # disabled by default, non-trivial performance impact zero_init_fc: bool = False act_layer: str = 'silu' @@ -163,8 +191,26 @@ def _nfnet_cfg( return cfg +def _dm_nfnet_cfg(depths, channels=(256, 512, 1536, 1536), act_layer='gelu', skipinit=True): + attn_kwargs = dict(reduction_ratio=0.5, divisor=8) + cfg = NfCfg( + depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=128, + bottle_ratio=0.5, extra_conv=True, gamma_in_act=True, same_padding=True, skipinit=skipinit, + num_features=int(channels[-1] * 2.0), act_layer=act_layer, attn_layer='se', attn_kwargs=attn_kwargs) + return cfg + + model_cfgs = dict( - # NFNet-F models w/ GeLU + # NFNet-F models w/ GELU compatible with DeepMind weights + dm_nfnet_f0=_dm_nfnet_cfg(depths=(1, 2, 6, 3)), + dm_nfnet_f1=_dm_nfnet_cfg(depths=(2, 4, 12, 6)), + dm_nfnet_f2=_dm_nfnet_cfg(depths=(3, 6, 18, 9)), + dm_nfnet_f3=_dm_nfnet_cfg(depths=(4, 8, 24, 12)), + dm_nfnet_f4=_dm_nfnet_cfg(depths=(5, 10, 30, 15)), + dm_nfnet_f5=_dm_nfnet_cfg(depths=(6, 12, 36, 18)), + dm_nfnet_f6=_dm_nfnet_cfg(depths=(7, 14, 42, 21)), + + # NFNet-F models w/ GELU (I will likely deprecate/remove these models and just keep dm_ ver for GELU) nfnet_f0=_nfnet_cfg(depths=(1, 2, 6, 3)), nfnet_f1=_nfnet_cfg(depths=(2, 4, 12, 6)), nfnet_f2=_nfnet_cfg(depths=(3, 6, 18, 9)), @@ -229,7 +275,7 @@ class GammaAct(nn.Module): self.inplace = inplace def forward(self, x): - return self.gamma * self.act_fn(x, inplace=self.inplace) + return self.act_fn(x, inplace=self.inplace).mul_(self.gamma) def act_with_gamma(act_type, gamma: float = 1.): @@ -325,8 +371,7 @@ class NormFreeBlock(nn.Module): out = self.drop_path(out) if self.skipinit_gain is not None: - # this really slows things down for some reason, TBD - out = out * self.skipinit_gain + out.mul_(self.skipinit_gain) # this slows things down more than expected, TBD out = out * self.alpha + shortcut return out @@ -419,12 +464,13 @@ class NormFreeNet(nn.Module): self.num_classes = num_classes self.drop_rate = drop_rate assert cfg.act_layer in _nonlin_gamma, f"Please add non-linearity constants for activation ({cfg.act_layer})." + conv_layer = ScaledStdConv2dSame if cfg.same_padding else ScaledStdConv2d if cfg.gamma_in_act: act_layer = act_with_gamma(cfg.act_layer, gamma=_nonlin_gamma[cfg.act_layer]) - conv_layer = partial(ScaledStdConv2d, bias=True, gain=True) + conv_layer = partial(conv_layer, eps=1e-4) # DM weights better with higher eps else: act_layer = get_act_layer(cfg.act_layer) - conv_layer = partial(ScaledStdConv2d, bias=True, gain=True, gamma=_nonlin_gamma[cfg.act_layer]) + conv_layer = partial(conv_layer, gamma=_nonlin_gamma[cfg.act_layer]) attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None stem_chs = make_divisible((cfg.stem_chs or cfg.channels[0]) * cfg.width_factor, cfg.ch_div) @@ -538,6 +584,69 @@ def _create_normfreenet(variant, pretrained=False, **kwargs): **kwargs) +@register_model +def dm_nfnet_f0(pretrained=False, **kwargs): + """ NFNet-F0 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f0', pretrained=pretrained, **kwargs) + + +@register_model +def dm_nfnet_f1(pretrained=False, **kwargs): + """ NFNet-F1 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f1', pretrained=pretrained, **kwargs) + + +@register_model +def dm_nfnet_f2(pretrained=False, **kwargs): + """ NFNet-F2 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f2', pretrained=pretrained, **kwargs) + + +@register_model +def dm_nfnet_f3(pretrained=False, **kwargs): + """ NFNet-F3 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f3', pretrained=pretrained, **kwargs) + + +@register_model +def dm_nfnet_f4(pretrained=False, **kwargs): + """ NFNet-F4 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f4', pretrained=pretrained, **kwargs) + + +@register_model +def dm_nfnet_f5(pretrained=False, **kwargs): + """ NFNet-F5 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f5', pretrained=pretrained, **kwargs) + + +@register_model +def dm_nfnet_f6(pretrained=False, **kwargs): + """ NFNet-F6 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f6', pretrained=pretrained, **kwargs) + + @register_model def nfnet_f0(pretrained=False, **kwargs): """ NFNet-F0