Fix issue with nfnet tests, bit more cleanup.

pull/427/head
Ross Wightman 4 years ago
parent cb06c7a910
commit 0d253e2c5e

@ -19,7 +19,9 @@ NON_STD_FILTERS = ['vit_*']
# exclude models that cause specific test failures
if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
# GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models
EXCLUDE_FILTERS = ['*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm'] + NON_STD_FILTERS
EXCLUDE_FILTERS = [
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm',
'nfnet_f4*', 'nfnet_f5*', 'nfnet_f6*', 'nfnet_f7*'] + NON_STD_FILTERS
else:
EXCLUDE_FILTERS = NON_STD_FILTERS

@ -3,7 +3,6 @@
Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
- https://arxiv.org/abs/2101.08692
Paper: `High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
@ -11,8 +10,8 @@ Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/m
Status:
* These models are a work in progress, experiments ongoing.
* Two pretrained weights so far, more to come.
* Model details update to closer match official JAX code now that it's released
* Pretrained weights for two models so far, more to come.
* Model details updated to closer match official JAX code now that it's released
* NF-ResNet, NF-RegNet-B, and NFNet-F models supported
Hacked together by / copyright Ross Wightman, 2021.
@ -150,7 +149,7 @@ def _nfnet_cfg(depths, act_layer='gelu', attn_layer='se', attn_kwargs=None):
num_features = channels[-1] * 2
attn_kwargs = attn_kwargs or dict(reduction_ratio=0.5, divisor=8)
cfg = NfCfg(
depths=depths, channels=channels, stem_type='nff', group_size=128, bottle_ratio=0.5, extra_conv=True,
depths=depths, channels=channels, stem_type='deep_quad', group_size=128, bottle_ratio=0.5, extra_conv=True,
num_features=num_features, act_layer=act_layer, attn_layer=attn_layer, attn_kwargs=attn_kwargs)
return cfg
@ -176,9 +175,6 @@ model_cfgs = dict(
nfnet_f6s=_nfnet_cfg(depths=(7, 14, 42, 21), act_layer='silu'),
nfnet_f7s=_nfnet_cfg(depths=(8, 16, 48, 24), act_layer='silu'),
# NFNet-F models w/ SiLU (much faster in PyTorch)
# FIXME add remainder if silu vs gelu proves worthwhile
# EffNet influenced RegNet defs.
# NOTE: These aren't quite the official ver, ch_div=1 must be set for exact ch counts. I round to ch_div=8.
nf_regnet_b0=_nfreg_cfg(depths=(1, 3, 6, 6)),
@ -194,9 +190,9 @@ model_cfgs = dict(
nf_resnet50=_nfres_cfg(depths=(3, 4, 6, 3)),
nf_resnet101=_nfres_cfg(depths=(3, 4, 23, 3)),
nf_seresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
nf_seresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
nf_seresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
nf_seresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)),
nf_seresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)),
nf_seresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)),
nf_ecaresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='eca', attn_kwargs=dict()),
nf_ecaresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='eca', attn_kwargs=dict()),
@ -315,38 +311,26 @@ class NormFreeBlock(nn.Module):
return out
def stem_info(stem_type):
stem_stride = 2
if 'nff' in stem_type or 'pool' in stem_type:
stem_stride = 4
stem_feat = ''
if 'nff' in stem_type:
stem_feat = 'stem.act3'
elif 'deep' in stem_type and not 'pool' in stem_type:
stem_feat = 'stem.act2'
return stem_stride, stem_feat
def create_stem(in_chs, out_chs, stem_type='', conv_layer=None, act_layer=None):
stem_stride = 2
stem_feature = ''
stem_feature = dict(num_chs=out_chs, reduction=2, module='')
stem = OrderedDict()
assert stem_type in ('', 'nff', 'deep', 'deep_tiered', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool')
assert stem_type in ('', 'deep', 'deep_tiered', 'deep_quad', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool')
if 'deep' in stem_type or 'nff' in stem_type:
# 3 deep 3x3 conv stack as in ResNet V1D models. NOTE: doesn't work as well here
if 'nff' in stem_type:
if 'quad' in stem_type:
assert not 'pool' in stem_type
stem_chs = (16, 32, 64, out_chs)
strides = (2, 1, 1, 2)
stem_stride = 4
stem_feature = 'stem.act4'
stem_feature = dict(num_chs=64, reduction=2, module='stem.act4')
else:
if 'tiered' in stem_type:
stem_chs = (3 * out_chs // 8, out_chs // 2, out_chs)
stem_chs = (3 * out_chs // 8, out_chs // 2, out_chs) # like 'T' resnets in resnet.py
else:
stem_chs = (out_chs // 2, out_chs // 2, out_chs)
stem_chs = (out_chs // 2, out_chs // 2, out_chs) # 'D' ResNets
strides = (2, 1, 1)
stem_feature = 'stem.act3'
stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.act3')
last_idx = len(stem_chs) - 1
for i, (c, s) in enumerate(zip(stem_chs, strides)):
stem[f'conv{i+1}'] = conv_layer(in_chs, c, kernel_size=3, stride=s)
@ -401,7 +385,7 @@ class NormFreeNet(nn.Module):
* activation correcting gamma constants are moved into the ScaledStdConv as it has less performance
impact in PyTorch when done with the weight scaling there. This likely wasn't a concern in the JAX impl.
* a config option `gamma_in_act` can be enabled to not apply gamma in StdConv as described above, but
apply it in each activation. This is slightly slower, and yields slightly different results.
apply it in each activation. This is slightly slower, numerically different, but matches official impl.
* skipinit is disabled by default, it seems to have a rather drastic impact on GPU memory use and throughput
for what it is/does. Approx 8-10% throughput loss.
"""
@ -424,7 +408,7 @@ class NormFreeNet(nn.Module):
self.stem, stem_stride, stem_feat = create_stem(
in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer, act_layer=act_layer)
self.feature_info = [dict(num_chs=stem_chs, reduction=2, module=stem_feat)] if stem_stride == 4 else []
self.feature_info = [stem_feat] if stem_stride == 4 else []
drop_path_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
prev_chs = stem_chs
net_stride = stem_stride
@ -476,7 +460,6 @@ class NormFreeNet(nn.Module):
# The paper NFRegNet models have an EfficientNet-like final head convolution.
self.num_features = make_divisible(cfg.width_factor * cfg.num_features, cfg.ch_div)
self.final_conv = conv_layer(prev_chs, self.num_features, 1)
# FIXME not 100% clear on gamma subtleties final conv/final act in case where it's pushed into stdconv
else:
self.num_features = prev_chs
self.final_conv = nn.Identity()
@ -554,10 +537,12 @@ def nfnet_f3(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f3', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f4(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f4', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f5(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f5', pretrained=pretrained, **kwargs)
@ -567,6 +552,7 @@ def nfnet_f6(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f6', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f7(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f7', pretrained=pretrained, **kwargs)
@ -591,10 +577,12 @@ def nfnet_f3s(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f3s', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f4s(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f4s', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f5s(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f5s', pretrained=pretrained, **kwargs)
@ -604,6 +592,7 @@ def nfnet_f6s(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f6s', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f7s(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f7s', pretrained=pretrained, **kwargs)

Loading…
Cancel
Save