diff --git a/README.md b/README.md index a5b4b536..421bced4 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,20 @@ ## What's New +### Feb 16, 2021 +* Add Adaptive Gradient Clipping (AGC) as per https://arxiv.org/abs/2102.06171. Integrated w/ PyTorch gradient clipping via mode arg that defaults to prev 'norm' mode. For backward arg compat, clip-grad arg must be specified to enable when using train.py. + * AGC w/ default clipping factor `--clip-grad .01 --clip-mode agc` + * PyTorch global norm of 1.0 (old behaviour, always norm), `--clip-grad 1.0` + * PyTorch value clipping of 10, `--clip-grad 10. --clip-mode value` + * AGC performance is definitely sensitive to the clipping factor. More experimentation needed to determine good values for smaller batch sizes and optimizers besides those in paper. So far I've found .001-.005 is necessary for stable RMSProp training w/ NFNet/NF-ResNet. + +### Feb 12, 2021 +* Update Normalization-Free nets to include new NFNet-F (https://arxiv.org/abs/2102.06171) model defs + ### Feb 10, 2021 +* First Normalization-Free model training experiments done, + * nf_resnet50 - 80.68 top-1 @ 288x288, 80.31 @ 256x256 + * nf_regnet_b1 - 79.30 @ 288x288, 78.75 @ 256x256 * More model archs, incl a flexible ByobNet backbone ('Bring-your-own-blocks') * GPU-Efficient-Networks (https://github.com/idstcv/GPU-Efficient-Networks), impl in `byobnet.py` * RepVGG (https://github.com/DingXiaoH/RepVGG), impl in `byobnet.py` @@ -161,6 +174,7 @@ A full version of the list below with source links can be found in the [document * Inception-ResNet-V2 and Inception-V4 - https://arxiv.org/abs/1602.07261 * MobileNet-V3 (MBConvNet w/ Efficient Head) - https://arxiv.org/abs/1905.02244 * NASNet-A - https://arxiv.org/abs/1707.07012 +* NFNet-F - https://arxiv.org/abs/2102.06171 * NF-RegNet / NF-ResNet - https://arxiv.org/abs/2101.08692 * PNasNet - https://arxiv.org/abs/1712.00559 * RegNet - https://arxiv.org/abs/2003.13678 @@ -231,6 +245,7 @@ Several (less common) features that I often utilize in my projects are included. * Efficient Channel Attention - ECA (https://arxiv.org/abs/1910.03151) * Blur Pooling (https://arxiv.org/abs/1904.11486) * Space-to-Depth by [mrT23](https://github.com/mrT23/TResNet/blob/master/src/models/tresnet/layers/space_to_depth.py) (https://arxiv.org/abs/1801.04590) -- original paper? +* Adaptive Gradient Clipping (https://arxiv.org/abs/2102.06171, https://github.com/deepmind/deepmind-research/tree/master/nfnets) ## Results diff --git a/tests/test_models.py b/tests/test_models.py index 3f1c4cda..407e0fe5 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -19,7 +19,9 @@ NON_STD_FILTERS = ['vit_*'] # exclude models that cause specific test failures if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system(): # GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models - EXCLUDE_FILTERS = ['*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm'] + NON_STD_FILTERS + EXCLUDE_FILTERS = [ + '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', + 'nfnet_f4*', 'nfnet_f5*', 'nfnet_f6*', 'nfnet_f7*'] + NON_STD_FILTERS else: EXCLUDE_FILTERS = NON_STD_FILTERS diff --git a/timm/models/__init__.py b/timm/models/__init__.py index dc56848e..8d99d19b 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -31,7 +31,7 @@ from .xception import * from .xception_aligned import * from .factory import create_model -from .helpers import load_checkpoint, resume_checkpoint +from .helpers import load_checkpoint, resume_checkpoint, model_parameters from .layers import TestTimePoolHead, apply_test_time_pool from .layers import convert_splitbn_model from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 0908694d..62558913 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -132,10 +132,9 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_ digits of the SHA256 hash of the contents of the file. The hash is used to ensure unique names and to verify the contents of the file. Default: False """ - if cfg is None: - cfg = getattr(model, 'default_cfg') - if cfg is None or 'url' not in cfg or not cfg['url']: - _logger.warning("Pretrained model URL does not exist, using random initialization.") + cfg = cfg or getattr(model, 'default_cfg') + if cfg is None or not cfg.get('url', None): + _logger.warning("No pretrained weights exist for this model. Using random initialization.") return url = cfg['url'] @@ -186,8 +185,7 @@ def adapt_input_conv(in_chans, conv_weight): def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False, hf_checkpoint=None, hf_revision=None): - if cfg is None: - cfg = getattr(model, 'default_cfg') + cfg = cfg or getattr(model, 'default_cfg') if hf_checkpoint is None: hg_checkpoint = cfg.get('hf_checkpoint') if hf_revision is None: @@ -405,6 +403,7 @@ def build_model_with_cfg( return model + def load_cfg_from_json(json_file: Union[str, os.PathLike]): with open(json_file, "r", encoding="utf-8") as reader: text = reader.read() @@ -417,3 +416,10 @@ def load_hf_checkpoint_config(checkpoint: str, revision: Optional[str] = None): url, library_name="timm", library_version=__version__, cache_dir=get_cache_dir() ) return load_cfg_from_json(cached_filed) + +def model_parameters(model, exclude_head=False): + if exclude_head: + # FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering + return [p for p in model.parameters()][:-2] + else: + return model.parameters() diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index c56c5780..b43ee5ef 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -1,10 +1,18 @@ -""" Normalizer Free RegNet / ResNet (pre-activation) Models +""" Normalization Free Nets. NFNet, NF-RegNet, NF-ResNet (pre-activation) Models Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - https://arxiv.org/abs/2101.08692 -NOTE: These models are a work in progress, no pretrained weights yet but I'm currently training some. -Details may change, especially once the paper authors release their official models. +Paper: `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + +Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets + +Status: +* These models are a work in progress, experiments ongoing. +* Pretrained weights for two models so far, more to come. +* Model details updated to closer match official JAX code now that it's released +* NF-ResNet, NF-RegNet-B, and NFNet-F models supported Hacked together by / copyright Ross Wightman, 2021. """ @@ -28,33 +36,78 @@ def _dcfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'crop_pct': 0.9, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'stem.conv', 'classifier': 'head.fc', **kwargs } -# FIXME finish -default_cfgs = { - 'nf_regnet_b0': _dcfg(url=''), - 'nf_regnet_b1': _dcfg(url='', input_size=(3, 240, 240), pool_size=(8, 8)), - 'nf_regnet_b2': _dcfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), - 'nf_regnet_b3': _dcfg(url='', input_size=(3, 272, 272), pool_size=(9, 9)), - 'nf_regnet_b4': _dcfg(url='', input_size=(3, 320, 320), pool_size=(10, 10)), - 'nf_regnet_b5': _dcfg(url='', input_size=(3, 384, 384), pool_size=(12, 12)), - - 'nf_resnet26': _dcfg(url='', first_conv='stem.conv'), - 'nf_resnet50': _dcfg(url='', first_conv='stem.conv'), - 'nf_resnet101': _dcfg(url='', first_conv='stem.conv'), - - 'nf_seresnet26': _dcfg(url='', first_conv='stem.conv'), - 'nf_seresnet50': _dcfg(url='', first_conv='stem.conv'), - 'nf_seresnet101': _dcfg(url='', first_conv='stem.conv'), - 'nf_ecaresnet26': _dcfg(url='', first_conv='stem.conv'), - 'nf_ecaresnet50': _dcfg(url='', first_conv='stem.conv'), - 'nf_ecaresnet101': _dcfg(url='', first_conv='stem.conv'), -} +default_cfgs = dict( + nfnet_f0=_dcfg( + url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'), + nfnet_f1=_dcfg( + url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320), first_conv='stem.conv1'), + nfnet_f2=_dcfg( + url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352), first_conv='stem.conv1'), + nfnet_f3=_dcfg( + url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416), first_conv='stem.conv1'), + nfnet_f4=_dcfg( + url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512), first_conv='stem.conv1'), + nfnet_f5=_dcfg( + url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544), first_conv='stem.conv1'), + nfnet_f6=_dcfg( + url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576), first_conv='stem.conv1'), + nfnet_f7=_dcfg( + url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608), first_conv='stem.conv1'), + + nfnet_f0s=_dcfg( + url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'), + nfnet_f1s=_dcfg( + url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320), first_conv='stem.conv1'), + nfnet_f2s=_dcfg( + url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352), first_conv='stem.conv1'), + nfnet_f3s=_dcfg( + url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416), first_conv='stem.conv1'), + nfnet_f4s=_dcfg( + url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512), first_conv='stem.conv1'), + nfnet_f5s=_dcfg( + url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544), first_conv='stem.conv1'), + nfnet_f6s=_dcfg( + url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576), first_conv='stem.conv1'), + nfnet_f7s=_dcfg( + url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608), first_conv='stem.conv1'), + + nfnet_l0a=_dcfg( + url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'), + nfnet_l0b=_dcfg( + url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'), + nfnet_l0c=_dcfg( + url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'), + + nf_regnet_b0=_dcfg(url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)), + nf_regnet_b1=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_regnet_b1_256_ra2-ad85cfef.pth', + pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288)), # NOT to paper spec + nf_regnet_b2=_dcfg(url='', pool_size=(8, 8), input_size=(3, 240, 240), test_input_size=(3, 272, 272)), + nf_regnet_b3=_dcfg(url='', pool_size=(9, 9), input_size=(3, 288, 288), test_input_size=(3, 320, 320)), + nf_regnet_b4=_dcfg(url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384)), + nf_regnet_b5=_dcfg(url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 456, 456)), + + nf_resnet26=_dcfg(url=''), + nf_resnet50=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_resnet50_ra2-9f236009.pth', + pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), crop_pct=0.94), + nf_resnet101=_dcfg(url=''), + + nf_seresnet26=_dcfg(url=''), + nf_seresnet50=_dcfg(url=''), + nf_seresnet101=_dcfg(url=''), + + nf_ecaresnet26=_dcfg(url=''), + nf_ecaresnet50=_dcfg(url=''), + nf_ecaresnet101=_dcfg(url=''), +) @dataclass @@ -65,69 +118,105 @@ class NfCfg: gamma_in_act: bool = False stem_type: str = '3x3' stem_chs: Optional[int] = None - group_size: Optional[int] = 8 - attn_layer: Optional[str] = 'se' - attn_kwargs: dict = field(default_factory=lambda: dict(reduction_ratio=0.5, divisor=8)) + group_size: Optional[int] = None + attn_layer: Optional[str] = None + attn_kwargs: dict = None attn_gain: float = 2.0 # NF correction gain to apply if attn layer is used - width_factor: float = 0.75 - bottle_ratio: float = 2.25 - efficient: bool = True # enables EfficientNet-like options that are used in paper 'nf_regnet_b*' models - num_features: int = 1280 # num out_channels for final conv (when enabled in efficient mode) + width_factor: float = 1.0 + bottle_ratio: float = 0.5 + num_features: int = 0 # num out_channels for final conv, no final_conv if 0 ch_div: int = 8 # round channels % 8 == 0 to keep tensor-core use optimal - skipinit: bool = False + reg: bool = False # enables EfficientNet-like options used in RegNet variants, expand from in_chs, se in middle + extra_conv: bool = False # extra 3x3 bottleneck convolution for NFNet models + skipinit: bool = False # disabled by default, non-trivial performance impact + zero_init_fc: bool = False act_layer: str = 'silu' +def _nfres_cfg( + depths, channels=(256, 512, 1024, 2048), group_size=None, act_layer='relu', attn_layer=None, attn_kwargs=None): + attn_kwargs = attn_kwargs or {} + cfg = NfCfg( + depths=depths, channels=channels, stem_type='7x7_pool', stem_chs=64, bottle_ratio=0.25, + group_size=group_size, act_layer=act_layer, attn_layer=attn_layer, attn_kwargs=attn_kwargs) + return cfg + + +def _nfreg_cfg(depths, channels=(48, 104, 208, 440)): + num_features = 1280 * channels[-1] // 440 + attn_kwargs = dict(reduction_ratio=0.5, divisor=8) + cfg = NfCfg( + depths=depths, channels=channels, stem_type='3x3', group_size=8, width_factor=0.75, bottle_ratio=2.25, + num_features=num_features, reg=True, attn_layer='se', attn_kwargs=attn_kwargs) + return cfg + + +def _nfnet_cfg( + depths, channels=(256, 512, 1536, 1536), group_size=128, bottle_ratio=0.5, feat_mult=2., + act_layer='gelu', attn_layer='se', attn_kwargs=None): + num_features = int(channels[-1] * feat_mult) + attn_kwargs = attn_kwargs if attn_kwargs is not None else dict(reduction_ratio=0.5, divisor=8) + cfg = NfCfg( + depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=group_size, + bottle_ratio=bottle_ratio, extra_conv=True, num_features=num_features, act_layer=act_layer, + attn_layer=attn_layer, attn_kwargs=attn_kwargs) + return cfg + + model_cfgs = dict( - # EffNet influenced RegNet defs - nf_regnet_b0=NfCfg(depths=(1, 3, 6, 6), channels=(48, 104, 208, 440), num_features=1280), - nf_regnet_b1=NfCfg(depths=(2, 4, 7, 7), channels=(48, 104, 208, 440), num_features=1280), - nf_regnet_b2=NfCfg(depths=(2, 4, 8, 8), channels=(56, 112, 232, 488), num_features=1416), - nf_regnet_b3=NfCfg(depths=(2, 5, 9, 9), channels=(56, 128, 248, 528), num_features=1536), - nf_regnet_b4=NfCfg(depths=(2, 6, 11, 11), channels=(64, 144, 288, 616), num_features=1792), - nf_regnet_b5=NfCfg(depths=(3, 7, 14, 14), channels=(80, 168, 336, 704), num_features=2048), + # NFNet-F models w/ GeLU + nfnet_f0=_nfnet_cfg(depths=(1, 2, 6, 3)), + nfnet_f1=_nfnet_cfg(depths=(2, 4, 12, 6)), + nfnet_f2=_nfnet_cfg(depths=(3, 6, 18, 9)), + nfnet_f3=_nfnet_cfg(depths=(4, 8, 24, 12)), + nfnet_f4=_nfnet_cfg(depths=(5, 10, 30, 15)), + nfnet_f5=_nfnet_cfg(depths=(6, 12, 36, 18)), + nfnet_f6=_nfnet_cfg(depths=(7, 14, 42, 21)), + nfnet_f7=_nfnet_cfg(depths=(8, 16, 48, 24)), + + # NFNet-F models w/ SiLU (much faster in PyTorch) + nfnet_f0s=_nfnet_cfg(depths=(1, 2, 6, 3), act_layer='silu'), + nfnet_f1s=_nfnet_cfg(depths=(2, 4, 12, 6), act_layer='silu'), + nfnet_f2s=_nfnet_cfg(depths=(3, 6, 18, 9), act_layer='silu'), + nfnet_f3s=_nfnet_cfg(depths=(4, 8, 24, 12), act_layer='silu'), + nfnet_f4s=_nfnet_cfg(depths=(5, 10, 30, 15), act_layer='silu'), + nfnet_f5s=_nfnet_cfg(depths=(6, 12, 36, 18), act_layer='silu'), + nfnet_f6s=_nfnet_cfg(depths=(7, 14, 42, 21), act_layer='silu'), + nfnet_f7s=_nfnet_cfg(depths=(8, 16, 48, 24), act_layer='silu'), + + # Experimental 'light' versions of nfnet-f that are little leaner + nfnet_l0a=_nfnet_cfg( + depths=(1, 2, 6, 3), channels=(256, 512, 1280, 1536), feat_mult=1.5, group_size=64, bottle_ratio=0.25, + attn_kwargs=dict(reduction_ratio=0.25, divisor=8), act_layer='silu'), + nfnet_l0b=_nfnet_cfg( + depths=(1, 2, 6, 3), channels=(256, 512, 1536, 1536), feat_mult=1.5, group_size=64, bottle_ratio=0.25, + attn_kwargs=dict(reduction_ratio=0.25, divisor=8), act_layer='silu'), + nfnet_l0c=_nfnet_cfg( + depths=(1, 2, 6, 3), channels=(256, 512, 1536, 1536), feat_mult=1.5, group_size=64, bottle_ratio=0.25, + attn_layer='eca', attn_kwargs=dict(), act_layer='silu'), + + # EffNet influenced RegNet defs. + # NOTE: These aren't quite the official ver, ch_div=1 must be set for exact ch counts. I round to ch_div=8. + nf_regnet_b0=_nfreg_cfg(depths=(1, 3, 6, 6)), + nf_regnet_b1=_nfreg_cfg(depths=(2, 4, 7, 7)), + nf_regnet_b2=_nfreg_cfg(depths=(2, 4, 8, 8), channels=(56, 112, 232, 488)), + nf_regnet_b3=_nfreg_cfg(depths=(2, 5, 9, 9), channels=(56, 128, 248, 528)), + nf_regnet_b4=_nfreg_cfg(depths=(2, 6, 11, 11), channels=(64, 144, 288, 616)), + nf_regnet_b5=_nfreg_cfg(depths=(3, 7, 14, 14), channels=(80, 168, 336, 704)), + # FIXME add B6-B8 # ResNet (preact, D style deep stem/avg down) defs - nf_resnet26=NfCfg( - depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048), - stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, - act_layer='relu', attn_layer=None,), - nf_resnet50=NfCfg( - depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048), - stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, - act_layer='relu', attn_layer=None), - nf_resnet101=NfCfg( - depths=(3, 4, 23, 3), channels=(256, 512, 1024, 2048), - stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, - act_layer='relu', attn_layer=None), - - - nf_seresnet26=NfCfg( - depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048), - stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, - act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)), - nf_seresnet50=NfCfg( - depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048), - stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, - act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)), - nf_seresnet101=NfCfg( - depths=(3, 4, 23, 3), channels=(256, 512, 1024, 2048), - stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, - act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)), - - - nf_ecaresnet26=NfCfg( - depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048), - stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, - act_layer='relu', attn_layer='eca', attn_kwargs=dict()), - nf_ecaresnet50=NfCfg( - depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048), - stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, - act_layer='relu', attn_layer='eca', attn_kwargs=dict()), - nf_ecaresnet101=NfCfg( - depths=(3, 4, 23, 3), channels=(256, 512, 1024, 2048), - stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, - act_layer='relu', attn_layer='eca', attn_kwargs=dict()), + nf_resnet26=_nfres_cfg(depths=(2, 2, 2, 2)), + nf_resnet50=_nfres_cfg(depths=(3, 4, 6, 3)), + nf_resnet101=_nfres_cfg(depths=(3, 4, 23, 3)), + + nf_seresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)), + nf_seresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)), + nf_seresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)), + + nf_ecaresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='eca', attn_kwargs=dict()), + nf_ecaresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='eca', attn_kwargs=dict()), + nf_ecaresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='eca', attn_kwargs=dict()), ) @@ -166,20 +255,20 @@ class DownsampleAvg(nn.Module): return self.conv(self.pool(x)) -class NormalizationFreeBlock(nn.Module): - """Normalization-free pre-activation block. +class NormFreeBlock(nn.Module): + """Normalization-Free pre-activation block. """ def __init__( self, in_chs, out_chs=None, stride=1, dilation=1, first_dilation=None, - alpha=1.0, beta=1.0, bottle_ratio=0.25, efficient=True, ch_div=1, group_size=None, - attn_layer=None, attn_gain=2.0, act_layer=None, conv_layer=None, drop_path_rate=0., skipinit=False): + alpha=1.0, beta=1.0, bottle_ratio=0.25, group_size=None, ch_div=1, reg=True, extra_conv=False, + skipinit=False, attn_layer=None, attn_gain=2.0, act_layer=None, conv_layer=None, drop_path_rate=0.): super().__init__() first_dilation = first_dilation or dilation out_chs = out_chs or in_chs - # EfficientNet-like models scale bottleneck from in_chs, otherwise scale from out_chs like ResNet - mid_chs = make_divisible(in_chs * bottle_ratio if efficient else out_chs * bottle_ratio, ch_div) - groups = 1 if group_size is None else mid_chs // group_size + # RegNet variants scale bottleneck from in_chs, otherwise scale from out_chs like ResNet + mid_chs = make_divisible(in_chs * bottle_ratio if reg else out_chs * bottle_ratio, ch_div) + groups = 1 if not group_size else mid_chs // group_size if group_size and group_size % ch_div == 0: mid_chs = group_size * groups # correct mid_chs if group_size divisible by ch_div, otherwise error self.alpha = alpha @@ -196,12 +285,22 @@ class NormalizationFreeBlock(nn.Module): self.conv1 = conv_layer(in_chs, mid_chs, 1) self.act2 = act_layer(inplace=True) self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups) - if attn_layer is not None: - self.attn = attn_layer(mid_chs) + if extra_conv: + self.act2b = act_layer(inplace=True) + self.conv2b = conv_layer(mid_chs, mid_chs, 3, stride=1, dilation=dilation, groups=groups) + else: + self.act2b = None + self.conv2b = None + if reg and attn_layer is not None: + self.attn = attn_layer(mid_chs) # RegNet blocks apply attn btw conv2 & 3 else: self.attn = None self.act3 = act_layer() self.conv3 = conv_layer(mid_chs, out_chs, 1) + if not reg and attn_layer is not None: + self.attn_last = attn_layer(out_chs) # ResNet blocks apply attn after conv3 + else: + self.attn_last = None self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() self.skipinit_gain = nn.Parameter(torch.tensor(0.)) if skipinit else None @@ -216,28 +315,48 @@ class NormalizationFreeBlock(nn.Module): # residual branch out = self.conv1(out) out = self.conv2(self.act2(out)) + if self.conv2b is not None: + out = self.conv2b(self.act2b(out)) if self.attn is not None: out = self.attn_gain * self.attn(out) out = self.conv3(self.act3(out)) + if self.attn_last is not None: + out = self.attn_gain * self.attn_last(out) out = self.drop_path(out) - if self.skipinit_gain is None: - out = out * self.alpha + shortcut - else: + + if self.skipinit_gain is not None: # this really slows things down for some reason, TBD - out = out * self.alpha * self.skipinit_gain + shortcut + out = out * self.skipinit_gain + out = out * self.alpha + shortcut return out -def create_stem(in_chs, out_chs, stem_type='', conv_layer=None): +def create_stem(in_chs, out_chs, stem_type='', conv_layer=None, act_layer=None): stem_stride = 2 + stem_feature = dict(num_chs=out_chs, reduction=2, module='') stem = OrderedDict() - assert stem_type in ('', 'deep', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool') + assert stem_type in ('', 'deep', 'deep_tiered', 'deep_quad', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool') if 'deep' in stem_type: - # 3 deep 3x3 conv stack as in ResNet V1D models. NOTE: doesn't work as well here - mid_chs = out_chs // 2 - stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2) - stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1) - stem['conv3'] = conv_layer(mid_chs, out_chs, kernel_size=3, stride=1) + if 'quad' in stem_type: + # 4 deep conv stack as in NFNet-F models + assert not 'pool' in stem_type + stem_chs = (out_chs // 8, out_chs // 4, out_chs // 2, out_chs) + strides = (2, 1, 1, 2) + stem_stride = 4 + stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.act4') + else: + if 'tiered' in stem_type: + stem_chs = (3 * out_chs // 8, out_chs // 2, out_chs) # 'T' resnets in resnet.py + else: + stem_chs = (out_chs // 2, out_chs // 2, out_chs) # 'D' ResNets + strides = (2, 1, 1) + stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.act3') + last_idx = len(stem_chs) - 1 + for i, (c, s) in enumerate(zip(stem_chs, strides)): + stem[f'conv{i + 1}'] = conv_layer(in_chs, c, kernel_size=3, stride=s) + if i != last_idx: + stem[f'act{i + 2}'] = act_layer(inplace=True) + in_chs = c elif '3x3' in stem_type: # 3x3 stem conv as in RegNet stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=3, stride=2) @@ -249,21 +368,37 @@ def create_stem(in_chs, out_chs, stem_type='', conv_layer=None): stem['pool'] = nn.MaxPool2d(3, stride=2, padding=1) stem_stride = 4 - return nn.Sequential(stem), stem_stride + return nn.Sequential(stem), stem_stride, stem_feature +# from https://github.com/deepmind/deepmind-research/tree/master/nfnets _nonlin_gamma = dict( - silu=1./.5595, - relu=(0.5 * (1. - 1. / math.pi)) ** -0.5, - identity=1.0 + identity=1.0, + celu=1.270926833152771, + elu=1.2716004848480225, + gelu=1.7015043497085571, + leaky_relu=1.70590341091156, + log_sigmoid=1.9193484783172607, + log_softmax=1.0002083778381348, + relu=1.7139588594436646, + relu6=1.7131484746932983, + selu=1.0008515119552612, + sigmoid=4.803835391998291, + silu=1.7881293296813965, + softsign=2.338853120803833, + softplus=1.9203323125839233, + tanh=1.5939117670059204, ) -class NormalizerFreeNet(nn.Module): - """ Normalizer-free ResNets and RegNets +class NormFreeNet(nn.Module): + """ Normalization-Free Network - As described in `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + As described in : + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - https://arxiv.org/abs/2101.08692 + and + `High-Performance Large-Scale Image Recognition Without Normalization` - https://arxiv.org/abs/2102.06171 This model aims to cover both the NFRegNet-Bx models as detailed in the paper's code snippets and the (preact) ResNet models described earlier in the paper. @@ -274,7 +409,7 @@ class NormalizerFreeNet(nn.Module): * activation correcting gamma constants are moved into the ScaledStdConv as it has less performance impact in PyTorch when done with the weight scaling there. This likely wasn't a concern in the JAX impl. * a config option `gamma_in_act` can be enabled to not apply gamma in StdConv as described above, but - apply it in each activation. This is slightly slower, and yields slightly different results. + apply it in each activation. This is slightly slower, numerically different, but matches official impl. * skipinit is disabled by default, it seems to have a rather drastic impact on GPU memory use and throughput for what it is/does. Approx 8-10% throughput loss. """ @@ -292,12 +427,12 @@ class NormalizerFreeNet(nn.Module): conv_layer = partial(ScaledStdConv2d, bias=True, gain=True, gamma=_nonlin_gamma[cfg.act_layer]) attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None - stem_chs = cfg.stem_chs or cfg.channels[0] - stem_chs = make_divisible(stem_chs * cfg.width_factor, cfg.ch_div) - self.stem, stem_stride = create_stem(in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer) + stem_chs = make_divisible((cfg.stem_chs or cfg.channels[0]) * cfg.width_factor, cfg.ch_div) + self.stem, stem_stride, stem_feat = create_stem( + in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer, act_layer=act_layer) - self.feature_info = [] # NOTE: there will be no stride == 2 feature if stem_stride == 4 - dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] + self.feature_info = [stem_feat] if stem_stride == 4 else [] + drop_path_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] prev_chs = stem_chs net_stride = stem_stride dilation = 1 @@ -305,8 +440,8 @@ class NormalizerFreeNet(nn.Module): stages = [] for stage_idx, stage_depth in enumerate(cfg.depths): stride = 1 if stage_idx == 0 and stem_stride > 2 else 2 - self.feature_info += [dict( - num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}.0.act1' if stride == 2 else '')] + if stride == 2: + self.feature_info += [dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}.0.act1')] if net_stride >= output_stride and stride > 1: dilation *= stride stride = 1 @@ -317,23 +452,24 @@ class NormalizerFreeNet(nn.Module): for block_idx in range(cfg.depths[stage_idx]): first_block = block_idx == 0 and stage_idx == 0 out_chs = make_divisible(cfg.channels[stage_idx] * cfg.width_factor, cfg.ch_div) - blocks += [NormalizationFreeBlock( + blocks += [NormFreeBlock( in_chs=prev_chs, out_chs=out_chs, alpha=cfg.alpha, - beta=1. / expected_var ** 0.5, # NOTE: beta used as multiplier in block + beta=1. / expected_var ** 0.5, stride=stride if block_idx == 0 else 1, dilation=dilation, first_dilation=first_dilation, group_size=cfg.group_size, - bottle_ratio=1. if cfg.efficient and first_block else cfg.bottle_ratio, - efficient=cfg.efficient, + bottle_ratio=1. if cfg.reg and first_block else cfg.bottle_ratio, ch_div=cfg.ch_div, + reg=cfg.reg, + extra_conv=cfg.extra_conv, + skipinit=cfg.skipinit, attn_layer=attn_layer, attn_gain=cfg.attn_gain, act_layer=act_layer, conv_layer=conv_layer, - drop_path_rate=dpr[stage_idx][block_idx], - skipinit=cfg.skipinit, + drop_path_rate=drop_path_rates[stage_idx][block_idx], )] if block_idx == 0: expected_var = 1. # expected var is reset after first block of each stage @@ -343,27 +479,27 @@ class NormalizerFreeNet(nn.Module): stages += [nn.Sequential(*blocks)] self.stages = nn.Sequential(*stages) - if cfg.efficient and cfg.num_features: + if cfg.num_features: # The paper NFRegNet models have an EfficientNet-like final head convolution. self.num_features = make_divisible(cfg.width_factor * cfg.num_features, cfg.ch_div) self.final_conv = conv_layer(prev_chs, self.num_features, 1) else: self.num_features = prev_chs self.final_conv = nn.Identity() - # FIXME not 100% clear on gamma subtleties final conv/final act in case where it's in stdconv - self.final_act = act_layer() + self.final_act = act_layer(inplace=cfg.num_features > 0) self.feature_info += [dict(num_chs=self.num_features, reduction=net_stride, module='final_act')] self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) for n, m in self.named_modules(): if 'fc' in n and isinstance(m, nn.Linear): - nn.init.zeros_(m.weight) + if cfg.zero_init_fc: + nn.init.zeros_(m.weight) + else: + nn.init.normal_(m.weight, 0., .01) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Conv2d): - # as per discussion with paper authors, original in haiku is - # hk.initializers.VarianceScaling(1.0, 'fan_in', 'normal')' w/ zero'd bias nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='linear') if m.bias is not None: nn.init.zeros_(m.bias) @@ -391,86 +527,303 @@ def _create_normfreenet(variant, pretrained=False, **kwargs): model_cfg = model_cfgs[variant] feature_cfg = dict(flatten_sequential=True) feature_cfg['feature_cls'] = 'hook' # pre-act models need hooks to grab feat from act1 in bottleneck blocks - if 'pool' in model_cfg.stem_type: - feature_cfg['out_indices'] = (1, 2, 3, 4) # no stride 2, 0 level feat for stride 4 maxpool stems in ResNet + if 'pool' in model_cfg.stem_type and 'deep' not in model_cfg.stem_type: + feature_cfg['out_indices'] = (1, 2, 3, 4) # no stride 2 feat for stride 4, 1 layer maxpool stems return build_model_with_cfg( - NormalizerFreeNet, variant, pretrained, + NormFreeNet, variant, pretrained, default_cfg=default_cfgs[variant], model_cfg=model_cfg, feature_cfg=feature_cfg, **kwargs) +@register_model +def nfnet_f0(pretrained=False, **kwargs): + """ NFNet-F0 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f0', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f1(pretrained=False, **kwargs): + """ NFNet-F1 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f1', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f2(pretrained=False, **kwargs): + """ NFNet-F2 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f2', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f3(pretrained=False, **kwargs): + """ NFNet-F3 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f3', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f4(pretrained=False, **kwargs): + """ NFNet-F4 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f4', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f5(pretrained=False, **kwargs): + """ NFNet-F5 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f5', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f6(pretrained=False, **kwargs): + """ NFNet-F6 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f6', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f7(pretrained=False, **kwargs): + """ NFNet-F7 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f7', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f0s(pretrained=False, **kwargs): + """ NFNet-F0 w/ SiLU + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f0s', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f1s(pretrained=False, **kwargs): + """ NFNet-F1 w/ SiLU + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f1s', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f2s(pretrained=False, **kwargs): + """ NFNet-F2 w/ SiLU + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f2s', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f3s(pretrained=False, **kwargs): + """ NFNet-F3 w/ SiLU + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f3s', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f4s(pretrained=False, **kwargs): + """ NFNet-F4 w/ SiLU + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f4s', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f5s(pretrained=False, **kwargs): + """ NFNet-F5 w/ SiLU + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f5s', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f6s(pretrained=False, **kwargs): + """ NFNet-F6 w/ SiLU + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f6s', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f7s(pretrained=False, **kwargs): + """ NFNet-F7 w/ SiLU + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f7s', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_l0a(pretrained=False, **kwargs): + """ NFNet-L0a w/ SiLU + My experimental 'light' model w/ 1280 width stage 3, 1.5x final_conv mult, 64 group_size, .25 bottleneck & SE ratio + """ + return _create_normfreenet('nfnet_l0a', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_l0b(pretrained=False, **kwargs): + """ NFNet-L0b w/ SiLU + My experimental 'light' model w/ 1.5x final_conv mult, 64 group_size, .25 bottleneck & SE ratio + """ + return _create_normfreenet('nfnet_l0b', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_l0c(pretrained=False, **kwargs): + """ NFNet-L0c w/ SiLU + My experimental 'light' model w/ 1.5x final_conv mult, 64 group_size, .25 bottleneck & ECA attn + """ + return _create_normfreenet('nfnet_l0c', pretrained=pretrained, **kwargs) + + @register_model def nf_regnet_b0(pretrained=False, **kwargs): + """ Normalization-Free RegNet-B0 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ return _create_normfreenet('nf_regnet_b0', pretrained=pretrained, **kwargs) @register_model def nf_regnet_b1(pretrained=False, **kwargs): + """ Normalization-Free RegNet-B1 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ return _create_normfreenet('nf_regnet_b1', pretrained=pretrained, **kwargs) @register_model def nf_regnet_b2(pretrained=False, **kwargs): + """ Normalization-Free RegNet-B2 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ return _create_normfreenet('nf_regnet_b2', pretrained=pretrained, **kwargs) @register_model def nf_regnet_b3(pretrained=False, **kwargs): + """ Normalization-Free RegNet-B3 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ return _create_normfreenet('nf_regnet_b3', pretrained=pretrained, **kwargs) @register_model def nf_regnet_b4(pretrained=False, **kwargs): + """ Normalization-Free RegNet-B4 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ return _create_normfreenet('nf_regnet_b4', pretrained=pretrained, **kwargs) @register_model def nf_regnet_b5(pretrained=False, **kwargs): + """ Normalization-Free RegNet-B5 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ return _create_normfreenet('nf_regnet_b5', pretrained=pretrained, **kwargs) @register_model def nf_resnet26(pretrained=False, **kwargs): + """ Normalization-Free ResNet-26 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ return _create_normfreenet('nf_resnet26', pretrained=pretrained, **kwargs) @register_model def nf_resnet50(pretrained=False, **kwargs): + """ Normalization-Free ResNet-50 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ return _create_normfreenet('nf_resnet50', pretrained=pretrained, **kwargs) @register_model def nf_resnet101(pretrained=False, **kwargs): + """ Normalization-Free ResNet-101 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ return _create_normfreenet('nf_resnet101', pretrained=pretrained, **kwargs) @register_model def nf_seresnet26(pretrained=False, **kwargs): + """ Normalization-Free SE-ResNet26 + """ return _create_normfreenet('nf_seresnet26', pretrained=pretrained, **kwargs) @register_model def nf_seresnet50(pretrained=False, **kwargs): + """ Normalization-Free SE-ResNet50 + """ return _create_normfreenet('nf_seresnet50', pretrained=pretrained, **kwargs) @register_model def nf_seresnet101(pretrained=False, **kwargs): + """ Normalization-Free SE-ResNet101 + """ return _create_normfreenet('nf_seresnet101', pretrained=pretrained, **kwargs) @register_model def nf_ecaresnet26(pretrained=False, **kwargs): + """ Normalization-Free ECA-ResNet26 + """ return _create_normfreenet('nf_ecaresnet26', pretrained=pretrained, **kwargs) @register_model def nf_ecaresnet50(pretrained=False, **kwargs): + """ Normalization-Free ECA-ResNet50 + """ return _create_normfreenet('nf_ecaresnet50', pretrained=pretrained, **kwargs) + @register_model def nf_ecaresnet101(pretrained=False, **kwargs): - return _create_normfreenet('nf_ecaresnet101', pretrained=pretrained, **kwargs) \ No newline at end of file + """ Normalization-Free ECA-ResNet101 + """ + return _create_normfreenet('nf_ecaresnet101', pretrained=pretrained, **kwargs) diff --git a/timm/utils/__init__.py b/timm/utils/__init__.py index 0f7c4b05..1c526e8c 100644 --- a/timm/utils/__init__.py +++ b/timm/utils/__init__.py @@ -1,4 +1,6 @@ +from .agc import adaptive_clip_grad from .checkpoint_saver import CheckpointSaver +from .clip_grad import dispatch_clip_grad from .cuda import ApexScaler, NativeScaler from .distributed import distribute_bn, reduce_tensor from .jit import set_jit_legacy diff --git a/timm/utils/agc.py b/timm/utils/agc.py new file mode 100644 index 00000000..f5140172 --- /dev/null +++ b/timm/utils/agc.py @@ -0,0 +1,42 @@ +""" Adaptive Gradient Clipping + +An impl of AGC, as per (https://arxiv.org/abs/2102.06171): + +@article{brock2021high, + author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan}, + title={High-Performance Large-Scale Image Recognition Without Normalization}, + journal={arXiv preprint arXiv:}, + year={2021} +} + +Code references: + * Official JAX impl (paper authors): https://github.com/deepmind/deepmind-research/tree/master/nfnets + * Phil Wang's PyTorch gist: https://gist.github.com/lucidrains/0d6560077edac419ab5d3aa29e674d5c + +Hacked together by / Copyright 2021 Ross Wightman +""" +import torch + + +def unitwise_norm(x, norm_type=2.0): + if x.ndim <= 1: + return x.norm(norm_type) + else: + # works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor + # might need special cases for other weights (possibly MHA) where this may not be true + return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True) + + +def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + for p in parameters: + if p.grad is None: + continue + p_data = p.detach() + g_data = p.grad.detach() + max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor) + grad_norm = unitwise_norm(g_data, norm_type=norm_type) + clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6)) + new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad) + p.grad.detach().copy_(new_grads) diff --git a/timm/utils/clip_grad.py b/timm/utils/clip_grad.py new file mode 100644 index 00000000..7eb40697 --- /dev/null +++ b/timm/utils/clip_grad.py @@ -0,0 +1,23 @@ +import torch + +from timm.utils.agc import adaptive_clip_grad + + +def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0): + """ Dispatch to gradient clipping method + + Args: + parameters (Iterable): model parameters to clip + value (float): clipping value/factor/norm, mode dependant + mode (str): clipping mode, one of 'norm', 'value', 'agc' + norm_type (float): p-norm, default 2.0 + """ + if mode == 'norm': + torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type) + elif mode == 'value': + torch.nn.utils.clip_grad_value_(parameters, value) + elif mode == 'agc': + adaptive_clip_grad(parameters, value, norm_type=norm_type) + else: + assert False, f"Unknown clip mode ({mode})." + diff --git a/timm/utils/cuda.py b/timm/utils/cuda.py index bcd29f58..9e7bddf3 100644 --- a/timm/utils/cuda.py +++ b/timm/utils/cuda.py @@ -11,15 +11,17 @@ except ImportError: amp = None has_apex = False +from .clip_grad import dispatch_clip_grad + class ApexScaler: state_dict_key = "amp" - def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False): + def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward(create_graph=create_graph) if clip_grad is not None: - torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), clip_grad) + dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode) optimizer.step() def state_dict(self): @@ -37,12 +39,12 @@ class NativeScaler: def __init__(self): self._scaler = torch.cuda.amp.GradScaler() - def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False): + def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): self._scaler.scale(loss).backward(create_graph=create_graph) if clip_grad is not None: assert parameters is not None self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place - torch.nn.utils.clip_grad_norm_(parameters, clip_grad) + dispatch_clip_grad(parameters, clip_grad, mode=clip_mode) self._scaler.step(optimizer) self._scaler.update() diff --git a/timm/version.py b/timm/version.py index 908c0bb7..9a8e054a 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.4.3' +__version__ = '0.4.4' diff --git a/train.py b/train.py index 0333d72f..9abcfed3 100755 --- a/train.py +++ b/train.py @@ -29,7 +29,7 @@ import torchvision.utils from torch.nn.parallel import DistributedDataParallel as NativeDDP from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset -from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model +from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model, model_parameters from timm.utils import * from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy from timm.optim import create_optimizer @@ -116,7 +116,8 @@ parser.add_argument('--weight-decay', type=float, default=0.0001, help='weight decay (default: 0.0001)') parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', help='Clip gradient norm (default: None, no clipping)') - +parser.add_argument('--clip-mode', type=str, default='norm', + help='Gradient clipping mode. One of ("norm", "value", "agc")') # Learning rate schedule parameters @@ -637,11 +638,16 @@ def train_one_epoch( optimizer.zero_grad() if loss_scaler is not None: loss_scaler( - loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order) + loss, optimizer, + clip_grad=args.clip_grad, clip_mode=args.clip_mode, + parameters=model_parameters(model, exclude_head='agc' in args.clip_mode), + create_graph=second_order) else: loss.backward(create_graph=second_order) if args.clip_grad is not None: - torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad) + dispatch_clip_grad( + model_parameters(model, exclude_head='agc' in args.clip_mode), + value=args.clip_grad, mode=args.clip_mode) optimizer.step() if model_ema is not None: