From e4de077021baa9bd6b0ba120a83030422c237561 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 11 Feb 2021 13:15:20 -0800 Subject: [PATCH 01/10] Add first 'Normalizer Free' models. nf_regnet_b1 79.3 @ 288x288 test, and nf_resnet50 80.3 @ 256x256 test (80.68 @ 288x288). --- README.md | 3 +++ timm/models/nfnet.py | 10 +++++++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index a5b4b536..572109de 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,9 @@ ## What's New ### Feb 10, 2021 +* First Normalizer-Free model training experiments done, + * nf_resnet50 - 80.68 top-1 @ 288x288, 80.31 @ 256x256 + * nf_regnet_b1 - 79.30 @ 288x288, 78.75 @ 256x256 * More model archs, incl a flexible ByobNet backbone ('Bring-your-own-blocks') * GPU-Efficient-Networks (https://github.com/idstcv/GPU-Efficient-Networks), impl in `byobnet.py` * RepVGG (https://github.com/DingXiaoH/RepVGG), impl in `byobnet.py` diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index c56c5780..7b79259c 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -34,17 +34,21 @@ def _dcfg(url='', **kwargs): **kwargs } -# FIXME finish + default_cfgs = { 'nf_regnet_b0': _dcfg(url=''), - 'nf_regnet_b1': _dcfg(url='', input_size=(3, 240, 240), pool_size=(8, 8)), + 'nf_regnet_b1': _dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_regnet_b1_256_ra2-ad85cfef.pth', + pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), crop_pct=0.9), 'nf_regnet_b2': _dcfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), 'nf_regnet_b3': _dcfg(url='', input_size=(3, 272, 272), pool_size=(9, 9)), 'nf_regnet_b4': _dcfg(url='', input_size=(3, 320, 320), pool_size=(10, 10)), 'nf_regnet_b5': _dcfg(url='', input_size=(3, 384, 384), pool_size=(12, 12)), 'nf_resnet26': _dcfg(url='', first_conv='stem.conv'), - 'nf_resnet50': _dcfg(url='', first_conv='stem.conv'), + 'nf_resnet50': _dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_resnet50_ra2-9f236009.pth', + first_conv='stem.conv', pool_size=(8, 8), input_size=(3, 256, 256), crop_pct=0.94), 'nf_resnet101': _dcfg(url='', first_conv='stem.conv'), 'nf_seresnet26': _dcfg(url='', first_conv='stem.conv'), From 607f9149b185780150b11818a3371ce81609837b Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 11 Feb 2021 22:06:06 -0800 Subject: [PATCH 02/10] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 572109de..81fe904b 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ ## What's New ### Feb 10, 2021 -* First Normalizer-Free model training experiments done, +* First Normalization-Free model training experiments done, * nf_resnet50 - 80.68 top-1 @ 288x288, 80.31 @ 256x256 * nf_regnet_b1 - 79.30 @ 288x288, 78.75 @ 256x256 * More model archs, incl a flexible ByobNet backbone ('Bring-your-own-blocks') From cb06c7a910cb9b1078679bc67c76afcbb7453d3c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 12 Feb 2021 18:28:56 -0800 Subject: [PATCH 03/10] Add NFNet-F models and tweak existing NF models. --- timm/models/nfnet.py | 475 ++++++++++++++++++++++++++++++++----------- 1 file changed, 354 insertions(+), 121 deletions(-) diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 7b79259c..b9c003e8 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -1,10 +1,19 @@ -""" Normalizer Free RegNet / ResNet (pre-activation) Models +""" Normalization Free Nets. NFNet, NF-RegNet, NF-ResNet (pre-activation) Models Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - https://arxiv.org/abs/2101.08692 -NOTE: These models are a work in progress, no pretrained weights yet but I'm currently training some. -Details may change, especially once the paper authors release their official models. + +Paper: `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + +Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets + +Status: +* These models are a work in progress, experiments ongoing. +* Two pretrained weights so far, more to come. +* Model details update to closer match official JAX code now that it's released +* NF-ResNet, NF-RegNet-B, and NFNet-F models supported Hacked together by / copyright Ross Wightman, 2021. """ @@ -28,37 +37,71 @@ def _dcfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'crop_pct': 0.9, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'stem.conv', 'classifier': 'head.fc', **kwargs } -default_cfgs = { - 'nf_regnet_b0': _dcfg(url=''), - 'nf_regnet_b1': _dcfg( +default_cfgs = dict( + nfnet_f0=_dcfg( + url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'), + nfnet_f1=_dcfg( + url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320), first_conv='stem.conv1'), + nfnet_f2=_dcfg( + url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352), first_conv='stem.conv1'), + nfnet_f3=_dcfg( + url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416), first_conv='stem.conv1'), + nfnet_f4=_dcfg( + url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512), first_conv='stem.conv1'), + nfnet_f5=_dcfg( + url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544), first_conv='stem.conv1'), + nfnet_f6=_dcfg( + url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576), first_conv='stem.conv1'), + nfnet_f7=_dcfg( + url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608), first_conv='stem.conv1'), + + nfnet_f0s=_dcfg( + url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'), + nfnet_f1s=_dcfg( + url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320), first_conv='stem.conv1'), + nfnet_f2s=_dcfg( + url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352), first_conv='stem.conv1'), + nfnet_f3s=_dcfg( + url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416), first_conv='stem.conv1'), + nfnet_f4s=_dcfg( + url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512), first_conv='stem.conv1'), + nfnet_f5s=_dcfg( + url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544), first_conv='stem.conv1'), + nfnet_f6s=_dcfg( + url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576), first_conv='stem.conv1'), + nfnet_f7s=_dcfg( + url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608), first_conv='stem.conv1'), + + nf_regnet_b0=_dcfg(url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)), + nf_regnet_b1=_dcfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_regnet_b1_256_ra2-ad85cfef.pth', - pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), crop_pct=0.9), - 'nf_regnet_b2': _dcfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), - 'nf_regnet_b3': _dcfg(url='', input_size=(3, 272, 272), pool_size=(9, 9)), - 'nf_regnet_b4': _dcfg(url='', input_size=(3, 320, 320), pool_size=(10, 10)), - 'nf_regnet_b5': _dcfg(url='', input_size=(3, 384, 384), pool_size=(12, 12)), - - 'nf_resnet26': _dcfg(url='', first_conv='stem.conv'), - 'nf_resnet50': _dcfg( + pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288)), # NOT to paper spec + nf_regnet_b2=_dcfg(url='', pool_size=(8, 8), input_size=(3, 240, 240), test_input_size=(3, 272, 272)), + nf_regnet_b3=_dcfg(url='', pool_size=(9, 9), input_size=(3, 288, 288), test_input_size=(3, 320, 320)), + nf_regnet_b4=_dcfg(url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384)), + nf_regnet_b5=_dcfg(url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 456, 456)), + + nf_resnet26=_dcfg(url=''), + nf_resnet50=_dcfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_resnet50_ra2-9f236009.pth', - first_conv='stem.conv', pool_size=(8, 8), input_size=(3, 256, 256), crop_pct=0.94), - 'nf_resnet101': _dcfg(url='', first_conv='stem.conv'), + pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), crop_pct=0.94), + nf_resnet101=_dcfg(url=''), - 'nf_seresnet26': _dcfg(url='', first_conv='stem.conv'), - 'nf_seresnet50': _dcfg(url='', first_conv='stem.conv'), - 'nf_seresnet101': _dcfg(url='', first_conv='stem.conv'), + nf_seresnet26=_dcfg(url=''), + nf_seresnet50=_dcfg(url=''), + nf_seresnet101=_dcfg(url=''), - 'nf_ecaresnet26': _dcfg(url='', first_conv='stem.conv'), - 'nf_ecaresnet50': _dcfg(url='', first_conv='stem.conv'), - 'nf_ecaresnet101': _dcfg(url='', first_conv='stem.conv'), -} + nf_ecaresnet26=_dcfg(url=''), + nf_ecaresnet50=_dcfg(url=''), + nf_ecaresnet101=_dcfg(url=''), +) @dataclass @@ -69,69 +112,95 @@ class NfCfg: gamma_in_act: bool = False stem_type: str = '3x3' stem_chs: Optional[int] = None - group_size: Optional[int] = 8 - attn_layer: Optional[str] = 'se' - attn_kwargs: dict = field(default_factory=lambda: dict(reduction_ratio=0.5, divisor=8)) + group_size: Optional[int] = None + attn_layer: Optional[str] = None + attn_kwargs: dict = None attn_gain: float = 2.0 # NF correction gain to apply if attn layer is used - width_factor: float = 0.75 - bottle_ratio: float = 2.25 - efficient: bool = True # enables EfficientNet-like options that are used in paper 'nf_regnet_b*' models - num_features: int = 1280 # num out_channels for final conv (when enabled in efficient mode) + width_factor: float = 1.0 + bottle_ratio: float = 0.5 + num_features: int = 0 # num out_channels for final conv, no final_conv if 0 ch_div: int = 8 # round channels % 8 == 0 to keep tensor-core use optimal - skipinit: bool = False + reg: bool = False # enables EfficientNet-like options used in RegNet variants, expand from in_chs, se in middle + extra_conv: bool = False # extra 3x3 bottleneck convolution for NFNet models + skipinit: bool = False # disabled by default, non-trivial performance impact + zero_init_fc: bool = False act_layer: str = 'silu' +def _nfres_cfg( + depths, channels=(256, 512, 1024, 2048), group_size=None, act_layer='relu', attn_layer=None, attn_kwargs=None): + attn_kwargs = attn_kwargs or {} + cfg = NfCfg( + depths=depths, channels=channels, stem_type='7x7_pool', stem_chs=64, bottle_ratio=0.25, + group_size=group_size, act_layer=act_layer, attn_layer=attn_layer, attn_kwargs=attn_kwargs) + return cfg + + +def _nfreg_cfg(depths, channels=(48, 104, 208, 440)): + num_features = 1280 * channels[-1] // 440 + attn_kwargs = dict(reduction_ratio=0.5, divisor=8) + cfg = NfCfg( + depths=depths, channels=channels, stem_type='3x3', group_size=8, width_factor=0.75, bottle_ratio=2.25, + num_features=num_features, reg=True, attn_layer='se', attn_kwargs=attn_kwargs) + return cfg + + +def _nfnet_cfg(depths, act_layer='gelu', attn_layer='se', attn_kwargs=None): + channels = (256, 512, 1536, 1536) + num_features = channels[-1] * 2 + attn_kwargs = attn_kwargs or dict(reduction_ratio=0.5, divisor=8) + cfg = NfCfg( + depths=depths, channels=channels, stem_type='nff', group_size=128, bottle_ratio=0.5, extra_conv=True, + num_features=num_features, act_layer=act_layer, attn_layer=attn_layer, attn_kwargs=attn_kwargs) + return cfg + + model_cfgs = dict( - # EffNet influenced RegNet defs - nf_regnet_b0=NfCfg(depths=(1, 3, 6, 6), channels=(48, 104, 208, 440), num_features=1280), - nf_regnet_b1=NfCfg(depths=(2, 4, 7, 7), channels=(48, 104, 208, 440), num_features=1280), - nf_regnet_b2=NfCfg(depths=(2, 4, 8, 8), channels=(56, 112, 232, 488), num_features=1416), - nf_regnet_b3=NfCfg(depths=(2, 5, 9, 9), channels=(56, 128, 248, 528), num_features=1536), - nf_regnet_b4=NfCfg(depths=(2, 6, 11, 11), channels=(64, 144, 288, 616), num_features=1792), - nf_regnet_b5=NfCfg(depths=(3, 7, 14, 14), channels=(80, 168, 336, 704), num_features=2048), + # NFNet-F models w/ GeLU + nfnet_f0=_nfnet_cfg(depths=(1, 2, 6, 3)), + nfnet_f1=_nfnet_cfg(depths=(2, 4, 12, 6)), + nfnet_f2=_nfnet_cfg(depths=(3, 6, 18, 9)), + nfnet_f3=_nfnet_cfg(depths=(4, 8, 24, 12)), + nfnet_f4=_nfnet_cfg(depths=(5, 10, 30, 15)), + nfnet_f5=_nfnet_cfg(depths=(6, 12, 36, 18)), + nfnet_f6=_nfnet_cfg(depths=(7, 14, 42, 21)), + nfnet_f7=_nfnet_cfg(depths=(8, 16, 48, 24)), + + # NFNet-F models w/ SiLU (much faster in PyTorch) + nfnet_f0s=_nfnet_cfg(depths=(1, 2, 6, 3), act_layer='silu'), + nfnet_f1s=_nfnet_cfg(depths=(2, 4, 12, 6), act_layer='silu'), + nfnet_f2s=_nfnet_cfg(depths=(3, 6, 18, 9), act_layer='silu'), + nfnet_f3s=_nfnet_cfg(depths=(4, 8, 24, 12), act_layer='silu'), + nfnet_f4s=_nfnet_cfg(depths=(5, 10, 30, 15), act_layer='silu'), + nfnet_f5s=_nfnet_cfg(depths=(6, 12, 36, 18), act_layer='silu'), + nfnet_f6s=_nfnet_cfg(depths=(7, 14, 42, 21), act_layer='silu'), + nfnet_f7s=_nfnet_cfg(depths=(8, 16, 48, 24), act_layer='silu'), + + # NFNet-F models w/ SiLU (much faster in PyTorch) + # FIXME add remainder if silu vs gelu proves worthwhile + + # EffNet influenced RegNet defs. + # NOTE: These aren't quite the official ver, ch_div=1 must be set for exact ch counts. I round to ch_div=8. + nf_regnet_b0=_nfreg_cfg(depths=(1, 3, 6, 6)), + nf_regnet_b1=_nfreg_cfg(depths=(2, 4, 7, 7)), + nf_regnet_b2=_nfreg_cfg(depths=(2, 4, 8, 8), channels=(56, 112, 232, 488)), + nf_regnet_b3=_nfreg_cfg(depths=(2, 5, 9, 9), channels=(56, 128, 248, 528)), + nf_regnet_b4=_nfreg_cfg(depths=(2, 6, 11, 11), channels=(64, 144, 288, 616)), + nf_regnet_b5=_nfreg_cfg(depths=(3, 7, 14, 14), channels=(80, 168, 336, 704)), + # FIXME add B6-B8 # ResNet (preact, D style deep stem/avg down) defs - nf_resnet26=NfCfg( - depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048), - stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, - act_layer='relu', attn_layer=None,), - nf_resnet50=NfCfg( - depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048), - stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, - act_layer='relu', attn_layer=None), - nf_resnet101=NfCfg( - depths=(3, 4, 23, 3), channels=(256, 512, 1024, 2048), - stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, - act_layer='relu', attn_layer=None), - - - nf_seresnet26=NfCfg( - depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048), - stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, - act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)), - nf_seresnet50=NfCfg( - depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048), - stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, - act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)), - nf_seresnet101=NfCfg( - depths=(3, 4, 23, 3), channels=(256, 512, 1024, 2048), - stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, - act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)), - - - nf_ecaresnet26=NfCfg( - depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048), - stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, - act_layer='relu', attn_layer='eca', attn_kwargs=dict()), - nf_ecaresnet50=NfCfg( - depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048), - stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, - act_layer='relu', attn_layer='eca', attn_kwargs=dict()), - nf_ecaresnet101=NfCfg( - depths=(3, 4, 23, 3), channels=(256, 512, 1024, 2048), - stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, - act_layer='relu', attn_layer='eca', attn_kwargs=dict()), + nf_resnet26=_nfres_cfg(depths=(2, 2, 2, 2)), + nf_resnet50=_nfres_cfg(depths=(3, 4, 6, 3)), + nf_resnet101=_nfres_cfg(depths=(3, 4, 23, 3)), + + nf_seresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)), + nf_seresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)), + nf_seresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)), + + nf_ecaresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='eca', attn_kwargs=dict()), + nf_ecaresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='eca', attn_kwargs=dict()), + nf_ecaresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='eca', attn_kwargs=dict()), ) @@ -170,20 +239,20 @@ class DownsampleAvg(nn.Module): return self.conv(self.pool(x)) -class NormalizationFreeBlock(nn.Module): +class NormFreeBlock(nn.Module): """Normalization-free pre-activation block. """ def __init__( self, in_chs, out_chs=None, stride=1, dilation=1, first_dilation=None, - alpha=1.0, beta=1.0, bottle_ratio=0.25, efficient=True, ch_div=1, group_size=None, - attn_layer=None, attn_gain=2.0, act_layer=None, conv_layer=None, drop_path_rate=0., skipinit=False): + alpha=1.0, beta=1.0, bottle_ratio=0.25, group_size=None, ch_div=1, reg=True, extra_conv=False, + skipinit=False, attn_layer=None, attn_gain=2.0, act_layer=None, conv_layer=None, drop_path_rate=0.): super().__init__() first_dilation = first_dilation or dilation out_chs = out_chs or in_chs - # EfficientNet-like models scale bottleneck from in_chs, otherwise scale from out_chs like ResNet - mid_chs = make_divisible(in_chs * bottle_ratio if efficient else out_chs * bottle_ratio, ch_div) - groups = 1 if group_size is None else mid_chs // group_size + # RegNet variants scale bottleneck from in_chs, otherwise scale from out_chs like ResNet + mid_chs = make_divisible(in_chs * bottle_ratio if reg else out_chs * bottle_ratio, ch_div) + groups = 1 if not group_size else mid_chs // group_size if group_size and group_size % ch_div == 0: mid_chs = group_size * groups # correct mid_chs if group_size divisible by ch_div, otherwise error self.alpha = alpha @@ -200,12 +269,22 @@ class NormalizationFreeBlock(nn.Module): self.conv1 = conv_layer(in_chs, mid_chs, 1) self.act2 = act_layer(inplace=True) self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups) - if attn_layer is not None: - self.attn = attn_layer(mid_chs) + if extra_conv: + self.act2b = act_layer(inplace=True) + self.conv2b = conv_layer(mid_chs, mid_chs, 3, stride=1, dilation=dilation, groups=groups) + else: + self.act2b = None + self.conv2b = None + if reg and attn_layer is not None: + self.attn = attn_layer(mid_chs) # RegNet blocks apply attn btw conv2 & 3 else: self.attn = None self.act3 = act_layer() self.conv3 = conv_layer(mid_chs, out_chs, 1) + if not reg and attn_layer is not None: + self.attn_last = attn_layer(out_chs) # ResNet blocks apply attn after conv3 + else: + self.attn_last = None self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() self.skipinit_gain = nn.Parameter(torch.tensor(0.)) if skipinit else None @@ -220,28 +299,60 @@ class NormalizationFreeBlock(nn.Module): # residual branch out = self.conv1(out) out = self.conv2(self.act2(out)) + if self.conv2b is not None: + out = self.conv2b(self.act2b(out)) if self.attn is not None: out = self.attn_gain * self.attn(out) out = self.conv3(self.act3(out)) + if self.attn_last is not None: + out = self.attn_gain * self.attn_last(out) out = self.drop_path(out) - if self.skipinit_gain is None: - out = out * self.alpha + shortcut - else: + + if self.skipinit_gain is not None: # this really slows things down for some reason, TBD - out = out * self.alpha * self.skipinit_gain + shortcut + out = out * self.skipinit_gain + out = out * self.alpha + shortcut return out -def create_stem(in_chs, out_chs, stem_type='', conv_layer=None): +def stem_info(stem_type): + stem_stride = 2 + if 'nff' in stem_type or 'pool' in stem_type: + stem_stride = 4 + stem_feat = '' + if 'nff' in stem_type: + stem_feat = 'stem.act3' + elif 'deep' in stem_type and not 'pool' in stem_type: + stem_feat = 'stem.act2' + return stem_stride, stem_feat + + +def create_stem(in_chs, out_chs, stem_type='', conv_layer=None, act_layer=None): stem_stride = 2 + stem_feature = '' stem = OrderedDict() - assert stem_type in ('', 'deep', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool') - if 'deep' in stem_type: + assert stem_type in ('', 'nff', 'deep', 'deep_tiered', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool') + if 'deep' in stem_type or 'nff' in stem_type: # 3 deep 3x3 conv stack as in ResNet V1D models. NOTE: doesn't work as well here - mid_chs = out_chs // 2 - stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2) - stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1) - stem['conv3'] = conv_layer(mid_chs, out_chs, kernel_size=3, stride=1) + if 'nff' in stem_type: + assert not 'pool' in stem_type + stem_chs = (16, 32, 64, out_chs) + strides = (2, 1, 1, 2) + stem_stride = 4 + stem_feature = 'stem.act4' + else: + if 'tiered' in stem_type: + stem_chs = (3 * out_chs // 8, out_chs // 2, out_chs) + else: + stem_chs = (out_chs // 2, out_chs // 2, out_chs) + strides = (2, 1, 1) + stem_feature = 'stem.act3' + last_idx = len(stem_chs) - 1 + for i, (c, s) in enumerate(zip(stem_chs, strides)): + stem[f'conv{i+1}'] = conv_layer(in_chs, c, kernel_size=3, stride=s) + if i != last_idx: + stem[f'act{i+2}'] = act_layer(inplace=True) + in_chs = c elif '3x3' in stem_type: # 3x3 stem conv as in RegNet stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=3, stride=2) @@ -253,18 +364,30 @@ def create_stem(in_chs, out_chs, stem_type='', conv_layer=None): stem['pool'] = nn.MaxPool2d(3, stride=2, padding=1) stem_stride = 4 - return nn.Sequential(stem), stem_stride + return nn.Sequential(stem), stem_stride, stem_feature _nonlin_gamma = dict( - silu=1./.5595, - relu=(0.5 * (1. - 1. / math.pi)) ** -0.5, - identity=1.0 + identity=1.0, + celu=1.270926833152771, + elu=1.2716004848480225, + gelu=1.7015043497085571, + leaky_relu=1.70590341091156, + log_sigmoid=1.9193484783172607, + log_softmax=1.0002083778381348, + relu=1.7139588594436646, + relu6=1.7131484746932983, + selu=1.0008515119552612, + sigmoid=4.803835391998291, + silu=1.7881293296813965, + softsign=2.338853120803833, + softplus=1.9203323125839233, + tanh=1.5939117670059204, ) -class NormalizerFreeNet(nn.Module): - """ Normalizer-free ResNets and RegNets +class NormFreeNet(nn.Module): + """ Normalization-free ResNets and RegNets As described in `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - https://arxiv.org/abs/2101.08692 @@ -298,10 +421,11 @@ class NormalizerFreeNet(nn.Module): stem_chs = cfg.stem_chs or cfg.channels[0] stem_chs = make_divisible(stem_chs * cfg.width_factor, cfg.ch_div) - self.stem, stem_stride = create_stem(in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer) + self.stem, stem_stride, stem_feat = create_stem( + in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer, act_layer=act_layer) - self.feature_info = [] # NOTE: there will be no stride == 2 feature if stem_stride == 4 - dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] + self.feature_info = [dict(num_chs=stem_chs, reduction=2, module=stem_feat)] if stem_stride == 4 else [] + drop_path_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] prev_chs = stem_chs net_stride = stem_stride dilation = 1 @@ -309,8 +433,8 @@ class NormalizerFreeNet(nn.Module): stages = [] for stage_idx, stage_depth in enumerate(cfg.depths): stride = 1 if stage_idx == 0 and stem_stride > 2 else 2 - self.feature_info += [dict( - num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}.0.act1' if stride == 2 else '')] + if stride == 2: + self.feature_info += [dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}.0.act1')] if net_stride >= output_stride and stride > 1: dilation *= stride stride = 1 @@ -321,7 +445,7 @@ class NormalizerFreeNet(nn.Module): for block_idx in range(cfg.depths[stage_idx]): first_block = block_idx == 0 and stage_idx == 0 out_chs = make_divisible(cfg.channels[stage_idx] * cfg.width_factor, cfg.ch_div) - blocks += [NormalizationFreeBlock( + blocks += [NormFreeBlock( in_chs=prev_chs, out_chs=out_chs, alpha=cfg.alpha, beta=1. / expected_var ** 0.5, # NOTE: beta used as multiplier in block @@ -329,15 +453,16 @@ class NormalizerFreeNet(nn.Module): dilation=dilation, first_dilation=first_dilation, group_size=cfg.group_size, - bottle_ratio=1. if cfg.efficient and first_block else cfg.bottle_ratio, - efficient=cfg.efficient, + bottle_ratio=1. if cfg.reg and first_block else cfg.bottle_ratio, ch_div=cfg.ch_div, + reg=cfg.reg, + extra_conv=cfg.extra_conv, + skipinit=cfg.skipinit, attn_layer=attn_layer, attn_gain=cfg.attn_gain, act_layer=act_layer, conv_layer=conv_layer, - drop_path_rate=dpr[stage_idx][block_idx], - skipinit=cfg.skipinit, + drop_path_rate=drop_path_rates[stage_idx][block_idx], )] if block_idx == 0: expected_var = 1. # expected var is reset after first block of each stage @@ -347,22 +472,25 @@ class NormalizerFreeNet(nn.Module): stages += [nn.Sequential(*blocks)] self.stages = nn.Sequential(*stages) - if cfg.efficient and cfg.num_features: + if cfg.num_features: # The paper NFRegNet models have an EfficientNet-like final head convolution. self.num_features = make_divisible(cfg.width_factor * cfg.num_features, cfg.ch_div) self.final_conv = conv_layer(prev_chs, self.num_features, 1) + # FIXME not 100% clear on gamma subtleties final conv/final act in case where it's pushed into stdconv else: self.num_features = prev_chs self.final_conv = nn.Identity() - # FIXME not 100% clear on gamma subtleties final conv/final act in case where it's in stdconv - self.final_act = act_layer() + self.final_act = act_layer(inplace=cfg.num_features > 0) self.feature_info += [dict(num_chs=self.num_features, reduction=net_stride, module='final_act')] self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) for n, m in self.named_modules(): if 'fc' in n and isinstance(m, nn.Linear): - nn.init.zeros_(m.weight) + if cfg.zero_init_fc: + nn.init.zeros_(m.weight) + else: + nn.init.normal_(m.weight, 0., .01) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Conv2d): @@ -395,17 +523,121 @@ def _create_normfreenet(variant, pretrained=False, **kwargs): model_cfg = model_cfgs[variant] feature_cfg = dict(flatten_sequential=True) feature_cfg['feature_cls'] = 'hook' # pre-act models need hooks to grab feat from act1 in bottleneck blocks - if 'pool' in model_cfg.stem_type: - feature_cfg['out_indices'] = (1, 2, 3, 4) # no stride 2, 0 level feat for stride 4 maxpool stems in ResNet + if 'pool' in model_cfg.stem_type and 'deep' not in model_cfg.stem_type: + feature_cfg['out_indices'] = (1, 2, 3, 4) # no stride 2 feat for stride 4, 1 layer maxpool stems return build_model_with_cfg( - NormalizerFreeNet, variant, pretrained, + NormFreeNet, variant, pretrained, default_cfg=default_cfgs[variant], model_cfg=model_cfg, feature_cfg=feature_cfg, **kwargs) +@register_model +def nfnet_f0(pretrained=False, **kwargs): + return _create_normfreenet('nfnet_f0', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f1(pretrained=False, **kwargs): + return _create_normfreenet('nfnet_f1', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f2(pretrained=False, **kwargs): + return _create_normfreenet('nfnet_f2', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f3(pretrained=False, **kwargs): + return _create_normfreenet('nfnet_f3', pretrained=pretrained, **kwargs) + + +def nfnet_f4(pretrained=False, **kwargs): + return _create_normfreenet('nfnet_f4', pretrained=pretrained, **kwargs) + + +def nfnet_f5(pretrained=False, **kwargs): + return _create_normfreenet('nfnet_f5', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f6(pretrained=False, **kwargs): + return _create_normfreenet('nfnet_f6', pretrained=pretrained, **kwargs) + + +def nfnet_f7(pretrained=False, **kwargs): + return _create_normfreenet('nfnet_f7', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f0s(pretrained=False, **kwargs): + return _create_normfreenet('nfnet_f0s', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f1s(pretrained=False, **kwargs): + return _create_normfreenet('nfnet_f1s', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f2s(pretrained=False, **kwargs): + return _create_normfreenet('nfnet_f2s', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f3s(pretrained=False, **kwargs): + return _create_normfreenet('nfnet_f3s', pretrained=pretrained, **kwargs) + + +def nfnet_f4s(pretrained=False, **kwargs): + return _create_normfreenet('nfnet_f4s', pretrained=pretrained, **kwargs) + + +def nfnet_f5s(pretrained=False, **kwargs): + return _create_normfreenet('nfnet_f5s', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f6s(pretrained=False, **kwargs): + return _create_normfreenet('nfnet_f6s', pretrained=pretrained, **kwargs) + + +def nfnet_f7s(pretrained=False, **kwargs): + return _create_normfreenet('nfnet_f7s', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b0(pretrained=False, **kwargs): + return _create_normfreenet('nf_regnet_b0', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b1(pretrained=False, **kwargs): + return _create_normfreenet('nf_regnet_b1', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b2(pretrained=False, **kwargs): + return _create_normfreenet('nf_regnet_b2', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b3(pretrained=False, **kwargs): + return _create_normfreenet('nf_regnet_b3', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b4(pretrained=False, **kwargs): + return _create_normfreenet('nf_regnet_b4', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b5(pretrained=False, **kwargs): + return _create_normfreenet('nf_regnet_b5', pretrained=pretrained, **kwargs) + + @register_model def nf_regnet_b0(pretrained=False, **kwargs): return _create_normfreenet('nf_regnet_b0', pretrained=pretrained, **kwargs) @@ -475,6 +707,7 @@ def nf_ecaresnet26(pretrained=False, **kwargs): def nf_ecaresnet50(pretrained=False, **kwargs): return _create_normfreenet('nf_ecaresnet50', pretrained=pretrained, **kwargs) + @register_model def nf_ecaresnet101(pretrained=False, **kwargs): return _create_normfreenet('nf_ecaresnet101', pretrained=pretrained, **kwargs) \ No newline at end of file From 0d253e2c5e29814ad6fd58be02a3dd154d9debc6 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 12 Feb 2021 21:05:41 -0800 Subject: [PATCH 04/10] Fix issue with nfnet tests, bit more cleanup. --- tests/test_models.py | 4 +++- timm/models/nfnet.py | 53 ++++++++++++++++++-------------------------- 2 files changed, 24 insertions(+), 33 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 3f1c4cda..407e0fe5 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -19,7 +19,9 @@ NON_STD_FILTERS = ['vit_*'] # exclude models that cause specific test failures if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system(): # GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models - EXCLUDE_FILTERS = ['*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm'] + NON_STD_FILTERS + EXCLUDE_FILTERS = [ + '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', + 'nfnet_f4*', 'nfnet_f5*', 'nfnet_f6*', 'nfnet_f7*'] + NON_STD_FILTERS else: EXCLUDE_FILTERS = NON_STD_FILTERS diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index b9c003e8..4dc848ba 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -3,7 +3,6 @@ Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - https://arxiv.org/abs/2101.08692 - Paper: `High-Performance Large-Scale Image Recognition Without Normalization` - https://arxiv.org/abs/2102.06171 @@ -11,8 +10,8 @@ Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/m Status: * These models are a work in progress, experiments ongoing. -* Two pretrained weights so far, more to come. -* Model details update to closer match official JAX code now that it's released +* Pretrained weights for two models so far, more to come. +* Model details updated to closer match official JAX code now that it's released * NF-ResNet, NF-RegNet-B, and NFNet-F models supported Hacked together by / copyright Ross Wightman, 2021. @@ -150,7 +149,7 @@ def _nfnet_cfg(depths, act_layer='gelu', attn_layer='se', attn_kwargs=None): num_features = channels[-1] * 2 attn_kwargs = attn_kwargs or dict(reduction_ratio=0.5, divisor=8) cfg = NfCfg( - depths=depths, channels=channels, stem_type='nff', group_size=128, bottle_ratio=0.5, extra_conv=True, + depths=depths, channels=channels, stem_type='deep_quad', group_size=128, bottle_ratio=0.5, extra_conv=True, num_features=num_features, act_layer=act_layer, attn_layer=attn_layer, attn_kwargs=attn_kwargs) return cfg @@ -176,9 +175,6 @@ model_cfgs = dict( nfnet_f6s=_nfnet_cfg(depths=(7, 14, 42, 21), act_layer='silu'), nfnet_f7s=_nfnet_cfg(depths=(8, 16, 48, 24), act_layer='silu'), - # NFNet-F models w/ SiLU (much faster in PyTorch) - # FIXME add remainder if silu vs gelu proves worthwhile - # EffNet influenced RegNet defs. # NOTE: These aren't quite the official ver, ch_div=1 must be set for exact ch counts. I round to ch_div=8. nf_regnet_b0=_nfreg_cfg(depths=(1, 3, 6, 6)), @@ -194,9 +190,9 @@ model_cfgs = dict( nf_resnet50=_nfres_cfg(depths=(3, 4, 6, 3)), nf_resnet101=_nfres_cfg(depths=(3, 4, 23, 3)), - nf_seresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)), - nf_seresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)), - nf_seresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)), + nf_seresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)), + nf_seresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)), + nf_seresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)), nf_ecaresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='eca', attn_kwargs=dict()), nf_ecaresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='eca', attn_kwargs=dict()), @@ -315,38 +311,26 @@ class NormFreeBlock(nn.Module): return out -def stem_info(stem_type): - stem_stride = 2 - if 'nff' in stem_type or 'pool' in stem_type: - stem_stride = 4 - stem_feat = '' - if 'nff' in stem_type: - stem_feat = 'stem.act3' - elif 'deep' in stem_type and not 'pool' in stem_type: - stem_feat = 'stem.act2' - return stem_stride, stem_feat - - def create_stem(in_chs, out_chs, stem_type='', conv_layer=None, act_layer=None): stem_stride = 2 - stem_feature = '' + stem_feature = dict(num_chs=out_chs, reduction=2, module='') stem = OrderedDict() - assert stem_type in ('', 'nff', 'deep', 'deep_tiered', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool') + assert stem_type in ('', 'deep', 'deep_tiered', 'deep_quad', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool') if 'deep' in stem_type or 'nff' in stem_type: # 3 deep 3x3 conv stack as in ResNet V1D models. NOTE: doesn't work as well here - if 'nff' in stem_type: + if 'quad' in stem_type: assert not 'pool' in stem_type stem_chs = (16, 32, 64, out_chs) strides = (2, 1, 1, 2) stem_stride = 4 - stem_feature = 'stem.act4' + stem_feature = dict(num_chs=64, reduction=2, module='stem.act4') else: if 'tiered' in stem_type: - stem_chs = (3 * out_chs // 8, out_chs // 2, out_chs) + stem_chs = (3 * out_chs // 8, out_chs // 2, out_chs) # like 'T' resnets in resnet.py else: - stem_chs = (out_chs // 2, out_chs // 2, out_chs) + stem_chs = (out_chs // 2, out_chs // 2, out_chs) # 'D' ResNets strides = (2, 1, 1) - stem_feature = 'stem.act3' + stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.act3') last_idx = len(stem_chs) - 1 for i, (c, s) in enumerate(zip(stem_chs, strides)): stem[f'conv{i+1}'] = conv_layer(in_chs, c, kernel_size=3, stride=s) @@ -401,7 +385,7 @@ class NormFreeNet(nn.Module): * activation correcting gamma constants are moved into the ScaledStdConv as it has less performance impact in PyTorch when done with the weight scaling there. This likely wasn't a concern in the JAX impl. * a config option `gamma_in_act` can be enabled to not apply gamma in StdConv as described above, but - apply it in each activation. This is slightly slower, and yields slightly different results. + apply it in each activation. This is slightly slower, numerically different, but matches official impl. * skipinit is disabled by default, it seems to have a rather drastic impact on GPU memory use and throughput for what it is/does. Approx 8-10% throughput loss. """ @@ -424,7 +408,7 @@ class NormFreeNet(nn.Module): self.stem, stem_stride, stem_feat = create_stem( in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer, act_layer=act_layer) - self.feature_info = [dict(num_chs=stem_chs, reduction=2, module=stem_feat)] if stem_stride == 4 else [] + self.feature_info = [stem_feat] if stem_stride == 4 else [] drop_path_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] prev_chs = stem_chs net_stride = stem_stride @@ -476,7 +460,6 @@ class NormFreeNet(nn.Module): # The paper NFRegNet models have an EfficientNet-like final head convolution. self.num_features = make_divisible(cfg.width_factor * cfg.num_features, cfg.ch_div) self.final_conv = conv_layer(prev_chs, self.num_features, 1) - # FIXME not 100% clear on gamma subtleties final conv/final act in case where it's pushed into stdconv else: self.num_features = prev_chs self.final_conv = nn.Identity() @@ -554,10 +537,12 @@ def nfnet_f3(pretrained=False, **kwargs): return _create_normfreenet('nfnet_f3', pretrained=pretrained, **kwargs) +@register_model def nfnet_f4(pretrained=False, **kwargs): return _create_normfreenet('nfnet_f4', pretrained=pretrained, **kwargs) +@register_model def nfnet_f5(pretrained=False, **kwargs): return _create_normfreenet('nfnet_f5', pretrained=pretrained, **kwargs) @@ -567,6 +552,7 @@ def nfnet_f6(pretrained=False, **kwargs): return _create_normfreenet('nfnet_f6', pretrained=pretrained, **kwargs) +@register_model def nfnet_f7(pretrained=False, **kwargs): return _create_normfreenet('nfnet_f7', pretrained=pretrained, **kwargs) @@ -591,10 +577,12 @@ def nfnet_f3s(pretrained=False, **kwargs): return _create_normfreenet('nfnet_f3s', pretrained=pretrained, **kwargs) +@register_model def nfnet_f4s(pretrained=False, **kwargs): return _create_normfreenet('nfnet_f4s', pretrained=pretrained, **kwargs) +@register_model def nfnet_f5s(pretrained=False, **kwargs): return _create_normfreenet('nfnet_f5s', pretrained=pretrained, **kwargs) @@ -604,6 +592,7 @@ def nfnet_f6s(pretrained=False, **kwargs): return _create_normfreenet('nfnet_f6s', pretrained=pretrained, **kwargs) +@register_model def nfnet_f7s(pretrained=False, **kwargs): return _create_normfreenet('nfnet_f7s', pretrained=pretrained, **kwargs) From d86dbe45c2b9a1b7a7c054b4687a526b8ec8df61 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 12 Feb 2021 22:07:18 -0800 Subject: [PATCH 05/10] Update README.md and few more comments --- README.md | 6 +++++- timm/models/nfnet.py | 14 ++++++++------ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 572109de..c4f3a588 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,11 @@ ## What's New +### Feb 12, 2021 +* Update Normalization-Free nets to include new NFNet-F (https://arxiv.org/abs/2102.06171) model defs + ### Feb 10, 2021 -* First Normalizer-Free model training experiments done, +* First Normalization-Free model training experiments done, * nf_resnet50 - 80.68 top-1 @ 288x288, 80.31 @ 256x256 * nf_regnet_b1 - 79.30 @ 288x288, 78.75 @ 256x256 * More model archs, incl a flexible ByobNet backbone ('Bring-your-own-blocks') @@ -164,6 +167,7 @@ A full version of the list below with source links can be found in the [document * Inception-ResNet-V2 and Inception-V4 - https://arxiv.org/abs/1602.07261 * MobileNet-V3 (MBConvNet w/ Efficient Head) - https://arxiv.org/abs/1905.02244 * NASNet-A - https://arxiv.org/abs/1707.07012 +* NFNet-F - https://arxiv.org/abs/2102.06171 * NF-RegNet / NF-ResNet - https://arxiv.org/abs/2101.08692 * PNasNet - https://arxiv.org/abs/1712.00559 * RegNet - https://arxiv.org/abs/2003.13678 diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 4dc848ba..1f83f6df 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -236,7 +236,7 @@ class DownsampleAvg(nn.Module): class NormFreeBlock(nn.Module): - """Normalization-free pre-activation block. + """Normalization-Free pre-activation block. """ def __init__( @@ -351,6 +351,7 @@ def create_stem(in_chs, out_chs, stem_type='', conv_layer=None, act_layer=None): return nn.Sequential(stem), stem_stride, stem_feature +# from https://github.com/deepmind/deepmind-research/tree/master/nfnets _nonlin_gamma = dict( identity=1.0, celu=1.270926833152771, @@ -371,10 +372,13 @@ _nonlin_gamma = dict( class NormFreeNet(nn.Module): - """ Normalization-free ResNets and RegNets + """ Normalization-Free Network - As described in `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + As described in : + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - https://arxiv.org/abs/2101.08692 + and + `High-Performance Large-Scale Image Recognition Without Normalization` - https://arxiv.org/abs/2102.06171 This model aims to cover both the NFRegNet-Bx models as detailed in the paper's code snippets and the (preact) ResNet models described earlier in the paper. @@ -432,7 +436,7 @@ class NormFreeNet(nn.Module): blocks += [NormFreeBlock( in_chs=prev_chs, out_chs=out_chs, alpha=cfg.alpha, - beta=1. / expected_var ** 0.5, # NOTE: beta used as multiplier in block + beta=1. / expected_var ** 0.5, stride=stride if block_idx == 0 else 1, dilation=dilation, first_dilation=first_dilation, @@ -477,8 +481,6 @@ class NormFreeNet(nn.Module): if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Conv2d): - # as per discussion with paper authors, original in haiku is - # hk.initializers.VarianceScaling(1.0, 'fan_in', 'normal')' w/ zero'd bias nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='linear') if m.bias is not None: nn.init.zeros_(m.bias) From 5f9aff395c224492e9e44248b15f44b5cc095d9c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 13 Feb 2021 16:58:51 -0800 Subject: [PATCH 06/10] Fix stem width in NFNet-F models, add some more comments, add some 'light' NFNet models for testing. --- timm/models/nfnet.py | 199 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 162 insertions(+), 37 deletions(-) diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 1f83f6df..b43ee5ef 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -78,6 +78,13 @@ default_cfgs = dict( nfnet_f7s=_dcfg( url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608), first_conv='stem.conv1'), + nfnet_l0a=_dcfg( + url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'), + nfnet_l0b=_dcfg( + url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'), + nfnet_l0c=_dcfg( + url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'), + nf_regnet_b0=_dcfg(url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)), nf_regnet_b1=_dcfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_regnet_b1_256_ra2-ad85cfef.pth', @@ -144,13 +151,15 @@ def _nfreg_cfg(depths, channels=(48, 104, 208, 440)): return cfg -def _nfnet_cfg(depths, act_layer='gelu', attn_layer='se', attn_kwargs=None): - channels = (256, 512, 1536, 1536) - num_features = channels[-1] * 2 - attn_kwargs = attn_kwargs or dict(reduction_ratio=0.5, divisor=8) +def _nfnet_cfg( + depths, channels=(256, 512, 1536, 1536), group_size=128, bottle_ratio=0.5, feat_mult=2., + act_layer='gelu', attn_layer='se', attn_kwargs=None): + num_features = int(channels[-1] * feat_mult) + attn_kwargs = attn_kwargs if attn_kwargs is not None else dict(reduction_ratio=0.5, divisor=8) cfg = NfCfg( - depths=depths, channels=channels, stem_type='deep_quad', group_size=128, bottle_ratio=0.5, extra_conv=True, - num_features=num_features, act_layer=act_layer, attn_layer=attn_layer, attn_kwargs=attn_kwargs) + depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=group_size, + bottle_ratio=bottle_ratio, extra_conv=True, num_features=num_features, act_layer=act_layer, + attn_layer=attn_layer, attn_kwargs=attn_kwargs) return cfg @@ -175,6 +184,17 @@ model_cfgs = dict( nfnet_f6s=_nfnet_cfg(depths=(7, 14, 42, 21), act_layer='silu'), nfnet_f7s=_nfnet_cfg(depths=(8, 16, 48, 24), act_layer='silu'), + # Experimental 'light' versions of nfnet-f that are little leaner + nfnet_l0a=_nfnet_cfg( + depths=(1, 2, 6, 3), channels=(256, 512, 1280, 1536), feat_mult=1.5, group_size=64, bottle_ratio=0.25, + attn_kwargs=dict(reduction_ratio=0.25, divisor=8), act_layer='silu'), + nfnet_l0b=_nfnet_cfg( + depths=(1, 2, 6, 3), channels=(256, 512, 1536, 1536), feat_mult=1.5, group_size=64, bottle_ratio=0.25, + attn_kwargs=dict(reduction_ratio=0.25, divisor=8), act_layer='silu'), + nfnet_l0c=_nfnet_cfg( + depths=(1, 2, 6, 3), channels=(256, 512, 1536, 1536), feat_mult=1.5, group_size=64, bottle_ratio=0.25, + attn_layer='eca', attn_kwargs=dict(), act_layer='silu'), + # EffNet influenced RegNet defs. # NOTE: These aren't quite the official ver, ch_div=1 must be set for exact ch counts. I round to ch_div=8. nf_regnet_b0=_nfreg_cfg(depths=(1, 3, 6, 6)), @@ -316,26 +336,26 @@ def create_stem(in_chs, out_chs, stem_type='', conv_layer=None, act_layer=None): stem_feature = dict(num_chs=out_chs, reduction=2, module='') stem = OrderedDict() assert stem_type in ('', 'deep', 'deep_tiered', 'deep_quad', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool') - if 'deep' in stem_type or 'nff' in stem_type: - # 3 deep 3x3 conv stack as in ResNet V1D models. NOTE: doesn't work as well here + if 'deep' in stem_type: if 'quad' in stem_type: + # 4 deep conv stack as in NFNet-F models assert not 'pool' in stem_type - stem_chs = (16, 32, 64, out_chs) + stem_chs = (out_chs // 8, out_chs // 4, out_chs // 2, out_chs) strides = (2, 1, 1, 2) stem_stride = 4 - stem_feature = dict(num_chs=64, reduction=2, module='stem.act4') + stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.act4') else: if 'tiered' in stem_type: - stem_chs = (3 * out_chs // 8, out_chs // 2, out_chs) # like 'T' resnets in resnet.py + stem_chs = (3 * out_chs // 8, out_chs // 2, out_chs) # 'T' resnets in resnet.py else: stem_chs = (out_chs // 2, out_chs // 2, out_chs) # 'D' ResNets strides = (2, 1, 1) stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.act3') last_idx = len(stem_chs) - 1 for i, (c, s) in enumerate(zip(stem_chs, strides)): - stem[f'conv{i+1}'] = conv_layer(in_chs, c, kernel_size=3, stride=s) + stem[f'conv{i + 1}'] = conv_layer(in_chs, c, kernel_size=3, stride=s) if i != last_idx: - stem[f'act{i+2}'] = act_layer(inplace=True) + stem[f'act{i + 2}'] = act_layer(inplace=True) in_chs = c elif '3x3' in stem_type: # 3x3 stem conv as in RegNet @@ -407,8 +427,7 @@ class NormFreeNet(nn.Module): conv_layer = partial(ScaledStdConv2d, bias=True, gain=True, gamma=_nonlin_gamma[cfg.act_layer]) attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None - stem_chs = cfg.stem_chs or cfg.channels[0] - stem_chs = make_divisible(stem_chs * cfg.width_factor, cfg.ch_div) + stem_chs = make_divisible((cfg.stem_chs or cfg.channels[0]) * cfg.width_factor, cfg.ch_div) self.stem, stem_stride, stem_feat = create_stem( in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer, act_layer=act_layer) @@ -521,184 +540,290 @@ def _create_normfreenet(variant, pretrained=False, **kwargs): @register_model def nfnet_f0(pretrained=False, **kwargs): + """ NFNet-F0 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ return _create_normfreenet('nfnet_f0', pretrained=pretrained, **kwargs) @register_model def nfnet_f1(pretrained=False, **kwargs): + """ NFNet-F1 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ return _create_normfreenet('nfnet_f1', pretrained=pretrained, **kwargs) @register_model def nfnet_f2(pretrained=False, **kwargs): + """ NFNet-F2 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ return _create_normfreenet('nfnet_f2', pretrained=pretrained, **kwargs) @register_model def nfnet_f3(pretrained=False, **kwargs): + """ NFNet-F3 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ return _create_normfreenet('nfnet_f3', pretrained=pretrained, **kwargs) @register_model def nfnet_f4(pretrained=False, **kwargs): + """ NFNet-F4 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ return _create_normfreenet('nfnet_f4', pretrained=pretrained, **kwargs) @register_model def nfnet_f5(pretrained=False, **kwargs): + """ NFNet-F5 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ return _create_normfreenet('nfnet_f5', pretrained=pretrained, **kwargs) @register_model def nfnet_f6(pretrained=False, **kwargs): + """ NFNet-F6 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ return _create_normfreenet('nfnet_f6', pretrained=pretrained, **kwargs) @register_model def nfnet_f7(pretrained=False, **kwargs): + """ NFNet-F7 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ return _create_normfreenet('nfnet_f7', pretrained=pretrained, **kwargs) @register_model def nfnet_f0s(pretrained=False, **kwargs): + """ NFNet-F0 w/ SiLU + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ return _create_normfreenet('nfnet_f0s', pretrained=pretrained, **kwargs) @register_model def nfnet_f1s(pretrained=False, **kwargs): + """ NFNet-F1 w/ SiLU + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ return _create_normfreenet('nfnet_f1s', pretrained=pretrained, **kwargs) @register_model def nfnet_f2s(pretrained=False, **kwargs): + """ NFNet-F2 w/ SiLU + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ return _create_normfreenet('nfnet_f2s', pretrained=pretrained, **kwargs) @register_model def nfnet_f3s(pretrained=False, **kwargs): + """ NFNet-F3 w/ SiLU + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ return _create_normfreenet('nfnet_f3s', pretrained=pretrained, **kwargs) @register_model def nfnet_f4s(pretrained=False, **kwargs): + """ NFNet-F4 w/ SiLU + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ return _create_normfreenet('nfnet_f4s', pretrained=pretrained, **kwargs) @register_model def nfnet_f5s(pretrained=False, **kwargs): + """ NFNet-F5 w/ SiLU + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ return _create_normfreenet('nfnet_f5s', pretrained=pretrained, **kwargs) @register_model def nfnet_f6s(pretrained=False, **kwargs): + """ NFNet-F6 w/ SiLU + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ return _create_normfreenet('nfnet_f6s', pretrained=pretrained, **kwargs) @register_model def nfnet_f7s(pretrained=False, **kwargs): + """ NFNet-F7 w/ SiLU + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ return _create_normfreenet('nfnet_f7s', pretrained=pretrained, **kwargs) @register_model -def nf_regnet_b0(pretrained=False, **kwargs): - return _create_normfreenet('nf_regnet_b0', pretrained=pretrained, **kwargs) - - -@register_model -def nf_regnet_b1(pretrained=False, **kwargs): - return _create_normfreenet('nf_regnet_b1', pretrained=pretrained, **kwargs) - - -@register_model -def nf_regnet_b2(pretrained=False, **kwargs): - return _create_normfreenet('nf_regnet_b2', pretrained=pretrained, **kwargs) - - -@register_model -def nf_regnet_b3(pretrained=False, **kwargs): - return _create_normfreenet('nf_regnet_b3', pretrained=pretrained, **kwargs) +def nfnet_l0a(pretrained=False, **kwargs): + """ NFNet-L0a w/ SiLU + My experimental 'light' model w/ 1280 width stage 3, 1.5x final_conv mult, 64 group_size, .25 bottleneck & SE ratio + """ + return _create_normfreenet('nfnet_l0a', pretrained=pretrained, **kwargs) @register_model -def nf_regnet_b4(pretrained=False, **kwargs): - return _create_normfreenet('nf_regnet_b4', pretrained=pretrained, **kwargs) +def nfnet_l0b(pretrained=False, **kwargs): + """ NFNet-L0b w/ SiLU + My experimental 'light' model w/ 1.5x final_conv mult, 64 group_size, .25 bottleneck & SE ratio + """ + return _create_normfreenet('nfnet_l0b', pretrained=pretrained, **kwargs) @register_model -def nf_regnet_b5(pretrained=False, **kwargs): - return _create_normfreenet('nf_regnet_b5', pretrained=pretrained, **kwargs) +def nfnet_l0c(pretrained=False, **kwargs): + """ NFNet-L0c w/ SiLU + My experimental 'light' model w/ 1.5x final_conv mult, 64 group_size, .25 bottleneck & ECA attn + """ + return _create_normfreenet('nfnet_l0c', pretrained=pretrained, **kwargs) @register_model def nf_regnet_b0(pretrained=False, **kwargs): + """ Normalization-Free RegNet-B0 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ return _create_normfreenet('nf_regnet_b0', pretrained=pretrained, **kwargs) @register_model def nf_regnet_b1(pretrained=False, **kwargs): + """ Normalization-Free RegNet-B1 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ return _create_normfreenet('nf_regnet_b1', pretrained=pretrained, **kwargs) @register_model def nf_regnet_b2(pretrained=False, **kwargs): + """ Normalization-Free RegNet-B2 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ return _create_normfreenet('nf_regnet_b2', pretrained=pretrained, **kwargs) @register_model def nf_regnet_b3(pretrained=False, **kwargs): + """ Normalization-Free RegNet-B3 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ return _create_normfreenet('nf_regnet_b3', pretrained=pretrained, **kwargs) @register_model def nf_regnet_b4(pretrained=False, **kwargs): + """ Normalization-Free RegNet-B4 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ return _create_normfreenet('nf_regnet_b4', pretrained=pretrained, **kwargs) @register_model def nf_regnet_b5(pretrained=False, **kwargs): + """ Normalization-Free RegNet-B5 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ return _create_normfreenet('nf_regnet_b5', pretrained=pretrained, **kwargs) @register_model def nf_resnet26(pretrained=False, **kwargs): + """ Normalization-Free ResNet-26 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ return _create_normfreenet('nf_resnet26', pretrained=pretrained, **kwargs) @register_model def nf_resnet50(pretrained=False, **kwargs): + """ Normalization-Free ResNet-50 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ return _create_normfreenet('nf_resnet50', pretrained=pretrained, **kwargs) @register_model def nf_resnet101(pretrained=False, **kwargs): + """ Normalization-Free ResNet-101 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ return _create_normfreenet('nf_resnet101', pretrained=pretrained, **kwargs) @register_model def nf_seresnet26(pretrained=False, **kwargs): + """ Normalization-Free SE-ResNet26 + """ return _create_normfreenet('nf_seresnet26', pretrained=pretrained, **kwargs) @register_model def nf_seresnet50(pretrained=False, **kwargs): + """ Normalization-Free SE-ResNet50 + """ return _create_normfreenet('nf_seresnet50', pretrained=pretrained, **kwargs) @register_model def nf_seresnet101(pretrained=False, **kwargs): + """ Normalization-Free SE-ResNet101 + """ return _create_normfreenet('nf_seresnet101', pretrained=pretrained, **kwargs) @register_model def nf_ecaresnet26(pretrained=False, **kwargs): + """ Normalization-Free ECA-ResNet26 + """ return _create_normfreenet('nf_ecaresnet26', pretrained=pretrained, **kwargs) @register_model def nf_ecaresnet50(pretrained=False, **kwargs): + """ Normalization-Free ECA-ResNet50 + """ return _create_normfreenet('nf_ecaresnet50', pretrained=pretrained, **kwargs) @register_model def nf_ecaresnet101(pretrained=False, **kwargs): - return _create_normfreenet('nf_ecaresnet101', pretrained=pretrained, **kwargs) \ No newline at end of file + """ Normalization-Free ECA-ResNet101 + """ + return _create_normfreenet('nf_ecaresnet101', pretrained=pretrained, **kwargs) From 4f49b94311860e5b695d2a919413d1aae4e0eb9c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 15 Feb 2021 23:22:44 -0800 Subject: [PATCH 07/10] Initial AGC impl. Still testing. --- timm/models/__init__.py | 2 +- timm/models/helpers.py | 20 +++++++++++++------- timm/utils/__init__.py | 2 ++ timm/utils/agc.py | 42 +++++++++++++++++++++++++++++++++++++++++ timm/utils/clip_grad.py | 23 ++++++++++++++++++++++ timm/utils/cuda.py | 10 ++++++---- train.py | 11 ++++++++--- 7 files changed, 95 insertions(+), 15 deletions(-) create mode 100644 timm/utils/agc.py create mode 100644 timm/utils/clip_grad.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index dc56848e..8d99d19b 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -31,7 +31,7 @@ from .xception import * from .xception_aligned import * from .factory import create_model -from .helpers import load_checkpoint, resume_checkpoint +from .helpers import load_checkpoint, resume_checkpoint, model_parameters from .layers import TestTimePoolHead, apply_test_time_pool from .layers import convert_splitbn_model from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit diff --git a/timm/models/helpers.py b/timm/models/helpers.py index d9b501da..4d9b8a28 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -113,10 +113,9 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_ digits of the SHA256 hash of the contents of the file. The hash is used to ensure unique names and to verify the contents of the file. Default: False """ - if cfg is None: - cfg = getattr(model, 'default_cfg') - if cfg is None or 'url' not in cfg or not cfg['url']: - _logger.warning("Pretrained model URL does not exist, using random initialization.") + cfg = cfg or getattr(model, 'default_cfg') + if cfg is None or not cfg.get('url', None): + _logger.warning("No pretrained weights exist for this model. Using random initialization.") return url = cfg['url'] @@ -174,9 +173,8 @@ def adapt_input_conv(in_chans, conv_weight): def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False): - if cfg is None: - cfg = getattr(model, 'default_cfg') - if cfg is None or 'url' not in cfg or not cfg['url']: + cfg = cfg or getattr(model, 'default_cfg') + if cfg is None or not cfg.get('url', None): _logger.warning("No pretrained weights exist for this model. Using random initialization.") return @@ -376,3 +374,11 @@ def build_model_with_cfg( model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg return model + + +def model_parameters(model, exclude_head=False): + if exclude_head: + # FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering + return [p for p in model.parameters()][:-2] + else: + return model.parameters() diff --git a/timm/utils/__init__.py b/timm/utils/__init__.py index 0f7c4b05..1c526e8c 100644 --- a/timm/utils/__init__.py +++ b/timm/utils/__init__.py @@ -1,4 +1,6 @@ +from .agc import adaptive_clip_grad from .checkpoint_saver import CheckpointSaver +from .clip_grad import dispatch_clip_grad from .cuda import ApexScaler, NativeScaler from .distributed import distribute_bn, reduce_tensor from .jit import set_jit_legacy diff --git a/timm/utils/agc.py b/timm/utils/agc.py new file mode 100644 index 00000000..f5140172 --- /dev/null +++ b/timm/utils/agc.py @@ -0,0 +1,42 @@ +""" Adaptive Gradient Clipping + +An impl of AGC, as per (https://arxiv.org/abs/2102.06171): + +@article{brock2021high, + author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan}, + title={High-Performance Large-Scale Image Recognition Without Normalization}, + journal={arXiv preprint arXiv:}, + year={2021} +} + +Code references: + * Official JAX impl (paper authors): https://github.com/deepmind/deepmind-research/tree/master/nfnets + * Phil Wang's PyTorch gist: https://gist.github.com/lucidrains/0d6560077edac419ab5d3aa29e674d5c + +Hacked together by / Copyright 2021 Ross Wightman +""" +import torch + + +def unitwise_norm(x, norm_type=2.0): + if x.ndim <= 1: + return x.norm(norm_type) + else: + # works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor + # might need special cases for other weights (possibly MHA) where this may not be true + return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True) + + +def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + for p in parameters: + if p.grad is None: + continue + p_data = p.detach() + g_data = p.grad.detach() + max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor) + grad_norm = unitwise_norm(g_data, norm_type=norm_type) + clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6)) + new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad) + p.grad.detach().copy_(new_grads) diff --git a/timm/utils/clip_grad.py b/timm/utils/clip_grad.py new file mode 100644 index 00000000..7eb40697 --- /dev/null +++ b/timm/utils/clip_grad.py @@ -0,0 +1,23 @@ +import torch + +from timm.utils.agc import adaptive_clip_grad + + +def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0): + """ Dispatch to gradient clipping method + + Args: + parameters (Iterable): model parameters to clip + value (float): clipping value/factor/norm, mode dependant + mode (str): clipping mode, one of 'norm', 'value', 'agc' + norm_type (float): p-norm, default 2.0 + """ + if mode == 'norm': + torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type) + elif mode == 'value': + torch.nn.utils.clip_grad_value_(parameters, value) + elif mode == 'agc': + adaptive_clip_grad(parameters, value, norm_type=norm_type) + else: + assert False, f"Unknown clip mode ({mode})." + diff --git a/timm/utils/cuda.py b/timm/utils/cuda.py index bcd29f58..9e7bddf3 100644 --- a/timm/utils/cuda.py +++ b/timm/utils/cuda.py @@ -11,15 +11,17 @@ except ImportError: amp = None has_apex = False +from .clip_grad import dispatch_clip_grad + class ApexScaler: state_dict_key = "amp" - def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False): + def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward(create_graph=create_graph) if clip_grad is not None: - torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), clip_grad) + dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode) optimizer.step() def state_dict(self): @@ -37,12 +39,12 @@ class NativeScaler: def __init__(self): self._scaler = torch.cuda.amp.GradScaler() - def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False): + def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): self._scaler.scale(loss).backward(create_graph=create_graph) if clip_grad is not None: assert parameters is not None self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place - torch.nn.utils.clip_grad_norm_(parameters, clip_grad) + dispatch_clip_grad(parameters, clip_grad, mode=clip_mode) self._scaler.step(optimizer) self._scaler.update() diff --git a/train.py b/train.py index 0333d72f..b787a88c 100755 --- a/train.py +++ b/train.py @@ -29,7 +29,7 @@ import torchvision.utils from torch.nn.parallel import DistributedDataParallel as NativeDDP from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset -from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model +from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model, model_parameters from timm.utils import * from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy from timm.optim import create_optimizer @@ -637,11 +637,16 @@ def train_one_epoch( optimizer.zero_grad() if loss_scaler is not None: loss_scaler( - loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order) + loss, optimizer, + clip_grad=args.clip_grad, clip_mode=args.clip_mode, + parameters=model_parameters(model, exclude_head='agc' in args.clip_mode), + create_graph=second_order) else: loss.backward(create_graph=second_order) if args.clip_grad is not None: - torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad) + dispatch_clip_grad( + model_parameters(model, exclude_head='agc' in args.clip_mode), + value=args.clip_grad, mode=args.clip_mode) optimizer.step() if model_ema is not None: From 01653db104c8d60d2bac643b169c04139f3ae668 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 15 Feb 2021 23:27:16 -0800 Subject: [PATCH 08/10] Missed clip-mode arg for repo train script --- train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index b787a88c..9abcfed3 100755 --- a/train.py +++ b/train.py @@ -116,7 +116,8 @@ parser.add_argument('--weight-decay', type=float, default=0.0001, help='weight decay (default: 0.0001)') parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', help='Clip gradient norm (default: None, no clipping)') - +parser.add_argument('--clip-mode', type=str, default='norm', + help='Gradient clipping mode. One of ("norm", "value", "agc")') # Learning rate schedule parameters From 9de2ec5e442d2b714b004a7f2c2baf272c9e3854 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 16 Feb 2021 09:12:23 -0800 Subject: [PATCH 09/10] Update README for AGC and bump version to 0.4.4 --- README.md | 8 ++++++++ timm/version.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index c4f3a588..8b1f6f28 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,13 @@ ## What's New +### Feb 16, 2021 +* Add Adaptive Gradient Clipping (AGC) as per https://arxiv.org/abs/2102.06171. Integrated w/ PyTorch gradient clipping via mode arg that defaults to prev 'norm' mode. For backward arg compat, clip-grad arg must be specified to enable when using train.py. + * AGC w/ default clipping factor `--clip-grad .01 --clip-mode agc` + * PyTorch global norm of 1.0 (old behaviour, always norm), `--clip-grad 1.0` + * PyTorch value clipping of 10, `--clip-grad 10. --clip-mode value` + * AGC performance is definitely sensitive to the clipping factor. More experimentation needed to determine good values for smaller batch sizes and optimizers besides those in paper. So far I've found .001-.005 is necessary for stable RMSProp training. + ### Feb 12, 2021 * Update Normalization-Free nets to include new NFNet-F (https://arxiv.org/abs/2102.06171) model defs @@ -238,6 +245,7 @@ Several (less common) features that I often utilize in my projects are included. * Efficient Channel Attention - ECA (https://arxiv.org/abs/1910.03151) * Blur Pooling (https://arxiv.org/abs/1904.11486) * Space-to-Depth by [mrT23](https://github.com/mrT23/TResNet/blob/master/src/models/tresnet/layers/space_to_depth.py) (https://arxiv.org/abs/1801.04590) -- original paper? +* Adaptive Gradient Clipping (https://arxiv.org/abs/2102.06171, https://github.com/deepmind/deepmind-research/tree/master/nfnets) ## Results diff --git a/timm/version.py b/timm/version.py index 908c0bb7..9a8e054a 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.4.3' +__version__ = '0.4.4' From 361fd0fc40708c868ff92218f24b01113a617572 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 16 Feb 2021 10:27:41 -0800 Subject: [PATCH 10/10] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 8b1f6f28..421bced4 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ * AGC w/ default clipping factor `--clip-grad .01 --clip-mode agc` * PyTorch global norm of 1.0 (old behaviour, always norm), `--clip-grad 1.0` * PyTorch value clipping of 10, `--clip-grad 10. --clip-mode value` - * AGC performance is definitely sensitive to the clipping factor. More experimentation needed to determine good values for smaller batch sizes and optimizers besides those in paper. So far I've found .001-.005 is necessary for stable RMSProp training. + * AGC performance is definitely sensitive to the clipping factor. More experimentation needed to determine good values for smaller batch sizes and optimizers besides those in paper. So far I've found .001-.005 is necessary for stable RMSProp training w/ NFNet/NF-ResNet. ### Feb 12, 2021 * Update Normalization-Free nets to include new NFNet-F (https://arxiv.org/abs/2102.06171) model defs