Merge pull request #427 from rwightman/nfnet

Add NFNet-F models and tweak existing NF models.
pull/437/head
Ross Wightman 4 years ago committed by GitHub
commit 4df513c68f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,6 +2,9 @@
## What's New ## What's New
### Feb 12, 2021
* Update Normalization-Free nets to include new NFNet-F (https://arxiv.org/abs/2102.06171) model defs
### Feb 10, 2021 ### Feb 10, 2021
* First Normalization-Free model training experiments done, * First Normalization-Free model training experiments done,
* nf_resnet50 - 80.68 top-1 @ 288x288, 80.31 @ 256x256 * nf_resnet50 - 80.68 top-1 @ 288x288, 80.31 @ 256x256
@ -164,6 +167,7 @@ A full version of the list below with source links can be found in the [document
* Inception-ResNet-V2 and Inception-V4 - https://arxiv.org/abs/1602.07261 * Inception-ResNet-V2 and Inception-V4 - https://arxiv.org/abs/1602.07261
* MobileNet-V3 (MBConvNet w/ Efficient Head) - https://arxiv.org/abs/1905.02244 * MobileNet-V3 (MBConvNet w/ Efficient Head) - https://arxiv.org/abs/1905.02244
* NASNet-A - https://arxiv.org/abs/1707.07012 * NASNet-A - https://arxiv.org/abs/1707.07012
* NFNet-F - https://arxiv.org/abs/2102.06171
* NF-RegNet / NF-ResNet - https://arxiv.org/abs/2101.08692 * NF-RegNet / NF-ResNet - https://arxiv.org/abs/2101.08692
* PNasNet - https://arxiv.org/abs/1712.00559 * PNasNet - https://arxiv.org/abs/1712.00559
* RegNet - https://arxiv.org/abs/2003.13678 * RegNet - https://arxiv.org/abs/2003.13678

@ -19,7 +19,9 @@ NON_STD_FILTERS = ['vit_*']
# exclude models that cause specific test failures # exclude models that cause specific test failures
if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system(): if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
# GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models # GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models
EXCLUDE_FILTERS = ['*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm'] + NON_STD_FILTERS EXCLUDE_FILTERS = [
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm',
'nfnet_f4*', 'nfnet_f5*', 'nfnet_f6*', 'nfnet_f7*'] + NON_STD_FILTERS
else: else:
EXCLUDE_FILTERS = NON_STD_FILTERS EXCLUDE_FILTERS = NON_STD_FILTERS

@ -1,10 +1,18 @@
""" Normalizer Free RegNet / ResNet (pre-activation) Models """ Normalization Free Nets. NFNet, NF-RegNet, NF-ResNet (pre-activation) Models
Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
- https://arxiv.org/abs/2101.08692 - https://arxiv.org/abs/2101.08692
NOTE: These models are a work in progress, no pretrained weights yet but I'm currently training some. Paper: `High-Performance Large-Scale Image Recognition Without Normalization`
Details may change, especially once the paper authors release their official models. - https://arxiv.org/abs/2102.06171
Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets
Status:
* These models are a work in progress, experiments ongoing.
* Pretrained weights for two models so far, more to come.
* Model details updated to closer match official JAX code now that it's released
* NF-ResNet, NF-RegNet-B, and NFNet-F models supported
Hacked together by / copyright Ross Wightman, 2021. Hacked together by / copyright Ross Wightman, 2021.
""" """
@ -28,37 +36,71 @@ def _dcfg(url='', **kwargs):
return { return {
'url': url, 'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic', 'crop_pct': 0.9, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.conv', 'classifier': 'head.fc', 'first_conv': 'stem.conv', 'classifier': 'head.fc',
**kwargs **kwargs
} }
default_cfgs = { default_cfgs = dict(
'nf_regnet_b0': _dcfg(url=''), nfnet_f0=_dcfg(
'nf_regnet_b1': _dcfg( url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'),
nfnet_f1=_dcfg(
url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320), first_conv='stem.conv1'),
nfnet_f2=_dcfg(
url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352), first_conv='stem.conv1'),
nfnet_f3=_dcfg(
url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416), first_conv='stem.conv1'),
nfnet_f4=_dcfg(
url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512), first_conv='stem.conv1'),
nfnet_f5=_dcfg(
url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544), first_conv='stem.conv1'),
nfnet_f6=_dcfg(
url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576), first_conv='stem.conv1'),
nfnet_f7=_dcfg(
url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608), first_conv='stem.conv1'),
nfnet_f0s=_dcfg(
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'),
nfnet_f1s=_dcfg(
url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320), first_conv='stem.conv1'),
nfnet_f2s=_dcfg(
url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352), first_conv='stem.conv1'),
nfnet_f3s=_dcfg(
url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416), first_conv='stem.conv1'),
nfnet_f4s=_dcfg(
url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512), first_conv='stem.conv1'),
nfnet_f5s=_dcfg(
url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544), first_conv='stem.conv1'),
nfnet_f6s=_dcfg(
url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576), first_conv='stem.conv1'),
nfnet_f7s=_dcfg(
url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608), first_conv='stem.conv1'),
nf_regnet_b0=_dcfg(url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
nf_regnet_b1=_dcfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_regnet_b1_256_ra2-ad85cfef.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_regnet_b1_256_ra2-ad85cfef.pth',
pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), crop_pct=0.9), pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288)), # NOT to paper spec
'nf_regnet_b2': _dcfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), nf_regnet_b2=_dcfg(url='', pool_size=(8, 8), input_size=(3, 240, 240), test_input_size=(3, 272, 272)),
'nf_regnet_b3': _dcfg(url='', input_size=(3, 272, 272), pool_size=(9, 9)), nf_regnet_b3=_dcfg(url='', pool_size=(9, 9), input_size=(3, 288, 288), test_input_size=(3, 320, 320)),
'nf_regnet_b4': _dcfg(url='', input_size=(3, 320, 320), pool_size=(10, 10)), nf_regnet_b4=_dcfg(url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384)),
'nf_regnet_b5': _dcfg(url='', input_size=(3, 384, 384), pool_size=(12, 12)), nf_regnet_b5=_dcfg(url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 456, 456)),
'nf_resnet26': _dcfg(url='', first_conv='stem.conv'), nf_resnet26=_dcfg(url=''),
'nf_resnet50': _dcfg( nf_resnet50=_dcfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_resnet50_ra2-9f236009.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_resnet50_ra2-9f236009.pth',
first_conv='stem.conv', pool_size=(8, 8), input_size=(3, 256, 256), crop_pct=0.94), pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), crop_pct=0.94),
'nf_resnet101': _dcfg(url='', first_conv='stem.conv'), nf_resnet101=_dcfg(url=''),
'nf_seresnet26': _dcfg(url='', first_conv='stem.conv'), nf_seresnet26=_dcfg(url=''),
'nf_seresnet50': _dcfg(url='', first_conv='stem.conv'), nf_seresnet50=_dcfg(url=''),
'nf_seresnet101': _dcfg(url='', first_conv='stem.conv'), nf_seresnet101=_dcfg(url=''),
'nf_ecaresnet26': _dcfg(url='', first_conv='stem.conv'), nf_ecaresnet26=_dcfg(url=''),
'nf_ecaresnet50': _dcfg(url='', first_conv='stem.conv'), nf_ecaresnet50=_dcfg(url=''),
'nf_ecaresnet101': _dcfg(url='', first_conv='stem.conv'), nf_ecaresnet101=_dcfg(url=''),
} )
@dataclass @dataclass
@ -69,69 +111,92 @@ class NfCfg:
gamma_in_act: bool = False gamma_in_act: bool = False
stem_type: str = '3x3' stem_type: str = '3x3'
stem_chs: Optional[int] = None stem_chs: Optional[int] = None
group_size: Optional[int] = 8 group_size: Optional[int] = None
attn_layer: Optional[str] = 'se' attn_layer: Optional[str] = None
attn_kwargs: dict = field(default_factory=lambda: dict(reduction_ratio=0.5, divisor=8)) attn_kwargs: dict = None
attn_gain: float = 2.0 # NF correction gain to apply if attn layer is used attn_gain: float = 2.0 # NF correction gain to apply if attn layer is used
width_factor: float = 0.75 width_factor: float = 1.0
bottle_ratio: float = 2.25 bottle_ratio: float = 0.5
efficient: bool = True # enables EfficientNet-like options that are used in paper 'nf_regnet_b*' models num_features: int = 0 # num out_channels for final conv, no final_conv if 0
num_features: int = 1280 # num out_channels for final conv (when enabled in efficient mode)
ch_div: int = 8 # round channels % 8 == 0 to keep tensor-core use optimal ch_div: int = 8 # round channels % 8 == 0 to keep tensor-core use optimal
skipinit: bool = False reg: bool = False # enables EfficientNet-like options used in RegNet variants, expand from in_chs, se in middle
extra_conv: bool = False # extra 3x3 bottleneck convolution for NFNet models
skipinit: bool = False # disabled by default, non-trivial performance impact
zero_init_fc: bool = False
act_layer: str = 'silu' act_layer: str = 'silu'
def _nfres_cfg(
depths, channels=(256, 512, 1024, 2048), group_size=None, act_layer='relu', attn_layer=None, attn_kwargs=None):
attn_kwargs = attn_kwargs or {}
cfg = NfCfg(
depths=depths, channels=channels, stem_type='7x7_pool', stem_chs=64, bottle_ratio=0.25,
group_size=group_size, act_layer=act_layer, attn_layer=attn_layer, attn_kwargs=attn_kwargs)
return cfg
def _nfreg_cfg(depths, channels=(48, 104, 208, 440)):
num_features = 1280 * channels[-1] // 440
attn_kwargs = dict(reduction_ratio=0.5, divisor=8)
cfg = NfCfg(
depths=depths, channels=channels, stem_type='3x3', group_size=8, width_factor=0.75, bottle_ratio=2.25,
num_features=num_features, reg=True, attn_layer='se', attn_kwargs=attn_kwargs)
return cfg
def _nfnet_cfg(depths, act_layer='gelu', attn_layer='se', attn_kwargs=None):
channels = (256, 512, 1536, 1536)
num_features = channels[-1] * 2
attn_kwargs = attn_kwargs or dict(reduction_ratio=0.5, divisor=8)
cfg = NfCfg(
depths=depths, channels=channels, stem_type='deep_quad', group_size=128, bottle_ratio=0.5, extra_conv=True,
num_features=num_features, act_layer=act_layer, attn_layer=attn_layer, attn_kwargs=attn_kwargs)
return cfg
model_cfgs = dict( model_cfgs = dict(
# EffNet influenced RegNet defs # NFNet-F models w/ GeLU
nf_regnet_b0=NfCfg(depths=(1, 3, 6, 6), channels=(48, 104, 208, 440), num_features=1280), nfnet_f0=_nfnet_cfg(depths=(1, 2, 6, 3)),
nf_regnet_b1=NfCfg(depths=(2, 4, 7, 7), channels=(48, 104, 208, 440), num_features=1280), nfnet_f1=_nfnet_cfg(depths=(2, 4, 12, 6)),
nf_regnet_b2=NfCfg(depths=(2, 4, 8, 8), channels=(56, 112, 232, 488), num_features=1416), nfnet_f2=_nfnet_cfg(depths=(3, 6, 18, 9)),
nf_regnet_b3=NfCfg(depths=(2, 5, 9, 9), channels=(56, 128, 248, 528), num_features=1536), nfnet_f3=_nfnet_cfg(depths=(4, 8, 24, 12)),
nf_regnet_b4=NfCfg(depths=(2, 6, 11, 11), channels=(64, 144, 288, 616), num_features=1792), nfnet_f4=_nfnet_cfg(depths=(5, 10, 30, 15)),
nf_regnet_b5=NfCfg(depths=(3, 7, 14, 14), channels=(80, 168, 336, 704), num_features=2048), nfnet_f5=_nfnet_cfg(depths=(6, 12, 36, 18)),
nfnet_f6=_nfnet_cfg(depths=(7, 14, 42, 21)),
nfnet_f7=_nfnet_cfg(depths=(8, 16, 48, 24)),
# NFNet-F models w/ SiLU (much faster in PyTorch)
nfnet_f0s=_nfnet_cfg(depths=(1, 2, 6, 3), act_layer='silu'),
nfnet_f1s=_nfnet_cfg(depths=(2, 4, 12, 6), act_layer='silu'),
nfnet_f2s=_nfnet_cfg(depths=(3, 6, 18, 9), act_layer='silu'),
nfnet_f3s=_nfnet_cfg(depths=(4, 8, 24, 12), act_layer='silu'),
nfnet_f4s=_nfnet_cfg(depths=(5, 10, 30, 15), act_layer='silu'),
nfnet_f5s=_nfnet_cfg(depths=(6, 12, 36, 18), act_layer='silu'),
nfnet_f6s=_nfnet_cfg(depths=(7, 14, 42, 21), act_layer='silu'),
nfnet_f7s=_nfnet_cfg(depths=(8, 16, 48, 24), act_layer='silu'),
# EffNet influenced RegNet defs.
# NOTE: These aren't quite the official ver, ch_div=1 must be set for exact ch counts. I round to ch_div=8.
nf_regnet_b0=_nfreg_cfg(depths=(1, 3, 6, 6)),
nf_regnet_b1=_nfreg_cfg(depths=(2, 4, 7, 7)),
nf_regnet_b2=_nfreg_cfg(depths=(2, 4, 8, 8), channels=(56, 112, 232, 488)),
nf_regnet_b3=_nfreg_cfg(depths=(2, 5, 9, 9), channels=(56, 128, 248, 528)),
nf_regnet_b4=_nfreg_cfg(depths=(2, 6, 11, 11), channels=(64, 144, 288, 616)),
nf_regnet_b5=_nfreg_cfg(depths=(3, 7, 14, 14), channels=(80, 168, 336, 704)),
# FIXME add B6-B8
# ResNet (preact, D style deep stem/avg down) defs # ResNet (preact, D style deep stem/avg down) defs
nf_resnet26=NfCfg( nf_resnet26=_nfres_cfg(depths=(2, 2, 2, 2)),
depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048), nf_resnet50=_nfres_cfg(depths=(3, 4, 6, 3)),
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, nf_resnet101=_nfres_cfg(depths=(3, 4, 23, 3)),
act_layer='relu', attn_layer=None,),
nf_resnet50=NfCfg( nf_seresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)),
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048), nf_seresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)),
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, nf_seresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)),
act_layer='relu', attn_layer=None),
nf_resnet101=NfCfg( nf_ecaresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='eca', attn_kwargs=dict()),
depths=(3, 4, 23, 3), channels=(256, 512, 1024, 2048), nf_ecaresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='eca', attn_kwargs=dict()),
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, nf_ecaresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='eca', attn_kwargs=dict()),
act_layer='relu', attn_layer=None),
nf_seresnet26=NfCfg(
depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048),
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
nf_seresnet50=NfCfg(
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
nf_seresnet101=NfCfg(
depths=(3, 4, 23, 3), channels=(256, 512, 1024, 2048),
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
nf_ecaresnet26=NfCfg(
depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048),
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer='eca', attn_kwargs=dict()),
nf_ecaresnet50=NfCfg(
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer='eca', attn_kwargs=dict()),
nf_ecaresnet101=NfCfg(
depths=(3, 4, 23, 3), channels=(256, 512, 1024, 2048),
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer='eca', attn_kwargs=dict()),
) )
@ -170,20 +235,20 @@ class DownsampleAvg(nn.Module):
return self.conv(self.pool(x)) return self.conv(self.pool(x))
class NormalizationFreeBlock(nn.Module): class NormFreeBlock(nn.Module):
"""Normalization-free pre-activation block. """Normalization-Free pre-activation block.
""" """
def __init__( def __init__(
self, in_chs, out_chs=None, stride=1, dilation=1, first_dilation=None, self, in_chs, out_chs=None, stride=1, dilation=1, first_dilation=None,
alpha=1.0, beta=1.0, bottle_ratio=0.25, efficient=True, ch_div=1, group_size=None, alpha=1.0, beta=1.0, bottle_ratio=0.25, group_size=None, ch_div=1, reg=True, extra_conv=False,
attn_layer=None, attn_gain=2.0, act_layer=None, conv_layer=None, drop_path_rate=0., skipinit=False): skipinit=False, attn_layer=None, attn_gain=2.0, act_layer=None, conv_layer=None, drop_path_rate=0.):
super().__init__() super().__init__()
first_dilation = first_dilation or dilation first_dilation = first_dilation or dilation
out_chs = out_chs or in_chs out_chs = out_chs or in_chs
# EfficientNet-like models scale bottleneck from in_chs, otherwise scale from out_chs like ResNet # RegNet variants scale bottleneck from in_chs, otherwise scale from out_chs like ResNet
mid_chs = make_divisible(in_chs * bottle_ratio if efficient else out_chs * bottle_ratio, ch_div) mid_chs = make_divisible(in_chs * bottle_ratio if reg else out_chs * bottle_ratio, ch_div)
groups = 1 if group_size is None else mid_chs // group_size groups = 1 if not group_size else mid_chs // group_size
if group_size and group_size % ch_div == 0: if group_size and group_size % ch_div == 0:
mid_chs = group_size * groups # correct mid_chs if group_size divisible by ch_div, otherwise error mid_chs = group_size * groups # correct mid_chs if group_size divisible by ch_div, otherwise error
self.alpha = alpha self.alpha = alpha
@ -200,12 +265,22 @@ class NormalizationFreeBlock(nn.Module):
self.conv1 = conv_layer(in_chs, mid_chs, 1) self.conv1 = conv_layer(in_chs, mid_chs, 1)
self.act2 = act_layer(inplace=True) self.act2 = act_layer(inplace=True)
self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups) self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups)
if attn_layer is not None: if extra_conv:
self.attn = attn_layer(mid_chs) self.act2b = act_layer(inplace=True)
self.conv2b = conv_layer(mid_chs, mid_chs, 3, stride=1, dilation=dilation, groups=groups)
else:
self.act2b = None
self.conv2b = None
if reg and attn_layer is not None:
self.attn = attn_layer(mid_chs) # RegNet blocks apply attn btw conv2 & 3
else: else:
self.attn = None self.attn = None
self.act3 = act_layer() self.act3 = act_layer()
self.conv3 = conv_layer(mid_chs, out_chs, 1) self.conv3 = conv_layer(mid_chs, out_chs, 1)
if not reg and attn_layer is not None:
self.attn_last = attn_layer(out_chs) # ResNet blocks apply attn after conv3
else:
self.attn_last = None
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
self.skipinit_gain = nn.Parameter(torch.tensor(0.)) if skipinit else None self.skipinit_gain = nn.Parameter(torch.tensor(0.)) if skipinit else None
@ -220,28 +295,48 @@ class NormalizationFreeBlock(nn.Module):
# residual branch # residual branch
out = self.conv1(out) out = self.conv1(out)
out = self.conv2(self.act2(out)) out = self.conv2(self.act2(out))
if self.conv2b is not None:
out = self.conv2b(self.act2b(out))
if self.attn is not None: if self.attn is not None:
out = self.attn_gain * self.attn(out) out = self.attn_gain * self.attn(out)
out = self.conv3(self.act3(out)) out = self.conv3(self.act3(out))
if self.attn_last is not None:
out = self.attn_gain * self.attn_last(out)
out = self.drop_path(out) out = self.drop_path(out)
if self.skipinit_gain is None:
out = out * self.alpha + shortcut if self.skipinit_gain is not None:
else:
# this really slows things down for some reason, TBD # this really slows things down for some reason, TBD
out = out * self.alpha * self.skipinit_gain + shortcut out = out * self.skipinit_gain
out = out * self.alpha + shortcut
return out return out
def create_stem(in_chs, out_chs, stem_type='', conv_layer=None): def create_stem(in_chs, out_chs, stem_type='', conv_layer=None, act_layer=None):
stem_stride = 2 stem_stride = 2
stem_feature = dict(num_chs=out_chs, reduction=2, module='')
stem = OrderedDict() stem = OrderedDict()
assert stem_type in ('', 'deep', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool') assert stem_type in ('', 'deep', 'deep_tiered', 'deep_quad', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool')
if 'deep' in stem_type: if 'deep' in stem_type or 'nff' in stem_type:
# 3 deep 3x3 conv stack as in ResNet V1D models. NOTE: doesn't work as well here # 3 deep 3x3 conv stack as in ResNet V1D models. NOTE: doesn't work as well here
mid_chs = out_chs // 2 if 'quad' in stem_type:
stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2) assert not 'pool' in stem_type
stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1) stem_chs = (16, 32, 64, out_chs)
stem['conv3'] = conv_layer(mid_chs, out_chs, kernel_size=3, stride=1) strides = (2, 1, 1, 2)
stem_stride = 4
stem_feature = dict(num_chs=64, reduction=2, module='stem.act4')
else:
if 'tiered' in stem_type:
stem_chs = (3 * out_chs // 8, out_chs // 2, out_chs) # like 'T' resnets in resnet.py
else:
stem_chs = (out_chs // 2, out_chs // 2, out_chs) # 'D' ResNets
strides = (2, 1, 1)
stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.act3')
last_idx = len(stem_chs) - 1
for i, (c, s) in enumerate(zip(stem_chs, strides)):
stem[f'conv{i+1}'] = conv_layer(in_chs, c, kernel_size=3, stride=s)
if i != last_idx:
stem[f'act{i+2}'] = act_layer(inplace=True)
in_chs = c
elif '3x3' in stem_type: elif '3x3' in stem_type:
# 3x3 stem conv as in RegNet # 3x3 stem conv as in RegNet
stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=3, stride=2) stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=3, stride=2)
@ -253,21 +348,37 @@ def create_stem(in_chs, out_chs, stem_type='', conv_layer=None):
stem['pool'] = nn.MaxPool2d(3, stride=2, padding=1) stem['pool'] = nn.MaxPool2d(3, stride=2, padding=1)
stem_stride = 4 stem_stride = 4
return nn.Sequential(stem), stem_stride return nn.Sequential(stem), stem_stride, stem_feature
# from https://github.com/deepmind/deepmind-research/tree/master/nfnets
_nonlin_gamma = dict( _nonlin_gamma = dict(
silu=1./.5595, identity=1.0,
relu=(0.5 * (1. - 1. / math.pi)) ** -0.5, celu=1.270926833152771,
identity=1.0 elu=1.2716004848480225,
gelu=1.7015043497085571,
leaky_relu=1.70590341091156,
log_sigmoid=1.9193484783172607,
log_softmax=1.0002083778381348,
relu=1.7139588594436646,
relu6=1.7131484746932983,
selu=1.0008515119552612,
sigmoid=4.803835391998291,
silu=1.7881293296813965,
softsign=2.338853120803833,
softplus=1.9203323125839233,
tanh=1.5939117670059204,
) )
class NormalizerFreeNet(nn.Module): class NormFreeNet(nn.Module):
""" Normalizer-free ResNets and RegNets """ Normalization-Free Network
As described in `Characterizing signal propagation to close the performance gap in unnormalized ResNets` As described in :
`Characterizing signal propagation to close the performance gap in unnormalized ResNets`
- https://arxiv.org/abs/2101.08692 - https://arxiv.org/abs/2101.08692
and
`High-Performance Large-Scale Image Recognition Without Normalization` - https://arxiv.org/abs/2102.06171
This model aims to cover both the NFRegNet-Bx models as detailed in the paper's code snippets and This model aims to cover both the NFRegNet-Bx models as detailed in the paper's code snippets and
the (preact) ResNet models described earlier in the paper. the (preact) ResNet models described earlier in the paper.
@ -278,7 +389,7 @@ class NormalizerFreeNet(nn.Module):
* activation correcting gamma constants are moved into the ScaledStdConv as it has less performance * activation correcting gamma constants are moved into the ScaledStdConv as it has less performance
impact in PyTorch when done with the weight scaling there. This likely wasn't a concern in the JAX impl. impact in PyTorch when done with the weight scaling there. This likely wasn't a concern in the JAX impl.
* a config option `gamma_in_act` can be enabled to not apply gamma in StdConv as described above, but * a config option `gamma_in_act` can be enabled to not apply gamma in StdConv as described above, but
apply it in each activation. This is slightly slower, and yields slightly different results. apply it in each activation. This is slightly slower, numerically different, but matches official impl.
* skipinit is disabled by default, it seems to have a rather drastic impact on GPU memory use and throughput * skipinit is disabled by default, it seems to have a rather drastic impact on GPU memory use and throughput
for what it is/does. Approx 8-10% throughput loss. for what it is/does. Approx 8-10% throughput loss.
""" """
@ -298,10 +409,11 @@ class NormalizerFreeNet(nn.Module):
stem_chs = cfg.stem_chs or cfg.channels[0] stem_chs = cfg.stem_chs or cfg.channels[0]
stem_chs = make_divisible(stem_chs * cfg.width_factor, cfg.ch_div) stem_chs = make_divisible(stem_chs * cfg.width_factor, cfg.ch_div)
self.stem, stem_stride = create_stem(in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer) self.stem, stem_stride, stem_feat = create_stem(
in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer, act_layer=act_layer)
self.feature_info = [] # NOTE: there will be no stride == 2 feature if stem_stride == 4 self.feature_info = [stem_feat] if stem_stride == 4 else []
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] drop_path_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
prev_chs = stem_chs prev_chs = stem_chs
net_stride = stem_stride net_stride = stem_stride
dilation = 1 dilation = 1
@ -309,8 +421,8 @@ class NormalizerFreeNet(nn.Module):
stages = [] stages = []
for stage_idx, stage_depth in enumerate(cfg.depths): for stage_idx, stage_depth in enumerate(cfg.depths):
stride = 1 if stage_idx == 0 and stem_stride > 2 else 2 stride = 1 if stage_idx == 0 and stem_stride > 2 else 2
self.feature_info += [dict( if stride == 2:
num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}.0.act1' if stride == 2 else '')] self.feature_info += [dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}.0.act1')]
if net_stride >= output_stride and stride > 1: if net_stride >= output_stride and stride > 1:
dilation *= stride dilation *= stride
stride = 1 stride = 1
@ -321,23 +433,24 @@ class NormalizerFreeNet(nn.Module):
for block_idx in range(cfg.depths[stage_idx]): for block_idx in range(cfg.depths[stage_idx]):
first_block = block_idx == 0 and stage_idx == 0 first_block = block_idx == 0 and stage_idx == 0
out_chs = make_divisible(cfg.channels[stage_idx] * cfg.width_factor, cfg.ch_div) out_chs = make_divisible(cfg.channels[stage_idx] * cfg.width_factor, cfg.ch_div)
blocks += [NormalizationFreeBlock( blocks += [NormFreeBlock(
in_chs=prev_chs, out_chs=out_chs, in_chs=prev_chs, out_chs=out_chs,
alpha=cfg.alpha, alpha=cfg.alpha,
beta=1. / expected_var ** 0.5, # NOTE: beta used as multiplier in block beta=1. / expected_var ** 0.5,
stride=stride if block_idx == 0 else 1, stride=stride if block_idx == 0 else 1,
dilation=dilation, dilation=dilation,
first_dilation=first_dilation, first_dilation=first_dilation,
group_size=cfg.group_size, group_size=cfg.group_size,
bottle_ratio=1. if cfg.efficient and first_block else cfg.bottle_ratio, bottle_ratio=1. if cfg.reg and first_block else cfg.bottle_ratio,
efficient=cfg.efficient,
ch_div=cfg.ch_div, ch_div=cfg.ch_div,
reg=cfg.reg,
extra_conv=cfg.extra_conv,
skipinit=cfg.skipinit,
attn_layer=attn_layer, attn_layer=attn_layer,
attn_gain=cfg.attn_gain, attn_gain=cfg.attn_gain,
act_layer=act_layer, act_layer=act_layer,
conv_layer=conv_layer, conv_layer=conv_layer,
drop_path_rate=dpr[stage_idx][block_idx], drop_path_rate=drop_path_rates[stage_idx][block_idx],
skipinit=cfg.skipinit,
)] )]
if block_idx == 0: if block_idx == 0:
expected_var = 1. # expected var is reset after first block of each stage expected_var = 1. # expected var is reset after first block of each stage
@ -347,27 +460,27 @@ class NormalizerFreeNet(nn.Module):
stages += [nn.Sequential(*blocks)] stages += [nn.Sequential(*blocks)]
self.stages = nn.Sequential(*stages) self.stages = nn.Sequential(*stages)
if cfg.efficient and cfg.num_features: if cfg.num_features:
# The paper NFRegNet models have an EfficientNet-like final head convolution. # The paper NFRegNet models have an EfficientNet-like final head convolution.
self.num_features = make_divisible(cfg.width_factor * cfg.num_features, cfg.ch_div) self.num_features = make_divisible(cfg.width_factor * cfg.num_features, cfg.ch_div)
self.final_conv = conv_layer(prev_chs, self.num_features, 1) self.final_conv = conv_layer(prev_chs, self.num_features, 1)
else: else:
self.num_features = prev_chs self.num_features = prev_chs
self.final_conv = nn.Identity() self.final_conv = nn.Identity()
# FIXME not 100% clear on gamma subtleties final conv/final act in case where it's in stdconv self.final_act = act_layer(inplace=cfg.num_features > 0)
self.final_act = act_layer()
self.feature_info += [dict(num_chs=self.num_features, reduction=net_stride, module='final_act')] self.feature_info += [dict(num_chs=self.num_features, reduction=net_stride, module='final_act')]
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
for n, m in self.named_modules(): for n, m in self.named_modules():
if 'fc' in n and isinstance(m, nn.Linear): if 'fc' in n and isinstance(m, nn.Linear):
nn.init.zeros_(m.weight) if cfg.zero_init_fc:
nn.init.zeros_(m.weight)
else:
nn.init.normal_(m.weight, 0., .01)
if m.bias is not None: if m.bias is not None:
nn.init.zeros_(m.bias) nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d): elif isinstance(m, nn.Conv2d):
# as per discussion with paper authors, original in haiku is
# hk.initializers.VarianceScaling(1.0, 'fan_in', 'normal')' w/ zero'd bias
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='linear') nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='linear')
if m.bias is not None: if m.bias is not None:
nn.init.zeros_(m.bias) nn.init.zeros_(m.bias)
@ -395,17 +508,127 @@ def _create_normfreenet(variant, pretrained=False, **kwargs):
model_cfg = model_cfgs[variant] model_cfg = model_cfgs[variant]
feature_cfg = dict(flatten_sequential=True) feature_cfg = dict(flatten_sequential=True)
feature_cfg['feature_cls'] = 'hook' # pre-act models need hooks to grab feat from act1 in bottleneck blocks feature_cfg['feature_cls'] = 'hook' # pre-act models need hooks to grab feat from act1 in bottleneck blocks
if 'pool' in model_cfg.stem_type: if 'pool' in model_cfg.stem_type and 'deep' not in model_cfg.stem_type:
feature_cfg['out_indices'] = (1, 2, 3, 4) # no stride 2, 0 level feat for stride 4 maxpool stems in ResNet feature_cfg['out_indices'] = (1, 2, 3, 4) # no stride 2 feat for stride 4, 1 layer maxpool stems
return build_model_with_cfg( return build_model_with_cfg(
NormalizerFreeNet, variant, pretrained, NormFreeNet, variant, pretrained,
default_cfg=default_cfgs[variant], default_cfg=default_cfgs[variant],
model_cfg=model_cfg, model_cfg=model_cfg,
feature_cfg=feature_cfg, feature_cfg=feature_cfg,
**kwargs) **kwargs)
@register_model
def nfnet_f0(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f0', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f1(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f1', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f2(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f2', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f3(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f3', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f4(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f4', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f5(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f5', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f6(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f6', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f7(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f7', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f0s(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f0s', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f1s(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f1s', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f2s(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f2s', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f3s(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f3s', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f4s(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f4s', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f5s(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f5s', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f6s(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f6s', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f7s(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f7s', pretrained=pretrained, **kwargs)
@register_model
def nf_regnet_b0(pretrained=False, **kwargs):
return _create_normfreenet('nf_regnet_b0', pretrained=pretrained, **kwargs)
@register_model
def nf_regnet_b1(pretrained=False, **kwargs):
return _create_normfreenet('nf_regnet_b1', pretrained=pretrained, **kwargs)
@register_model
def nf_regnet_b2(pretrained=False, **kwargs):
return _create_normfreenet('nf_regnet_b2', pretrained=pretrained, **kwargs)
@register_model
def nf_regnet_b3(pretrained=False, **kwargs):
return _create_normfreenet('nf_regnet_b3', pretrained=pretrained, **kwargs)
@register_model
def nf_regnet_b4(pretrained=False, **kwargs):
return _create_normfreenet('nf_regnet_b4', pretrained=pretrained, **kwargs)
@register_model
def nf_regnet_b5(pretrained=False, **kwargs):
return _create_normfreenet('nf_regnet_b5', pretrained=pretrained, **kwargs)
@register_model @register_model
def nf_regnet_b0(pretrained=False, **kwargs): def nf_regnet_b0(pretrained=False, **kwargs):
return _create_normfreenet('nf_regnet_b0', pretrained=pretrained, **kwargs) return _create_normfreenet('nf_regnet_b0', pretrained=pretrained, **kwargs)
@ -475,6 +698,7 @@ def nf_ecaresnet26(pretrained=False, **kwargs):
def nf_ecaresnet50(pretrained=False, **kwargs): def nf_ecaresnet50(pretrained=False, **kwargs):
return _create_normfreenet('nf_ecaresnet50', pretrained=pretrained, **kwargs) return _create_normfreenet('nf_ecaresnet50', pretrained=pretrained, **kwargs)
@register_model @register_model
def nf_ecaresnet101(pretrained=False, **kwargs): def nf_ecaresnet101(pretrained=False, **kwargs):
return _create_normfreenet('nf_ecaresnet101', pretrained=pretrained, **kwargs) return _create_normfreenet('nf_ecaresnet101', pretrained=pretrained, **kwargs)
Loading…
Cancel
Save