|
|
|
@ -3,7 +3,6 @@
|
|
|
|
|
Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
|
|
|
|
|
- https://arxiv.org/abs/2101.08692
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Paper: `High-Performance Large-Scale Image Recognition Without Normalization`
|
|
|
|
|
- https://arxiv.org/abs/2102.06171
|
|
|
|
|
|
|
|
|
@ -11,8 +10,8 @@ Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/m
|
|
|
|
|
|
|
|
|
|
Status:
|
|
|
|
|
* These models are a work in progress, experiments ongoing.
|
|
|
|
|
* Two pretrained weights so far, more to come.
|
|
|
|
|
* Model details update to closer match official JAX code now that it's released
|
|
|
|
|
* Pretrained weights for two models so far, more to come.
|
|
|
|
|
* Model details updated to closer match official JAX code now that it's released
|
|
|
|
|
* NF-ResNet, NF-RegNet-B, and NFNet-F models supported
|
|
|
|
|
|
|
|
|
|
Hacked together by / copyright Ross Wightman, 2021.
|
|
|
|
@ -150,7 +149,7 @@ def _nfnet_cfg(depths, act_layer='gelu', attn_layer='se', attn_kwargs=None):
|
|
|
|
|
num_features = channels[-1] * 2
|
|
|
|
|
attn_kwargs = attn_kwargs or dict(reduction_ratio=0.5, divisor=8)
|
|
|
|
|
cfg = NfCfg(
|
|
|
|
|
depths=depths, channels=channels, stem_type='nff', group_size=128, bottle_ratio=0.5, extra_conv=True,
|
|
|
|
|
depths=depths, channels=channels, stem_type='deep_quad', group_size=128, bottle_ratio=0.5, extra_conv=True,
|
|
|
|
|
num_features=num_features, act_layer=act_layer, attn_layer=attn_layer, attn_kwargs=attn_kwargs)
|
|
|
|
|
return cfg
|
|
|
|
|
|
|
|
|
@ -176,9 +175,6 @@ model_cfgs = dict(
|
|
|
|
|
nfnet_f6s=_nfnet_cfg(depths=(7, 14, 42, 21), act_layer='silu'),
|
|
|
|
|
nfnet_f7s=_nfnet_cfg(depths=(8, 16, 48, 24), act_layer='silu'),
|
|
|
|
|
|
|
|
|
|
# NFNet-F models w/ SiLU (much faster in PyTorch)
|
|
|
|
|
# FIXME add remainder if silu vs gelu proves worthwhile
|
|
|
|
|
|
|
|
|
|
# EffNet influenced RegNet defs.
|
|
|
|
|
# NOTE: These aren't quite the official ver, ch_div=1 must be set for exact ch counts. I round to ch_div=8.
|
|
|
|
|
nf_regnet_b0=_nfreg_cfg(depths=(1, 3, 6, 6)),
|
|
|
|
@ -194,9 +190,9 @@ model_cfgs = dict(
|
|
|
|
|
nf_resnet50=_nfres_cfg(depths=(3, 4, 6, 3)),
|
|
|
|
|
nf_resnet101=_nfres_cfg(depths=(3, 4, 23, 3)),
|
|
|
|
|
|
|
|
|
|
nf_seresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
|
|
|
|
|
nf_seresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
|
|
|
|
|
nf_seresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
|
|
|
|
|
nf_seresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)),
|
|
|
|
|
nf_seresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)),
|
|
|
|
|
nf_seresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)),
|
|
|
|
|
|
|
|
|
|
nf_ecaresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='eca', attn_kwargs=dict()),
|
|
|
|
|
nf_ecaresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='eca', attn_kwargs=dict()),
|
|
|
|
@ -315,38 +311,26 @@ class NormFreeBlock(nn.Module):
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def stem_info(stem_type):
|
|
|
|
|
stem_stride = 2
|
|
|
|
|
if 'nff' in stem_type or 'pool' in stem_type:
|
|
|
|
|
stem_stride = 4
|
|
|
|
|
stem_feat = ''
|
|
|
|
|
if 'nff' in stem_type:
|
|
|
|
|
stem_feat = 'stem.act3'
|
|
|
|
|
elif 'deep' in stem_type and not 'pool' in stem_type:
|
|
|
|
|
stem_feat = 'stem.act2'
|
|
|
|
|
return stem_stride, stem_feat
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_stem(in_chs, out_chs, stem_type='', conv_layer=None, act_layer=None):
|
|
|
|
|
stem_stride = 2
|
|
|
|
|
stem_feature = ''
|
|
|
|
|
stem_feature = dict(num_chs=out_chs, reduction=2, module='')
|
|
|
|
|
stem = OrderedDict()
|
|
|
|
|
assert stem_type in ('', 'nff', 'deep', 'deep_tiered', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool')
|
|
|
|
|
assert stem_type in ('', 'deep', 'deep_tiered', 'deep_quad', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool')
|
|
|
|
|
if 'deep' in stem_type or 'nff' in stem_type:
|
|
|
|
|
# 3 deep 3x3 conv stack as in ResNet V1D models. NOTE: doesn't work as well here
|
|
|
|
|
if 'nff' in stem_type:
|
|
|
|
|
if 'quad' in stem_type:
|
|
|
|
|
assert not 'pool' in stem_type
|
|
|
|
|
stem_chs = (16, 32, 64, out_chs)
|
|
|
|
|
strides = (2, 1, 1, 2)
|
|
|
|
|
stem_stride = 4
|
|
|
|
|
stem_feature = 'stem.act4'
|
|
|
|
|
stem_feature = dict(num_chs=64, reduction=2, module='stem.act4')
|
|
|
|
|
else:
|
|
|
|
|
if 'tiered' in stem_type:
|
|
|
|
|
stem_chs = (3 * out_chs // 8, out_chs // 2, out_chs)
|
|
|
|
|
stem_chs = (3 * out_chs // 8, out_chs // 2, out_chs) # like 'T' resnets in resnet.py
|
|
|
|
|
else:
|
|
|
|
|
stem_chs = (out_chs // 2, out_chs // 2, out_chs)
|
|
|
|
|
stem_chs = (out_chs // 2, out_chs // 2, out_chs) # 'D' ResNets
|
|
|
|
|
strides = (2, 1, 1)
|
|
|
|
|
stem_feature = 'stem.act3'
|
|
|
|
|
stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.act3')
|
|
|
|
|
last_idx = len(stem_chs) - 1
|
|
|
|
|
for i, (c, s) in enumerate(zip(stem_chs, strides)):
|
|
|
|
|
stem[f'conv{i+1}'] = conv_layer(in_chs, c, kernel_size=3, stride=s)
|
|
|
|
@ -401,7 +385,7 @@ class NormFreeNet(nn.Module):
|
|
|
|
|
* activation correcting gamma constants are moved into the ScaledStdConv as it has less performance
|
|
|
|
|
impact in PyTorch when done with the weight scaling there. This likely wasn't a concern in the JAX impl.
|
|
|
|
|
* a config option `gamma_in_act` can be enabled to not apply gamma in StdConv as described above, but
|
|
|
|
|
apply it in each activation. This is slightly slower, and yields slightly different results.
|
|
|
|
|
apply it in each activation. This is slightly slower, numerically different, but matches official impl.
|
|
|
|
|
* skipinit is disabled by default, it seems to have a rather drastic impact on GPU memory use and throughput
|
|
|
|
|
for what it is/does. Approx 8-10% throughput loss.
|
|
|
|
|
"""
|
|
|
|
@ -424,7 +408,7 @@ class NormFreeNet(nn.Module):
|
|
|
|
|
self.stem, stem_stride, stem_feat = create_stem(
|
|
|
|
|
in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer, act_layer=act_layer)
|
|
|
|
|
|
|
|
|
|
self.feature_info = [dict(num_chs=stem_chs, reduction=2, module=stem_feat)] if stem_stride == 4 else []
|
|
|
|
|
self.feature_info = [stem_feat] if stem_stride == 4 else []
|
|
|
|
|
drop_path_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
|
|
|
|
|
prev_chs = stem_chs
|
|
|
|
|
net_stride = stem_stride
|
|
|
|
@ -476,7 +460,6 @@ class NormFreeNet(nn.Module):
|
|
|
|
|
# The paper NFRegNet models have an EfficientNet-like final head convolution.
|
|
|
|
|
self.num_features = make_divisible(cfg.width_factor * cfg.num_features, cfg.ch_div)
|
|
|
|
|
self.final_conv = conv_layer(prev_chs, self.num_features, 1)
|
|
|
|
|
# FIXME not 100% clear on gamma subtleties final conv/final act in case where it's pushed into stdconv
|
|
|
|
|
else:
|
|
|
|
|
self.num_features = prev_chs
|
|
|
|
|
self.final_conv = nn.Identity()
|
|
|
|
@ -554,10 +537,12 @@ def nfnet_f3(pretrained=False, **kwargs):
|
|
|
|
|
return _create_normfreenet('nfnet_f3', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def nfnet_f4(pretrained=False, **kwargs):
|
|
|
|
|
return _create_normfreenet('nfnet_f4', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def nfnet_f5(pretrained=False, **kwargs):
|
|
|
|
|
return _create_normfreenet('nfnet_f5', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
@ -567,6 +552,7 @@ def nfnet_f6(pretrained=False, **kwargs):
|
|
|
|
|
return _create_normfreenet('nfnet_f6', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def nfnet_f7(pretrained=False, **kwargs):
|
|
|
|
|
return _create_normfreenet('nfnet_f7', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
@ -591,10 +577,12 @@ def nfnet_f3s(pretrained=False, **kwargs):
|
|
|
|
|
return _create_normfreenet('nfnet_f3s', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def nfnet_f4s(pretrained=False, **kwargs):
|
|
|
|
|
return _create_normfreenet('nfnet_f4s', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def nfnet_f5s(pretrained=False, **kwargs):
|
|
|
|
|
return _create_normfreenet('nfnet_f5s', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
@ -604,6 +592,7 @@ def nfnet_f6s(pretrained=False, **kwargs):
|
|
|
|
|
return _create_normfreenet('nfnet_f6s', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def nfnet_f7s(pretrained=False, **kwargs):
|
|
|
|
|
return _create_normfreenet('nfnet_f7s', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|