Merge branch 'master' into master

pull/440/head
Sylvain Gugger 4 years ago committed by GitHub
commit 482ab548dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,7 +2,20 @@
## What's New ## What's New
### Feb 16, 2021
* Add Adaptive Gradient Clipping (AGC) as per https://arxiv.org/abs/2102.06171. Integrated w/ PyTorch gradient clipping via mode arg that defaults to prev 'norm' mode. For backward arg compat, clip-grad arg must be specified to enable when using train.py.
* AGC w/ default clipping factor `--clip-grad .01 --clip-mode agc`
* PyTorch global norm of 1.0 (old behaviour, always norm), `--clip-grad 1.0`
* PyTorch value clipping of 10, `--clip-grad 10. --clip-mode value`
* AGC performance is definitely sensitive to the clipping factor. More experimentation needed to determine good values for smaller batch sizes and optimizers besides those in paper. So far I've found .001-.005 is necessary for stable RMSProp training w/ NFNet/NF-ResNet.
### Feb 12, 2021
* Update Normalization-Free nets to include new NFNet-F (https://arxiv.org/abs/2102.06171) model defs
### Feb 10, 2021 ### Feb 10, 2021
* First Normalization-Free model training experiments done,
* nf_resnet50 - 80.68 top-1 @ 288x288, 80.31 @ 256x256
* nf_regnet_b1 - 79.30 @ 288x288, 78.75 @ 256x256
* More model archs, incl a flexible ByobNet backbone ('Bring-your-own-blocks') * More model archs, incl a flexible ByobNet backbone ('Bring-your-own-blocks')
* GPU-Efficient-Networks (https://github.com/idstcv/GPU-Efficient-Networks), impl in `byobnet.py` * GPU-Efficient-Networks (https://github.com/idstcv/GPU-Efficient-Networks), impl in `byobnet.py`
* RepVGG (https://github.com/DingXiaoH/RepVGG), impl in `byobnet.py` * RepVGG (https://github.com/DingXiaoH/RepVGG), impl in `byobnet.py`
@ -161,6 +174,7 @@ A full version of the list below with source links can be found in the [document
* Inception-ResNet-V2 and Inception-V4 - https://arxiv.org/abs/1602.07261 * Inception-ResNet-V2 and Inception-V4 - https://arxiv.org/abs/1602.07261
* MobileNet-V3 (MBConvNet w/ Efficient Head) - https://arxiv.org/abs/1905.02244 * MobileNet-V3 (MBConvNet w/ Efficient Head) - https://arxiv.org/abs/1905.02244
* NASNet-A - https://arxiv.org/abs/1707.07012 * NASNet-A - https://arxiv.org/abs/1707.07012
* NFNet-F - https://arxiv.org/abs/2102.06171
* NF-RegNet / NF-ResNet - https://arxiv.org/abs/2101.08692 * NF-RegNet / NF-ResNet - https://arxiv.org/abs/2101.08692
* PNasNet - https://arxiv.org/abs/1712.00559 * PNasNet - https://arxiv.org/abs/1712.00559
* RegNet - https://arxiv.org/abs/2003.13678 * RegNet - https://arxiv.org/abs/2003.13678
@ -231,6 +245,7 @@ Several (less common) features that I often utilize in my projects are included.
* Efficient Channel Attention - ECA (https://arxiv.org/abs/1910.03151) * Efficient Channel Attention - ECA (https://arxiv.org/abs/1910.03151)
* Blur Pooling (https://arxiv.org/abs/1904.11486) * Blur Pooling (https://arxiv.org/abs/1904.11486)
* Space-to-Depth by [mrT23](https://github.com/mrT23/TResNet/blob/master/src/models/tresnet/layers/space_to_depth.py) (https://arxiv.org/abs/1801.04590) -- original paper? * Space-to-Depth by [mrT23](https://github.com/mrT23/TResNet/blob/master/src/models/tresnet/layers/space_to_depth.py) (https://arxiv.org/abs/1801.04590) -- original paper?
* Adaptive Gradient Clipping (https://arxiv.org/abs/2102.06171, https://github.com/deepmind/deepmind-research/tree/master/nfnets)
## Results ## Results

@ -19,7 +19,9 @@ NON_STD_FILTERS = ['vit_*']
# exclude models that cause specific test failures # exclude models that cause specific test failures
if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system(): if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
# GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models # GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models
EXCLUDE_FILTERS = ['*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm'] + NON_STD_FILTERS EXCLUDE_FILTERS = [
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm',
'nfnet_f4*', 'nfnet_f5*', 'nfnet_f6*', 'nfnet_f7*'] + NON_STD_FILTERS
else: else:
EXCLUDE_FILTERS = NON_STD_FILTERS EXCLUDE_FILTERS = NON_STD_FILTERS

@ -31,7 +31,7 @@ from .xception import *
from .xception_aligned import * from .xception_aligned import *
from .factory import create_model from .factory import create_model
from .helpers import load_checkpoint, resume_checkpoint from .helpers import load_checkpoint, resume_checkpoint, model_parameters
from .layers import TestTimePoolHead, apply_test_time_pool from .layers import TestTimePoolHead, apply_test_time_pool
from .layers import convert_splitbn_model from .layers import convert_splitbn_model
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit

@ -132,10 +132,9 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_
digits of the SHA256 hash of the contents of the file. The hash is used to digits of the SHA256 hash of the contents of the file. The hash is used to
ensure unique names and to verify the contents of the file. Default: False ensure unique names and to verify the contents of the file. Default: False
""" """
if cfg is None: cfg = cfg or getattr(model, 'default_cfg')
cfg = getattr(model, 'default_cfg') if cfg is None or not cfg.get('url', None):
if cfg is None or 'url' not in cfg or not cfg['url']: _logger.warning("No pretrained weights exist for this model. Using random initialization.")
_logger.warning("Pretrained model URL does not exist, using random initialization.")
return return
url = cfg['url'] url = cfg['url']
@ -186,8 +185,7 @@ def adapt_input_conv(in_chans, conv_weight):
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True,
progress=False, hf_checkpoint=None, hf_revision=None): progress=False, hf_checkpoint=None, hf_revision=None):
if cfg is None: cfg = cfg or getattr(model, 'default_cfg')
cfg = getattr(model, 'default_cfg')
if hf_checkpoint is None: if hf_checkpoint is None:
hg_checkpoint = cfg.get('hf_checkpoint') hg_checkpoint = cfg.get('hf_checkpoint')
if hf_revision is None: if hf_revision is None:
@ -405,6 +403,7 @@ def build_model_with_cfg(
return model return model
def load_cfg_from_json(json_file: Union[str, os.PathLike]): def load_cfg_from_json(json_file: Union[str, os.PathLike]):
with open(json_file, "r", encoding="utf-8") as reader: with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read() text = reader.read()
@ -417,3 +416,10 @@ def load_hf_checkpoint_config(checkpoint: str, revision: Optional[str] = None):
url, library_name="timm", library_version=__version__, cache_dir=get_cache_dir() url, library_name="timm", library_version=__version__, cache_dir=get_cache_dir()
) )
return load_cfg_from_json(cached_filed) return load_cfg_from_json(cached_filed)
def model_parameters(model, exclude_head=False):
if exclude_head:
# FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering
return [p for p in model.parameters()][:-2]
else:
return model.parameters()

@ -1,10 +1,18 @@
""" Normalizer Free RegNet / ResNet (pre-activation) Models """ Normalization Free Nets. NFNet, NF-RegNet, NF-ResNet (pre-activation) Models
Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
- https://arxiv.org/abs/2101.08692 - https://arxiv.org/abs/2101.08692
NOTE: These models are a work in progress, no pretrained weights yet but I'm currently training some. Paper: `High-Performance Large-Scale Image Recognition Without Normalization`
Details may change, especially once the paper authors release their official models. - https://arxiv.org/abs/2102.06171
Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets
Status:
* These models are a work in progress, experiments ongoing.
* Pretrained weights for two models so far, more to come.
* Model details updated to closer match official JAX code now that it's released
* NF-ResNet, NF-RegNet-B, and NFNet-F models supported
Hacked together by / copyright Ross Wightman, 2021. Hacked together by / copyright Ross Wightman, 2021.
""" """
@ -28,33 +36,78 @@ def _dcfg(url='', **kwargs):
return { return {
'url': url, 'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic', 'crop_pct': 0.9, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.conv', 'classifier': 'head.fc', 'first_conv': 'stem.conv', 'classifier': 'head.fc',
**kwargs **kwargs
} }
# FIXME finish
default_cfgs = {
'nf_regnet_b0': _dcfg(url=''),
'nf_regnet_b1': _dcfg(url='', input_size=(3, 240, 240), pool_size=(8, 8)),
'nf_regnet_b2': _dcfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'nf_regnet_b3': _dcfg(url='', input_size=(3, 272, 272), pool_size=(9, 9)),
'nf_regnet_b4': _dcfg(url='', input_size=(3, 320, 320), pool_size=(10, 10)),
'nf_regnet_b5': _dcfg(url='', input_size=(3, 384, 384), pool_size=(12, 12)),
'nf_resnet26': _dcfg(url='', first_conv='stem.conv'),
'nf_resnet50': _dcfg(url='', first_conv='stem.conv'),
'nf_resnet101': _dcfg(url='', first_conv='stem.conv'),
'nf_seresnet26': _dcfg(url='', first_conv='stem.conv'),
'nf_seresnet50': _dcfg(url='', first_conv='stem.conv'),
'nf_seresnet101': _dcfg(url='', first_conv='stem.conv'),
'nf_ecaresnet26': _dcfg(url='', first_conv='stem.conv'), default_cfgs = dict(
'nf_ecaresnet50': _dcfg(url='', first_conv='stem.conv'), nfnet_f0=_dcfg(
'nf_ecaresnet101': _dcfg(url='', first_conv='stem.conv'), url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'),
} nfnet_f1=_dcfg(
url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320), first_conv='stem.conv1'),
nfnet_f2=_dcfg(
url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352), first_conv='stem.conv1'),
nfnet_f3=_dcfg(
url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416), first_conv='stem.conv1'),
nfnet_f4=_dcfg(
url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512), first_conv='stem.conv1'),
nfnet_f5=_dcfg(
url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544), first_conv='stem.conv1'),
nfnet_f6=_dcfg(
url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576), first_conv='stem.conv1'),
nfnet_f7=_dcfg(
url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608), first_conv='stem.conv1'),
nfnet_f0s=_dcfg(
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'),
nfnet_f1s=_dcfg(
url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320), first_conv='stem.conv1'),
nfnet_f2s=_dcfg(
url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352), first_conv='stem.conv1'),
nfnet_f3s=_dcfg(
url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416), first_conv='stem.conv1'),
nfnet_f4s=_dcfg(
url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512), first_conv='stem.conv1'),
nfnet_f5s=_dcfg(
url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544), first_conv='stem.conv1'),
nfnet_f6s=_dcfg(
url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576), first_conv='stem.conv1'),
nfnet_f7s=_dcfg(
url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608), first_conv='stem.conv1'),
nfnet_l0a=_dcfg(
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'),
nfnet_l0b=_dcfg(
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'),
nfnet_l0c=_dcfg(
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'),
nf_regnet_b0=_dcfg(url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
nf_regnet_b1=_dcfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_regnet_b1_256_ra2-ad85cfef.pth',
pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288)), # NOT to paper spec
nf_regnet_b2=_dcfg(url='', pool_size=(8, 8), input_size=(3, 240, 240), test_input_size=(3, 272, 272)),
nf_regnet_b3=_dcfg(url='', pool_size=(9, 9), input_size=(3, 288, 288), test_input_size=(3, 320, 320)),
nf_regnet_b4=_dcfg(url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384)),
nf_regnet_b5=_dcfg(url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 456, 456)),
nf_resnet26=_dcfg(url=''),
nf_resnet50=_dcfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_resnet50_ra2-9f236009.pth',
pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), crop_pct=0.94),
nf_resnet101=_dcfg(url=''),
nf_seresnet26=_dcfg(url=''),
nf_seresnet50=_dcfg(url=''),
nf_seresnet101=_dcfg(url=''),
nf_ecaresnet26=_dcfg(url=''),
nf_ecaresnet50=_dcfg(url=''),
nf_ecaresnet101=_dcfg(url=''),
)
@dataclass @dataclass
@ -65,69 +118,105 @@ class NfCfg:
gamma_in_act: bool = False gamma_in_act: bool = False
stem_type: str = '3x3' stem_type: str = '3x3'
stem_chs: Optional[int] = None stem_chs: Optional[int] = None
group_size: Optional[int] = 8 group_size: Optional[int] = None
attn_layer: Optional[str] = 'se' attn_layer: Optional[str] = None
attn_kwargs: dict = field(default_factory=lambda: dict(reduction_ratio=0.5, divisor=8)) attn_kwargs: dict = None
attn_gain: float = 2.0 # NF correction gain to apply if attn layer is used attn_gain: float = 2.0 # NF correction gain to apply if attn layer is used
width_factor: float = 0.75 width_factor: float = 1.0
bottle_ratio: float = 2.25 bottle_ratio: float = 0.5
efficient: bool = True # enables EfficientNet-like options that are used in paper 'nf_regnet_b*' models num_features: int = 0 # num out_channels for final conv, no final_conv if 0
num_features: int = 1280 # num out_channels for final conv (when enabled in efficient mode)
ch_div: int = 8 # round channels % 8 == 0 to keep tensor-core use optimal ch_div: int = 8 # round channels % 8 == 0 to keep tensor-core use optimal
skipinit: bool = False reg: bool = False # enables EfficientNet-like options used in RegNet variants, expand from in_chs, se in middle
extra_conv: bool = False # extra 3x3 bottleneck convolution for NFNet models
skipinit: bool = False # disabled by default, non-trivial performance impact
zero_init_fc: bool = False
act_layer: str = 'silu' act_layer: str = 'silu'
def _nfres_cfg(
depths, channels=(256, 512, 1024, 2048), group_size=None, act_layer='relu', attn_layer=None, attn_kwargs=None):
attn_kwargs = attn_kwargs or {}
cfg = NfCfg(
depths=depths, channels=channels, stem_type='7x7_pool', stem_chs=64, bottle_ratio=0.25,
group_size=group_size, act_layer=act_layer, attn_layer=attn_layer, attn_kwargs=attn_kwargs)
return cfg
def _nfreg_cfg(depths, channels=(48, 104, 208, 440)):
num_features = 1280 * channels[-1] // 440
attn_kwargs = dict(reduction_ratio=0.5, divisor=8)
cfg = NfCfg(
depths=depths, channels=channels, stem_type='3x3', group_size=8, width_factor=0.75, bottle_ratio=2.25,
num_features=num_features, reg=True, attn_layer='se', attn_kwargs=attn_kwargs)
return cfg
def _nfnet_cfg(
depths, channels=(256, 512, 1536, 1536), group_size=128, bottle_ratio=0.5, feat_mult=2.,
act_layer='gelu', attn_layer='se', attn_kwargs=None):
num_features = int(channels[-1] * feat_mult)
attn_kwargs = attn_kwargs if attn_kwargs is not None else dict(reduction_ratio=0.5, divisor=8)
cfg = NfCfg(
depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=group_size,
bottle_ratio=bottle_ratio, extra_conv=True, num_features=num_features, act_layer=act_layer,
attn_layer=attn_layer, attn_kwargs=attn_kwargs)
return cfg
model_cfgs = dict( model_cfgs = dict(
# EffNet influenced RegNet defs # NFNet-F models w/ GeLU
nf_regnet_b0=NfCfg(depths=(1, 3, 6, 6), channels=(48, 104, 208, 440), num_features=1280), nfnet_f0=_nfnet_cfg(depths=(1, 2, 6, 3)),
nf_regnet_b1=NfCfg(depths=(2, 4, 7, 7), channels=(48, 104, 208, 440), num_features=1280), nfnet_f1=_nfnet_cfg(depths=(2, 4, 12, 6)),
nf_regnet_b2=NfCfg(depths=(2, 4, 8, 8), channels=(56, 112, 232, 488), num_features=1416), nfnet_f2=_nfnet_cfg(depths=(3, 6, 18, 9)),
nf_regnet_b3=NfCfg(depths=(2, 5, 9, 9), channels=(56, 128, 248, 528), num_features=1536), nfnet_f3=_nfnet_cfg(depths=(4, 8, 24, 12)),
nf_regnet_b4=NfCfg(depths=(2, 6, 11, 11), channels=(64, 144, 288, 616), num_features=1792), nfnet_f4=_nfnet_cfg(depths=(5, 10, 30, 15)),
nf_regnet_b5=NfCfg(depths=(3, 7, 14, 14), channels=(80, 168, 336, 704), num_features=2048), nfnet_f5=_nfnet_cfg(depths=(6, 12, 36, 18)),
nfnet_f6=_nfnet_cfg(depths=(7, 14, 42, 21)),
nfnet_f7=_nfnet_cfg(depths=(8, 16, 48, 24)),
# NFNet-F models w/ SiLU (much faster in PyTorch)
nfnet_f0s=_nfnet_cfg(depths=(1, 2, 6, 3), act_layer='silu'),
nfnet_f1s=_nfnet_cfg(depths=(2, 4, 12, 6), act_layer='silu'),
nfnet_f2s=_nfnet_cfg(depths=(3, 6, 18, 9), act_layer='silu'),
nfnet_f3s=_nfnet_cfg(depths=(4, 8, 24, 12), act_layer='silu'),
nfnet_f4s=_nfnet_cfg(depths=(5, 10, 30, 15), act_layer='silu'),
nfnet_f5s=_nfnet_cfg(depths=(6, 12, 36, 18), act_layer='silu'),
nfnet_f6s=_nfnet_cfg(depths=(7, 14, 42, 21), act_layer='silu'),
nfnet_f7s=_nfnet_cfg(depths=(8, 16, 48, 24), act_layer='silu'),
# Experimental 'light' versions of nfnet-f that are little leaner
nfnet_l0a=_nfnet_cfg(
depths=(1, 2, 6, 3), channels=(256, 512, 1280, 1536), feat_mult=1.5, group_size=64, bottle_ratio=0.25,
attn_kwargs=dict(reduction_ratio=0.25, divisor=8), act_layer='silu'),
nfnet_l0b=_nfnet_cfg(
depths=(1, 2, 6, 3), channels=(256, 512, 1536, 1536), feat_mult=1.5, group_size=64, bottle_ratio=0.25,
attn_kwargs=dict(reduction_ratio=0.25, divisor=8), act_layer='silu'),
nfnet_l0c=_nfnet_cfg(
depths=(1, 2, 6, 3), channels=(256, 512, 1536, 1536), feat_mult=1.5, group_size=64, bottle_ratio=0.25,
attn_layer='eca', attn_kwargs=dict(), act_layer='silu'),
# EffNet influenced RegNet defs.
# NOTE: These aren't quite the official ver, ch_div=1 must be set for exact ch counts. I round to ch_div=8.
nf_regnet_b0=_nfreg_cfg(depths=(1, 3, 6, 6)),
nf_regnet_b1=_nfreg_cfg(depths=(2, 4, 7, 7)),
nf_regnet_b2=_nfreg_cfg(depths=(2, 4, 8, 8), channels=(56, 112, 232, 488)),
nf_regnet_b3=_nfreg_cfg(depths=(2, 5, 9, 9), channels=(56, 128, 248, 528)),
nf_regnet_b4=_nfreg_cfg(depths=(2, 6, 11, 11), channels=(64, 144, 288, 616)),
nf_regnet_b5=_nfreg_cfg(depths=(3, 7, 14, 14), channels=(80, 168, 336, 704)),
# FIXME add B6-B8
# ResNet (preact, D style deep stem/avg down) defs # ResNet (preact, D style deep stem/avg down) defs
nf_resnet26=NfCfg( nf_resnet26=_nfres_cfg(depths=(2, 2, 2, 2)),
depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048), nf_resnet50=_nfres_cfg(depths=(3, 4, 6, 3)),
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, nf_resnet101=_nfres_cfg(depths=(3, 4, 23, 3)),
act_layer='relu', attn_layer=None,),
nf_resnet50=NfCfg( nf_seresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)),
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048), nf_seresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)),
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, nf_seresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)),
act_layer='relu', attn_layer=None),
nf_resnet101=NfCfg( nf_ecaresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='eca', attn_kwargs=dict()),
depths=(3, 4, 23, 3), channels=(256, 512, 1024, 2048), nf_ecaresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='eca', attn_kwargs=dict()),
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, nf_ecaresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='eca', attn_kwargs=dict()),
act_layer='relu', attn_layer=None),
nf_seresnet26=NfCfg(
depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048),
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
nf_seresnet50=NfCfg(
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
nf_seresnet101=NfCfg(
depths=(3, 4, 23, 3), channels=(256, 512, 1024, 2048),
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
nf_ecaresnet26=NfCfg(
depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048),
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer='eca', attn_kwargs=dict()),
nf_ecaresnet50=NfCfg(
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer='eca', attn_kwargs=dict()),
nf_ecaresnet101=NfCfg(
depths=(3, 4, 23, 3), channels=(256, 512, 1024, 2048),
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer='eca', attn_kwargs=dict()),
) )
@ -166,20 +255,20 @@ class DownsampleAvg(nn.Module):
return self.conv(self.pool(x)) return self.conv(self.pool(x))
class NormalizationFreeBlock(nn.Module): class NormFreeBlock(nn.Module):
"""Normalization-free pre-activation block. """Normalization-Free pre-activation block.
""" """
def __init__( def __init__(
self, in_chs, out_chs=None, stride=1, dilation=1, first_dilation=None, self, in_chs, out_chs=None, stride=1, dilation=1, first_dilation=None,
alpha=1.0, beta=1.0, bottle_ratio=0.25, efficient=True, ch_div=1, group_size=None, alpha=1.0, beta=1.0, bottle_ratio=0.25, group_size=None, ch_div=1, reg=True, extra_conv=False,
attn_layer=None, attn_gain=2.0, act_layer=None, conv_layer=None, drop_path_rate=0., skipinit=False): skipinit=False, attn_layer=None, attn_gain=2.0, act_layer=None, conv_layer=None, drop_path_rate=0.):
super().__init__() super().__init__()
first_dilation = first_dilation or dilation first_dilation = first_dilation or dilation
out_chs = out_chs or in_chs out_chs = out_chs or in_chs
# EfficientNet-like models scale bottleneck from in_chs, otherwise scale from out_chs like ResNet # RegNet variants scale bottleneck from in_chs, otherwise scale from out_chs like ResNet
mid_chs = make_divisible(in_chs * bottle_ratio if efficient else out_chs * bottle_ratio, ch_div) mid_chs = make_divisible(in_chs * bottle_ratio if reg else out_chs * bottle_ratio, ch_div)
groups = 1 if group_size is None else mid_chs // group_size groups = 1 if not group_size else mid_chs // group_size
if group_size and group_size % ch_div == 0: if group_size and group_size % ch_div == 0:
mid_chs = group_size * groups # correct mid_chs if group_size divisible by ch_div, otherwise error mid_chs = group_size * groups # correct mid_chs if group_size divisible by ch_div, otherwise error
self.alpha = alpha self.alpha = alpha
@ -196,12 +285,22 @@ class NormalizationFreeBlock(nn.Module):
self.conv1 = conv_layer(in_chs, mid_chs, 1) self.conv1 = conv_layer(in_chs, mid_chs, 1)
self.act2 = act_layer(inplace=True) self.act2 = act_layer(inplace=True)
self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups) self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups)
if attn_layer is not None: if extra_conv:
self.attn = attn_layer(mid_chs) self.act2b = act_layer(inplace=True)
self.conv2b = conv_layer(mid_chs, mid_chs, 3, stride=1, dilation=dilation, groups=groups)
else:
self.act2b = None
self.conv2b = None
if reg and attn_layer is not None:
self.attn = attn_layer(mid_chs) # RegNet blocks apply attn btw conv2 & 3
else: else:
self.attn = None self.attn = None
self.act3 = act_layer() self.act3 = act_layer()
self.conv3 = conv_layer(mid_chs, out_chs, 1) self.conv3 = conv_layer(mid_chs, out_chs, 1)
if not reg and attn_layer is not None:
self.attn_last = attn_layer(out_chs) # ResNet blocks apply attn after conv3
else:
self.attn_last = None
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
self.skipinit_gain = nn.Parameter(torch.tensor(0.)) if skipinit else None self.skipinit_gain = nn.Parameter(torch.tensor(0.)) if skipinit else None
@ -216,28 +315,48 @@ class NormalizationFreeBlock(nn.Module):
# residual branch # residual branch
out = self.conv1(out) out = self.conv1(out)
out = self.conv2(self.act2(out)) out = self.conv2(self.act2(out))
if self.conv2b is not None:
out = self.conv2b(self.act2b(out))
if self.attn is not None: if self.attn is not None:
out = self.attn_gain * self.attn(out) out = self.attn_gain * self.attn(out)
out = self.conv3(self.act3(out)) out = self.conv3(self.act3(out))
if self.attn_last is not None:
out = self.attn_gain * self.attn_last(out)
out = self.drop_path(out) out = self.drop_path(out)
if self.skipinit_gain is None:
out = out * self.alpha + shortcut if self.skipinit_gain is not None:
else:
# this really slows things down for some reason, TBD # this really slows things down for some reason, TBD
out = out * self.alpha * self.skipinit_gain + shortcut out = out * self.skipinit_gain
out = out * self.alpha + shortcut
return out return out
def create_stem(in_chs, out_chs, stem_type='', conv_layer=None): def create_stem(in_chs, out_chs, stem_type='', conv_layer=None, act_layer=None):
stem_stride = 2 stem_stride = 2
stem_feature = dict(num_chs=out_chs, reduction=2, module='')
stem = OrderedDict() stem = OrderedDict()
assert stem_type in ('', 'deep', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool') assert stem_type in ('', 'deep', 'deep_tiered', 'deep_quad', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool')
if 'deep' in stem_type: if 'deep' in stem_type:
# 3 deep 3x3 conv stack as in ResNet V1D models. NOTE: doesn't work as well here if 'quad' in stem_type:
mid_chs = out_chs // 2 # 4 deep conv stack as in NFNet-F models
stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2) assert not 'pool' in stem_type
stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1) stem_chs = (out_chs // 8, out_chs // 4, out_chs // 2, out_chs)
stem['conv3'] = conv_layer(mid_chs, out_chs, kernel_size=3, stride=1) strides = (2, 1, 1, 2)
stem_stride = 4
stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.act4')
else:
if 'tiered' in stem_type:
stem_chs = (3 * out_chs // 8, out_chs // 2, out_chs) # 'T' resnets in resnet.py
else:
stem_chs = (out_chs // 2, out_chs // 2, out_chs) # 'D' ResNets
strides = (2, 1, 1)
stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.act3')
last_idx = len(stem_chs) - 1
for i, (c, s) in enumerate(zip(stem_chs, strides)):
stem[f'conv{i + 1}'] = conv_layer(in_chs, c, kernel_size=3, stride=s)
if i != last_idx:
stem[f'act{i + 2}'] = act_layer(inplace=True)
in_chs = c
elif '3x3' in stem_type: elif '3x3' in stem_type:
# 3x3 stem conv as in RegNet # 3x3 stem conv as in RegNet
stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=3, stride=2) stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=3, stride=2)
@ -249,21 +368,37 @@ def create_stem(in_chs, out_chs, stem_type='', conv_layer=None):
stem['pool'] = nn.MaxPool2d(3, stride=2, padding=1) stem['pool'] = nn.MaxPool2d(3, stride=2, padding=1)
stem_stride = 4 stem_stride = 4
return nn.Sequential(stem), stem_stride return nn.Sequential(stem), stem_stride, stem_feature
# from https://github.com/deepmind/deepmind-research/tree/master/nfnets
_nonlin_gamma = dict( _nonlin_gamma = dict(
silu=1./.5595, identity=1.0,
relu=(0.5 * (1. - 1. / math.pi)) ** -0.5, celu=1.270926833152771,
identity=1.0 elu=1.2716004848480225,
gelu=1.7015043497085571,
leaky_relu=1.70590341091156,
log_sigmoid=1.9193484783172607,
log_softmax=1.0002083778381348,
relu=1.7139588594436646,
relu6=1.7131484746932983,
selu=1.0008515119552612,
sigmoid=4.803835391998291,
silu=1.7881293296813965,
softsign=2.338853120803833,
softplus=1.9203323125839233,
tanh=1.5939117670059204,
) )
class NormalizerFreeNet(nn.Module): class NormFreeNet(nn.Module):
""" Normalizer-free ResNets and RegNets """ Normalization-Free Network
As described in `Characterizing signal propagation to close the performance gap in unnormalized ResNets` As described in :
`Characterizing signal propagation to close the performance gap in unnormalized ResNets`
- https://arxiv.org/abs/2101.08692 - https://arxiv.org/abs/2101.08692
and
`High-Performance Large-Scale Image Recognition Without Normalization` - https://arxiv.org/abs/2102.06171
This model aims to cover both the NFRegNet-Bx models as detailed in the paper's code snippets and This model aims to cover both the NFRegNet-Bx models as detailed in the paper's code snippets and
the (preact) ResNet models described earlier in the paper. the (preact) ResNet models described earlier in the paper.
@ -274,7 +409,7 @@ class NormalizerFreeNet(nn.Module):
* activation correcting gamma constants are moved into the ScaledStdConv as it has less performance * activation correcting gamma constants are moved into the ScaledStdConv as it has less performance
impact in PyTorch when done with the weight scaling there. This likely wasn't a concern in the JAX impl. impact in PyTorch when done with the weight scaling there. This likely wasn't a concern in the JAX impl.
* a config option `gamma_in_act` can be enabled to not apply gamma in StdConv as described above, but * a config option `gamma_in_act` can be enabled to not apply gamma in StdConv as described above, but
apply it in each activation. This is slightly slower, and yields slightly different results. apply it in each activation. This is slightly slower, numerically different, but matches official impl.
* skipinit is disabled by default, it seems to have a rather drastic impact on GPU memory use and throughput * skipinit is disabled by default, it seems to have a rather drastic impact on GPU memory use and throughput
for what it is/does. Approx 8-10% throughput loss. for what it is/does. Approx 8-10% throughput loss.
""" """
@ -292,12 +427,12 @@ class NormalizerFreeNet(nn.Module):
conv_layer = partial(ScaledStdConv2d, bias=True, gain=True, gamma=_nonlin_gamma[cfg.act_layer]) conv_layer = partial(ScaledStdConv2d, bias=True, gain=True, gamma=_nonlin_gamma[cfg.act_layer])
attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
stem_chs = cfg.stem_chs or cfg.channels[0] stem_chs = make_divisible((cfg.stem_chs or cfg.channels[0]) * cfg.width_factor, cfg.ch_div)
stem_chs = make_divisible(stem_chs * cfg.width_factor, cfg.ch_div) self.stem, stem_stride, stem_feat = create_stem(
self.stem, stem_stride = create_stem(in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer) in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer, act_layer=act_layer)
self.feature_info = [] # NOTE: there will be no stride == 2 feature if stem_stride == 4 self.feature_info = [stem_feat] if stem_stride == 4 else []
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] drop_path_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
prev_chs = stem_chs prev_chs = stem_chs
net_stride = stem_stride net_stride = stem_stride
dilation = 1 dilation = 1
@ -305,8 +440,8 @@ class NormalizerFreeNet(nn.Module):
stages = [] stages = []
for stage_idx, stage_depth in enumerate(cfg.depths): for stage_idx, stage_depth in enumerate(cfg.depths):
stride = 1 if stage_idx == 0 and stem_stride > 2 else 2 stride = 1 if stage_idx == 0 and stem_stride > 2 else 2
self.feature_info += [dict( if stride == 2:
num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}.0.act1' if stride == 2 else '')] self.feature_info += [dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}.0.act1')]
if net_stride >= output_stride and stride > 1: if net_stride >= output_stride and stride > 1:
dilation *= stride dilation *= stride
stride = 1 stride = 1
@ -317,23 +452,24 @@ class NormalizerFreeNet(nn.Module):
for block_idx in range(cfg.depths[stage_idx]): for block_idx in range(cfg.depths[stage_idx]):
first_block = block_idx == 0 and stage_idx == 0 first_block = block_idx == 0 and stage_idx == 0
out_chs = make_divisible(cfg.channels[stage_idx] * cfg.width_factor, cfg.ch_div) out_chs = make_divisible(cfg.channels[stage_idx] * cfg.width_factor, cfg.ch_div)
blocks += [NormalizationFreeBlock( blocks += [NormFreeBlock(
in_chs=prev_chs, out_chs=out_chs, in_chs=prev_chs, out_chs=out_chs,
alpha=cfg.alpha, alpha=cfg.alpha,
beta=1. / expected_var ** 0.5, # NOTE: beta used as multiplier in block beta=1. / expected_var ** 0.5,
stride=stride if block_idx == 0 else 1, stride=stride if block_idx == 0 else 1,
dilation=dilation, dilation=dilation,
first_dilation=first_dilation, first_dilation=first_dilation,
group_size=cfg.group_size, group_size=cfg.group_size,
bottle_ratio=1. if cfg.efficient and first_block else cfg.bottle_ratio, bottle_ratio=1. if cfg.reg and first_block else cfg.bottle_ratio,
efficient=cfg.efficient,
ch_div=cfg.ch_div, ch_div=cfg.ch_div,
reg=cfg.reg,
extra_conv=cfg.extra_conv,
skipinit=cfg.skipinit,
attn_layer=attn_layer, attn_layer=attn_layer,
attn_gain=cfg.attn_gain, attn_gain=cfg.attn_gain,
act_layer=act_layer, act_layer=act_layer,
conv_layer=conv_layer, conv_layer=conv_layer,
drop_path_rate=dpr[stage_idx][block_idx], drop_path_rate=drop_path_rates[stage_idx][block_idx],
skipinit=cfg.skipinit,
)] )]
if block_idx == 0: if block_idx == 0:
expected_var = 1. # expected var is reset after first block of each stage expected_var = 1. # expected var is reset after first block of each stage
@ -343,27 +479,27 @@ class NormalizerFreeNet(nn.Module):
stages += [nn.Sequential(*blocks)] stages += [nn.Sequential(*blocks)]
self.stages = nn.Sequential(*stages) self.stages = nn.Sequential(*stages)
if cfg.efficient and cfg.num_features: if cfg.num_features:
# The paper NFRegNet models have an EfficientNet-like final head convolution. # The paper NFRegNet models have an EfficientNet-like final head convolution.
self.num_features = make_divisible(cfg.width_factor * cfg.num_features, cfg.ch_div) self.num_features = make_divisible(cfg.width_factor * cfg.num_features, cfg.ch_div)
self.final_conv = conv_layer(prev_chs, self.num_features, 1) self.final_conv = conv_layer(prev_chs, self.num_features, 1)
else: else:
self.num_features = prev_chs self.num_features = prev_chs
self.final_conv = nn.Identity() self.final_conv = nn.Identity()
# FIXME not 100% clear on gamma subtleties final conv/final act in case where it's in stdconv self.final_act = act_layer(inplace=cfg.num_features > 0)
self.final_act = act_layer()
self.feature_info += [dict(num_chs=self.num_features, reduction=net_stride, module='final_act')] self.feature_info += [dict(num_chs=self.num_features, reduction=net_stride, module='final_act')]
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
for n, m in self.named_modules(): for n, m in self.named_modules():
if 'fc' in n and isinstance(m, nn.Linear): if 'fc' in n and isinstance(m, nn.Linear):
if cfg.zero_init_fc:
nn.init.zeros_(m.weight) nn.init.zeros_(m.weight)
else:
nn.init.normal_(m.weight, 0., .01)
if m.bias is not None: if m.bias is not None:
nn.init.zeros_(m.bias) nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d): elif isinstance(m, nn.Conv2d):
# as per discussion with paper authors, original in haiku is
# hk.initializers.VarianceScaling(1.0, 'fan_in', 'normal')' w/ zero'd bias
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='linear') nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='linear')
if m.bias is not None: if m.bias is not None:
nn.init.zeros_(m.bias) nn.init.zeros_(m.bias)
@ -391,86 +527,303 @@ def _create_normfreenet(variant, pretrained=False, **kwargs):
model_cfg = model_cfgs[variant] model_cfg = model_cfgs[variant]
feature_cfg = dict(flatten_sequential=True) feature_cfg = dict(flatten_sequential=True)
feature_cfg['feature_cls'] = 'hook' # pre-act models need hooks to grab feat from act1 in bottleneck blocks feature_cfg['feature_cls'] = 'hook' # pre-act models need hooks to grab feat from act1 in bottleneck blocks
if 'pool' in model_cfg.stem_type: if 'pool' in model_cfg.stem_type and 'deep' not in model_cfg.stem_type:
feature_cfg['out_indices'] = (1, 2, 3, 4) # no stride 2, 0 level feat for stride 4 maxpool stems in ResNet feature_cfg['out_indices'] = (1, 2, 3, 4) # no stride 2 feat for stride 4, 1 layer maxpool stems
return build_model_with_cfg( return build_model_with_cfg(
NormalizerFreeNet, variant, pretrained, NormFreeNet, variant, pretrained,
default_cfg=default_cfgs[variant], default_cfg=default_cfgs[variant],
model_cfg=model_cfg, model_cfg=model_cfg,
feature_cfg=feature_cfg, feature_cfg=feature_cfg,
**kwargs) **kwargs)
@register_model
def nfnet_f0(pretrained=False, **kwargs):
""" NFNet-F0
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('nfnet_f0', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f1(pretrained=False, **kwargs):
""" NFNet-F1
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('nfnet_f1', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f2(pretrained=False, **kwargs):
""" NFNet-F2
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('nfnet_f2', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f3(pretrained=False, **kwargs):
""" NFNet-F3
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('nfnet_f3', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f4(pretrained=False, **kwargs):
""" NFNet-F4
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('nfnet_f4', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f5(pretrained=False, **kwargs):
""" NFNet-F5
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('nfnet_f5', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f6(pretrained=False, **kwargs):
""" NFNet-F6
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('nfnet_f6', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f7(pretrained=False, **kwargs):
""" NFNet-F7
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('nfnet_f7', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f0s(pretrained=False, **kwargs):
""" NFNet-F0 w/ SiLU
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('nfnet_f0s', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f1s(pretrained=False, **kwargs):
""" NFNet-F1 w/ SiLU
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('nfnet_f1s', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f2s(pretrained=False, **kwargs):
""" NFNet-F2 w/ SiLU
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('nfnet_f2s', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f3s(pretrained=False, **kwargs):
""" NFNet-F3 w/ SiLU
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('nfnet_f3s', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f4s(pretrained=False, **kwargs):
""" NFNet-F4 w/ SiLU
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('nfnet_f4s', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f5s(pretrained=False, **kwargs):
""" NFNet-F5 w/ SiLU
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('nfnet_f5s', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f6s(pretrained=False, **kwargs):
""" NFNet-F6 w/ SiLU
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('nfnet_f6s', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f7s(pretrained=False, **kwargs):
""" NFNet-F7 w/ SiLU
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('nfnet_f7s', pretrained=pretrained, **kwargs)
@register_model
def nfnet_l0a(pretrained=False, **kwargs):
""" NFNet-L0a w/ SiLU
My experimental 'light' model w/ 1280 width stage 3, 1.5x final_conv mult, 64 group_size, .25 bottleneck & SE ratio
"""
return _create_normfreenet('nfnet_l0a', pretrained=pretrained, **kwargs)
@register_model
def nfnet_l0b(pretrained=False, **kwargs):
""" NFNet-L0b w/ SiLU
My experimental 'light' model w/ 1.5x final_conv mult, 64 group_size, .25 bottleneck & SE ratio
"""
return _create_normfreenet('nfnet_l0b', pretrained=pretrained, **kwargs)
@register_model
def nfnet_l0c(pretrained=False, **kwargs):
""" NFNet-L0c w/ SiLU
My experimental 'light' model w/ 1.5x final_conv mult, 64 group_size, .25 bottleneck & ECA attn
"""
return _create_normfreenet('nfnet_l0c', pretrained=pretrained, **kwargs)
@register_model @register_model
def nf_regnet_b0(pretrained=False, **kwargs): def nf_regnet_b0(pretrained=False, **kwargs):
""" Normalization-Free RegNet-B0
`Characterizing signal propagation to close the performance gap in unnormalized ResNets`
- https://arxiv.org/abs/2101.08692
"""
return _create_normfreenet('nf_regnet_b0', pretrained=pretrained, **kwargs) return _create_normfreenet('nf_regnet_b0', pretrained=pretrained, **kwargs)
@register_model @register_model
def nf_regnet_b1(pretrained=False, **kwargs): def nf_regnet_b1(pretrained=False, **kwargs):
""" Normalization-Free RegNet-B1
`Characterizing signal propagation to close the performance gap in unnormalized ResNets`
- https://arxiv.org/abs/2101.08692
"""
return _create_normfreenet('nf_regnet_b1', pretrained=pretrained, **kwargs) return _create_normfreenet('nf_regnet_b1', pretrained=pretrained, **kwargs)
@register_model @register_model
def nf_regnet_b2(pretrained=False, **kwargs): def nf_regnet_b2(pretrained=False, **kwargs):
""" Normalization-Free RegNet-B2
`Characterizing signal propagation to close the performance gap in unnormalized ResNets`
- https://arxiv.org/abs/2101.08692
"""
return _create_normfreenet('nf_regnet_b2', pretrained=pretrained, **kwargs) return _create_normfreenet('nf_regnet_b2', pretrained=pretrained, **kwargs)
@register_model @register_model
def nf_regnet_b3(pretrained=False, **kwargs): def nf_regnet_b3(pretrained=False, **kwargs):
""" Normalization-Free RegNet-B3
`Characterizing signal propagation to close the performance gap in unnormalized ResNets`
- https://arxiv.org/abs/2101.08692
"""
return _create_normfreenet('nf_regnet_b3', pretrained=pretrained, **kwargs) return _create_normfreenet('nf_regnet_b3', pretrained=pretrained, **kwargs)
@register_model @register_model
def nf_regnet_b4(pretrained=False, **kwargs): def nf_regnet_b4(pretrained=False, **kwargs):
""" Normalization-Free RegNet-B4
`Characterizing signal propagation to close the performance gap in unnormalized ResNets`
- https://arxiv.org/abs/2101.08692
"""
return _create_normfreenet('nf_regnet_b4', pretrained=pretrained, **kwargs) return _create_normfreenet('nf_regnet_b4', pretrained=pretrained, **kwargs)
@register_model @register_model
def nf_regnet_b5(pretrained=False, **kwargs): def nf_regnet_b5(pretrained=False, **kwargs):
""" Normalization-Free RegNet-B5
`Characterizing signal propagation to close the performance gap in unnormalized ResNets`
- https://arxiv.org/abs/2101.08692
"""
return _create_normfreenet('nf_regnet_b5', pretrained=pretrained, **kwargs) return _create_normfreenet('nf_regnet_b5', pretrained=pretrained, **kwargs)
@register_model @register_model
def nf_resnet26(pretrained=False, **kwargs): def nf_resnet26(pretrained=False, **kwargs):
""" Normalization-Free ResNet-26
`Characterizing signal propagation to close the performance gap in unnormalized ResNets`
- https://arxiv.org/abs/2101.08692
"""
return _create_normfreenet('nf_resnet26', pretrained=pretrained, **kwargs) return _create_normfreenet('nf_resnet26', pretrained=pretrained, **kwargs)
@register_model @register_model
def nf_resnet50(pretrained=False, **kwargs): def nf_resnet50(pretrained=False, **kwargs):
""" Normalization-Free ResNet-50
`Characterizing signal propagation to close the performance gap in unnormalized ResNets`
- https://arxiv.org/abs/2101.08692
"""
return _create_normfreenet('nf_resnet50', pretrained=pretrained, **kwargs) return _create_normfreenet('nf_resnet50', pretrained=pretrained, **kwargs)
@register_model @register_model
def nf_resnet101(pretrained=False, **kwargs): def nf_resnet101(pretrained=False, **kwargs):
""" Normalization-Free ResNet-101
`Characterizing signal propagation to close the performance gap in unnormalized ResNets`
- https://arxiv.org/abs/2101.08692
"""
return _create_normfreenet('nf_resnet101', pretrained=pretrained, **kwargs) return _create_normfreenet('nf_resnet101', pretrained=pretrained, **kwargs)
@register_model @register_model
def nf_seresnet26(pretrained=False, **kwargs): def nf_seresnet26(pretrained=False, **kwargs):
""" Normalization-Free SE-ResNet26
"""
return _create_normfreenet('nf_seresnet26', pretrained=pretrained, **kwargs) return _create_normfreenet('nf_seresnet26', pretrained=pretrained, **kwargs)
@register_model @register_model
def nf_seresnet50(pretrained=False, **kwargs): def nf_seresnet50(pretrained=False, **kwargs):
""" Normalization-Free SE-ResNet50
"""
return _create_normfreenet('nf_seresnet50', pretrained=pretrained, **kwargs) return _create_normfreenet('nf_seresnet50', pretrained=pretrained, **kwargs)
@register_model @register_model
def nf_seresnet101(pretrained=False, **kwargs): def nf_seresnet101(pretrained=False, **kwargs):
""" Normalization-Free SE-ResNet101
"""
return _create_normfreenet('nf_seresnet101', pretrained=pretrained, **kwargs) return _create_normfreenet('nf_seresnet101', pretrained=pretrained, **kwargs)
@register_model @register_model
def nf_ecaresnet26(pretrained=False, **kwargs): def nf_ecaresnet26(pretrained=False, **kwargs):
""" Normalization-Free ECA-ResNet26
"""
return _create_normfreenet('nf_ecaresnet26', pretrained=pretrained, **kwargs) return _create_normfreenet('nf_ecaresnet26', pretrained=pretrained, **kwargs)
@register_model @register_model
def nf_ecaresnet50(pretrained=False, **kwargs): def nf_ecaresnet50(pretrained=False, **kwargs):
""" Normalization-Free ECA-ResNet50
"""
return _create_normfreenet('nf_ecaresnet50', pretrained=pretrained, **kwargs) return _create_normfreenet('nf_ecaresnet50', pretrained=pretrained, **kwargs)
@register_model @register_model
def nf_ecaresnet101(pretrained=False, **kwargs): def nf_ecaresnet101(pretrained=False, **kwargs):
""" Normalization-Free ECA-ResNet101
"""
return _create_normfreenet('nf_ecaresnet101', pretrained=pretrained, **kwargs) return _create_normfreenet('nf_ecaresnet101', pretrained=pretrained, **kwargs)

@ -1,4 +1,6 @@
from .agc import adaptive_clip_grad
from .checkpoint_saver import CheckpointSaver from .checkpoint_saver import CheckpointSaver
from .clip_grad import dispatch_clip_grad
from .cuda import ApexScaler, NativeScaler from .cuda import ApexScaler, NativeScaler
from .distributed import distribute_bn, reduce_tensor from .distributed import distribute_bn, reduce_tensor
from .jit import set_jit_legacy from .jit import set_jit_legacy

@ -0,0 +1,42 @@
""" Adaptive Gradient Clipping
An impl of AGC, as per (https://arxiv.org/abs/2102.06171):
@article{brock2021high,
author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan},
title={High-Performance Large-Scale Image Recognition Without Normalization},
journal={arXiv preprint arXiv:},
year={2021}
}
Code references:
* Official JAX impl (paper authors): https://github.com/deepmind/deepmind-research/tree/master/nfnets
* Phil Wang's PyTorch gist: https://gist.github.com/lucidrains/0d6560077edac419ab5d3aa29e674d5c
Hacked together by / Copyright 2021 Ross Wightman
"""
import torch
def unitwise_norm(x, norm_type=2.0):
if x.ndim <= 1:
return x.norm(norm_type)
else:
# works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor
# might need special cases for other weights (possibly MHA) where this may not be true
return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True)
def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
for p in parameters:
if p.grad is None:
continue
p_data = p.detach()
g_data = p.grad.detach()
max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor)
grad_norm = unitwise_norm(g_data, norm_type=norm_type)
clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6))
new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad)
p.grad.detach().copy_(new_grads)

@ -0,0 +1,23 @@
import torch
from timm.utils.agc import adaptive_clip_grad
def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0):
""" Dispatch to gradient clipping method
Args:
parameters (Iterable): model parameters to clip
value (float): clipping value/factor/norm, mode dependant
mode (str): clipping mode, one of 'norm', 'value', 'agc'
norm_type (float): p-norm, default 2.0
"""
if mode == 'norm':
torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type)
elif mode == 'value':
torch.nn.utils.clip_grad_value_(parameters, value)
elif mode == 'agc':
adaptive_clip_grad(parameters, value, norm_type=norm_type)
else:
assert False, f"Unknown clip mode ({mode})."

@ -11,15 +11,17 @@ except ImportError:
amp = None amp = None
has_apex = False has_apex = False
from .clip_grad import dispatch_clip_grad
class ApexScaler: class ApexScaler:
state_dict_key = "amp" state_dict_key = "amp"
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False): def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
with amp.scale_loss(loss, optimizer) as scaled_loss: with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward(create_graph=create_graph) scaled_loss.backward(create_graph=create_graph)
if clip_grad is not None: if clip_grad is not None:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), clip_grad) dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode)
optimizer.step() optimizer.step()
def state_dict(self): def state_dict(self):
@ -37,12 +39,12 @@ class NativeScaler:
def __init__(self): def __init__(self):
self._scaler = torch.cuda.amp.GradScaler() self._scaler = torch.cuda.amp.GradScaler()
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False): def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
self._scaler.scale(loss).backward(create_graph=create_graph) self._scaler.scale(loss).backward(create_graph=create_graph)
if clip_grad is not None: if clip_grad is not None:
assert parameters is not None assert parameters is not None
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
torch.nn.utils.clip_grad_norm_(parameters, clip_grad) dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
self._scaler.step(optimizer) self._scaler.step(optimizer)
self._scaler.update() self._scaler.update()

@ -1 +1 @@
__version__ = '0.4.3' __version__ = '0.4.4'

@ -29,7 +29,7 @@ import torchvision.utils
from torch.nn.parallel import DistributedDataParallel as NativeDDP from torch.nn.parallel import DistributedDataParallel as NativeDDP
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model, model_parameters
from timm.utils import * from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer from timm.optim import create_optimizer
@ -116,7 +116,8 @@ parser.add_argument('--weight-decay', type=float, default=0.0001,
help='weight decay (default: 0.0001)') help='weight decay (default: 0.0001)')
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)') help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--clip-mode', type=str, default='norm',
help='Gradient clipping mode. One of ("norm", "value", "agc")')
# Learning rate schedule parameters # Learning rate schedule parameters
@ -637,11 +638,16 @@ def train_one_epoch(
optimizer.zero_grad() optimizer.zero_grad()
if loss_scaler is not None: if loss_scaler is not None:
loss_scaler( loss_scaler(
loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order) loss, optimizer,
clip_grad=args.clip_grad, clip_mode=args.clip_mode,
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
create_graph=second_order)
else: else:
loss.backward(create_graph=second_order) loss.backward(create_graph=second_order)
if args.clip_grad is not None: if args.clip_grad is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad) dispatch_clip_grad(
model_parameters(model, exclude_head='agc' in args.clip_mode),
value=args.clip_grad, mode=args.clip_mode)
optimizer.step() optimizer.step()
if model_ema is not None: if model_ema is not None:

Loading…
Cancel
Save