Add ECA-NFNet-L0 weights and update model name. Update README and bump version to 0.4.6

pull/510/head
Ross Wightman 3 years ago
parent 5e2e4e7fb6
commit 740f32c96a

@ -23,6 +23,15 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
## What's New
### March 17, 2021
* Add new ECA-NFNet-L0 (rename `nfnet_l0c`->`eca_nfnet_l0`) weights trained by myself.
* 82.6 top-1 @ 288x288, 82.8 @ 320x320, trained at 224x224
* Uses SiLU activation, approx 2x faster than `dm_nfnet_f0` and 50% faster than `nfnet_f0s` w/ 1/3 param count
* Integrate [Hugging Face model hub](https://huggingface.co/models) into timm create_model and default_cfg handling for pretrained weight and config sharing (more on this soon!)
* Merge HardCoRe NAS models contributed by https://github.com/yoniaflalo
* Merge PyTorch trained EfficientNet-EL and pruned ES/EL variants contributed by [DeGirum](https://github.com/DeGirum)
### March 7, 2021
* First 0.4.x PyPi release w/ NFNets (& related), ByoB (GPU-Efficient, RepVGG, etc).
* Change feature extraction for pre-activation nets (NFNets, ResNetV2) to return features before activation.
@ -171,6 +180,7 @@ A full version of the list below with source links can be found in the [document
* MobileNet-V2 - https://arxiv.org/abs/1801.04381
* Single-Path NAS - https://arxiv.org/abs/1904.02877
* GPU-Efficient Networks - https://arxiv.org/abs/2006.14090
* HardCoRe-NAS - https://arxiv.org/abs/2102.11646
* HRNet - https://arxiv.org/abs/1908.07919
* Inception-V3 - https://arxiv.org/abs/1512.00567
* Inception-ResNet-V2 and Inception-V4 - https://arxiv.org/abs/1602.07261

@ -76,13 +76,13 @@ class ScaledStdConv2d(nn.Conv2d):
def __init__(
self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1,
bias=True, gamma=1.0, eps=1e-5, use_layernorm=False):
bias=True, gamma=1.0, eps=1e-5, gain_init=1.0, use_layernorm=False):
if padding is None:
padding = get_padding(kernel_size, stride, dilation)
super().__init__(
in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
groups=groups, bias=bias)
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1))
self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init))
self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in)
self.eps = eps ** 2 if use_layernorm else eps
self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel
@ -110,12 +110,12 @@ class ScaledStdConv2dSame(nn.Conv2d):
def __init__(
self, in_channels, out_channels, kernel_size, stride=1, padding='SAME', dilation=1, groups=1,
bias=True, gamma=1.0, eps=1e-5, use_layernorm=False):
bias=True, gamma=1.0, eps=1e-5, gain_init=1.0, use_layernorm=False):
padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation)
super().__init__(
in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
groups=groups, bias=bias)
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1))
self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init))
self.scale = gamma * self.weight[0].numel() ** -0.5
self.same_pad = is_dynamic
self.eps = eps ** 2 if use_layernorm else eps

@ -104,8 +104,8 @@ default_cfgs = dict(
url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288)),
nfnet_l0b=_dcfg(
url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288)),
nfnet_l0c=_dcfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nfnet_l0c-ad1045c2.pth',
eca_nfnet_l0=_dcfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l0_ra2-e3e9ac50.pth',
pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288), crop_pct=1.0),
nf_regnet_b0=_dcfg(
@ -238,7 +238,7 @@ model_cfgs = dict(
nfnet_l0b=_nfnet_cfg(
depths=(1, 2, 6, 3), channels=(256, 512, 1536, 1536), feat_mult=1.5, group_size=64, bottle_ratio=0.25,
attn_kwargs=dict(reduction_ratio=0.25, divisor=8), act_layer='silu'),
nfnet_l0c=_nfnet_cfg(
eca_nfnet_l0=_nfnet_cfg(
depths=(1, 2, 6, 3), channels=(256, 512, 1536, 1536), feat_mult=1.5, group_size=64, bottle_ratio=0.25,
attn_layer='eca', attn_kwargs=dict(), act_layer='silu'),
@ -343,7 +343,7 @@ class NormFreeBlock(nn.Module):
else:
self.attn = None
self.act3 = act_layer()
self.conv3 = conv_layer(mid_chs, out_chs, 1)
self.conv3 = conv_layer(mid_chs, out_chs, 1, gain_init=1. if skipinit else 0.)
if not reg and attn_layer is not None:
self.attn_last = attn_layer(out_chs) # ResNet blocks apply attn after conv3
else:
@ -804,11 +804,11 @@ def nfnet_l0b(pretrained=False, **kwargs):
@register_model
def nfnet_l0c(pretrained=False, **kwargs):
""" NFNet-L0c w/ SiLU
def eca_nfnet_l0(pretrained=False, **kwargs):
""" ECA-NFNet-L0 w/ SiLU
My experimental 'light' model w/ 1.5x final_conv mult, 64 group_size, .25 bottleneck & ECA attn
"""
return _create_normfreenet('nfnet_l0c', pretrained=pretrained, **kwargs)
return _create_normfreenet('eca_nfnet_l0', pretrained=pretrained, **kwargs)
@register_model

@ -1 +1 @@
__version__ = '0.4.5'
__version__ = '0.4.6'

Loading…
Cancel
Save