From 0d253e2c5e29814ad6fd58be02a3dd154d9debc6 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 12 Feb 2021 21:05:41 -0800 Subject: [PATCH] Fix issue with nfnet tests, bit more cleanup. --- tests/test_models.py | 4 +++- timm/models/nfnet.py | 53 ++++++++++++++++++-------------------------- 2 files changed, 24 insertions(+), 33 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 3f1c4cda..407e0fe5 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -19,7 +19,9 @@ NON_STD_FILTERS = ['vit_*'] # exclude models that cause specific test failures if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system(): # GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models - EXCLUDE_FILTERS = ['*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm'] + NON_STD_FILTERS + EXCLUDE_FILTERS = [ + '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', + 'nfnet_f4*', 'nfnet_f5*', 'nfnet_f6*', 'nfnet_f7*'] + NON_STD_FILTERS else: EXCLUDE_FILTERS = NON_STD_FILTERS diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index b9c003e8..4dc848ba 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -3,7 +3,6 @@ Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - https://arxiv.org/abs/2101.08692 - Paper: `High-Performance Large-Scale Image Recognition Without Normalization` - https://arxiv.org/abs/2102.06171 @@ -11,8 +10,8 @@ Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/m Status: * These models are a work in progress, experiments ongoing. -* Two pretrained weights so far, more to come. -* Model details update to closer match official JAX code now that it's released +* Pretrained weights for two models so far, more to come. +* Model details updated to closer match official JAX code now that it's released * NF-ResNet, NF-RegNet-B, and NFNet-F models supported Hacked together by / copyright Ross Wightman, 2021. @@ -150,7 +149,7 @@ def _nfnet_cfg(depths, act_layer='gelu', attn_layer='se', attn_kwargs=None): num_features = channels[-1] * 2 attn_kwargs = attn_kwargs or dict(reduction_ratio=0.5, divisor=8) cfg = NfCfg( - depths=depths, channels=channels, stem_type='nff', group_size=128, bottle_ratio=0.5, extra_conv=True, + depths=depths, channels=channels, stem_type='deep_quad', group_size=128, bottle_ratio=0.5, extra_conv=True, num_features=num_features, act_layer=act_layer, attn_layer=attn_layer, attn_kwargs=attn_kwargs) return cfg @@ -176,9 +175,6 @@ model_cfgs = dict( nfnet_f6s=_nfnet_cfg(depths=(7, 14, 42, 21), act_layer='silu'), nfnet_f7s=_nfnet_cfg(depths=(8, 16, 48, 24), act_layer='silu'), - # NFNet-F models w/ SiLU (much faster in PyTorch) - # FIXME add remainder if silu vs gelu proves worthwhile - # EffNet influenced RegNet defs. # NOTE: These aren't quite the official ver, ch_div=1 must be set for exact ch counts. I round to ch_div=8. nf_regnet_b0=_nfreg_cfg(depths=(1, 3, 6, 6)), @@ -194,9 +190,9 @@ model_cfgs = dict( nf_resnet50=_nfres_cfg(depths=(3, 4, 6, 3)), nf_resnet101=_nfres_cfg(depths=(3, 4, 23, 3)), - nf_seresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)), - nf_seresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)), - nf_seresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)), + nf_seresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)), + nf_seresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)), + nf_seresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)), nf_ecaresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='eca', attn_kwargs=dict()), nf_ecaresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='eca', attn_kwargs=dict()), @@ -315,38 +311,26 @@ class NormFreeBlock(nn.Module): return out -def stem_info(stem_type): - stem_stride = 2 - if 'nff' in stem_type or 'pool' in stem_type: - stem_stride = 4 - stem_feat = '' - if 'nff' in stem_type: - stem_feat = 'stem.act3' - elif 'deep' in stem_type and not 'pool' in stem_type: - stem_feat = 'stem.act2' - return stem_stride, stem_feat - - def create_stem(in_chs, out_chs, stem_type='', conv_layer=None, act_layer=None): stem_stride = 2 - stem_feature = '' + stem_feature = dict(num_chs=out_chs, reduction=2, module='') stem = OrderedDict() - assert stem_type in ('', 'nff', 'deep', 'deep_tiered', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool') + assert stem_type in ('', 'deep', 'deep_tiered', 'deep_quad', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool') if 'deep' in stem_type or 'nff' in stem_type: # 3 deep 3x3 conv stack as in ResNet V1D models. NOTE: doesn't work as well here - if 'nff' in stem_type: + if 'quad' in stem_type: assert not 'pool' in stem_type stem_chs = (16, 32, 64, out_chs) strides = (2, 1, 1, 2) stem_stride = 4 - stem_feature = 'stem.act4' + stem_feature = dict(num_chs=64, reduction=2, module='stem.act4') else: if 'tiered' in stem_type: - stem_chs = (3 * out_chs // 8, out_chs // 2, out_chs) + stem_chs = (3 * out_chs // 8, out_chs // 2, out_chs) # like 'T' resnets in resnet.py else: - stem_chs = (out_chs // 2, out_chs // 2, out_chs) + stem_chs = (out_chs // 2, out_chs // 2, out_chs) # 'D' ResNets strides = (2, 1, 1) - stem_feature = 'stem.act3' + stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.act3') last_idx = len(stem_chs) - 1 for i, (c, s) in enumerate(zip(stem_chs, strides)): stem[f'conv{i+1}'] = conv_layer(in_chs, c, kernel_size=3, stride=s) @@ -401,7 +385,7 @@ class NormFreeNet(nn.Module): * activation correcting gamma constants are moved into the ScaledStdConv as it has less performance impact in PyTorch when done with the weight scaling there. This likely wasn't a concern in the JAX impl. * a config option `gamma_in_act` can be enabled to not apply gamma in StdConv as described above, but - apply it in each activation. This is slightly slower, and yields slightly different results. + apply it in each activation. This is slightly slower, numerically different, but matches official impl. * skipinit is disabled by default, it seems to have a rather drastic impact on GPU memory use and throughput for what it is/does. Approx 8-10% throughput loss. """ @@ -424,7 +408,7 @@ class NormFreeNet(nn.Module): self.stem, stem_stride, stem_feat = create_stem( in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer, act_layer=act_layer) - self.feature_info = [dict(num_chs=stem_chs, reduction=2, module=stem_feat)] if stem_stride == 4 else [] + self.feature_info = [stem_feat] if stem_stride == 4 else [] drop_path_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] prev_chs = stem_chs net_stride = stem_stride @@ -476,7 +460,6 @@ class NormFreeNet(nn.Module): # The paper NFRegNet models have an EfficientNet-like final head convolution. self.num_features = make_divisible(cfg.width_factor * cfg.num_features, cfg.ch_div) self.final_conv = conv_layer(prev_chs, self.num_features, 1) - # FIXME not 100% clear on gamma subtleties final conv/final act in case where it's pushed into stdconv else: self.num_features = prev_chs self.final_conv = nn.Identity() @@ -554,10 +537,12 @@ def nfnet_f3(pretrained=False, **kwargs): return _create_normfreenet('nfnet_f3', pretrained=pretrained, **kwargs) +@register_model def nfnet_f4(pretrained=False, **kwargs): return _create_normfreenet('nfnet_f4', pretrained=pretrained, **kwargs) +@register_model def nfnet_f5(pretrained=False, **kwargs): return _create_normfreenet('nfnet_f5', pretrained=pretrained, **kwargs) @@ -567,6 +552,7 @@ def nfnet_f6(pretrained=False, **kwargs): return _create_normfreenet('nfnet_f6', pretrained=pretrained, **kwargs) +@register_model def nfnet_f7(pretrained=False, **kwargs): return _create_normfreenet('nfnet_f7', pretrained=pretrained, **kwargs) @@ -591,10 +577,12 @@ def nfnet_f3s(pretrained=False, **kwargs): return _create_normfreenet('nfnet_f3s', pretrained=pretrained, **kwargs) +@register_model def nfnet_f4s(pretrained=False, **kwargs): return _create_normfreenet('nfnet_f4s', pretrained=pretrained, **kwargs) +@register_model def nfnet_f5s(pretrained=False, **kwargs): return _create_normfreenet('nfnet_f5s', pretrained=pretrained, **kwargs) @@ -604,6 +592,7 @@ def nfnet_f6s(pretrained=False, **kwargs): return _create_normfreenet('nfnet_f6s', pretrained=pretrained, **kwargs) +@register_model def nfnet_f7s(pretrained=False, **kwargs): return _create_normfreenet('nfnet_f7s', pretrained=pretrained, **kwargs)