From cb06c7a910cb9b1078679bc67c76afcbb7453d3c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 12 Feb 2021 18:28:56 -0800 Subject: [PATCH] Add NFNet-F models and tweak existing NF models. --- timm/models/nfnet.py | 475 ++++++++++++++++++++++++++++++++----------- 1 file changed, 354 insertions(+), 121 deletions(-) diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 7b79259c..b9c003e8 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -1,10 +1,19 @@ -""" Normalizer Free RegNet / ResNet (pre-activation) Models +""" Normalization Free Nets. NFNet, NF-RegNet, NF-ResNet (pre-activation) Models Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - https://arxiv.org/abs/2101.08692 -NOTE: These models are a work in progress, no pretrained weights yet but I'm currently training some. -Details may change, especially once the paper authors release their official models. + +Paper: `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + +Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets + +Status: +* These models are a work in progress, experiments ongoing. +* Two pretrained weights so far, more to come. +* Model details update to closer match official JAX code now that it's released +* NF-ResNet, NF-RegNet-B, and NFNet-F models supported Hacked together by / copyright Ross Wightman, 2021. """ @@ -28,37 +37,71 @@ def _dcfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'crop_pct': 0.9, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'stem.conv', 'classifier': 'head.fc', **kwargs } -default_cfgs = { - 'nf_regnet_b0': _dcfg(url=''), - 'nf_regnet_b1': _dcfg( +default_cfgs = dict( + nfnet_f0=_dcfg( + url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'), + nfnet_f1=_dcfg( + url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320), first_conv='stem.conv1'), + nfnet_f2=_dcfg( + url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352), first_conv='stem.conv1'), + nfnet_f3=_dcfg( + url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416), first_conv='stem.conv1'), + nfnet_f4=_dcfg( + url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512), first_conv='stem.conv1'), + nfnet_f5=_dcfg( + url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544), first_conv='stem.conv1'), + nfnet_f6=_dcfg( + url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576), first_conv='stem.conv1'), + nfnet_f7=_dcfg( + url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608), first_conv='stem.conv1'), + + nfnet_f0s=_dcfg( + url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'), + nfnet_f1s=_dcfg( + url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320), first_conv='stem.conv1'), + nfnet_f2s=_dcfg( + url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352), first_conv='stem.conv1'), + nfnet_f3s=_dcfg( + url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416), first_conv='stem.conv1'), + nfnet_f4s=_dcfg( + url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512), first_conv='stem.conv1'), + nfnet_f5s=_dcfg( + url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544), first_conv='stem.conv1'), + nfnet_f6s=_dcfg( + url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576), first_conv='stem.conv1'), + nfnet_f7s=_dcfg( + url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608), first_conv='stem.conv1'), + + nf_regnet_b0=_dcfg(url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)), + nf_regnet_b1=_dcfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_regnet_b1_256_ra2-ad85cfef.pth', - pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), crop_pct=0.9), - 'nf_regnet_b2': _dcfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), - 'nf_regnet_b3': _dcfg(url='', input_size=(3, 272, 272), pool_size=(9, 9)), - 'nf_regnet_b4': _dcfg(url='', input_size=(3, 320, 320), pool_size=(10, 10)), - 'nf_regnet_b5': _dcfg(url='', input_size=(3, 384, 384), pool_size=(12, 12)), - - 'nf_resnet26': _dcfg(url='', first_conv='stem.conv'), - 'nf_resnet50': _dcfg( + pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288)), # NOT to paper spec + nf_regnet_b2=_dcfg(url='', pool_size=(8, 8), input_size=(3, 240, 240), test_input_size=(3, 272, 272)), + nf_regnet_b3=_dcfg(url='', pool_size=(9, 9), input_size=(3, 288, 288), test_input_size=(3, 320, 320)), + nf_regnet_b4=_dcfg(url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384)), + nf_regnet_b5=_dcfg(url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 456, 456)), + + nf_resnet26=_dcfg(url=''), + nf_resnet50=_dcfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_resnet50_ra2-9f236009.pth', - first_conv='stem.conv', pool_size=(8, 8), input_size=(3, 256, 256), crop_pct=0.94), - 'nf_resnet101': _dcfg(url='', first_conv='stem.conv'), + pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), crop_pct=0.94), + nf_resnet101=_dcfg(url=''), - 'nf_seresnet26': _dcfg(url='', first_conv='stem.conv'), - 'nf_seresnet50': _dcfg(url='', first_conv='stem.conv'), - 'nf_seresnet101': _dcfg(url='', first_conv='stem.conv'), + nf_seresnet26=_dcfg(url=''), + nf_seresnet50=_dcfg(url=''), + nf_seresnet101=_dcfg(url=''), - 'nf_ecaresnet26': _dcfg(url='', first_conv='stem.conv'), - 'nf_ecaresnet50': _dcfg(url='', first_conv='stem.conv'), - 'nf_ecaresnet101': _dcfg(url='', first_conv='stem.conv'), -} + nf_ecaresnet26=_dcfg(url=''), + nf_ecaresnet50=_dcfg(url=''), + nf_ecaresnet101=_dcfg(url=''), +) @dataclass @@ -69,69 +112,95 @@ class NfCfg: gamma_in_act: bool = False stem_type: str = '3x3' stem_chs: Optional[int] = None - group_size: Optional[int] = 8 - attn_layer: Optional[str] = 'se' - attn_kwargs: dict = field(default_factory=lambda: dict(reduction_ratio=0.5, divisor=8)) + group_size: Optional[int] = None + attn_layer: Optional[str] = None + attn_kwargs: dict = None attn_gain: float = 2.0 # NF correction gain to apply if attn layer is used - width_factor: float = 0.75 - bottle_ratio: float = 2.25 - efficient: bool = True # enables EfficientNet-like options that are used in paper 'nf_regnet_b*' models - num_features: int = 1280 # num out_channels for final conv (when enabled in efficient mode) + width_factor: float = 1.0 + bottle_ratio: float = 0.5 + num_features: int = 0 # num out_channels for final conv, no final_conv if 0 ch_div: int = 8 # round channels % 8 == 0 to keep tensor-core use optimal - skipinit: bool = False + reg: bool = False # enables EfficientNet-like options used in RegNet variants, expand from in_chs, se in middle + extra_conv: bool = False # extra 3x3 bottleneck convolution for NFNet models + skipinit: bool = False # disabled by default, non-trivial performance impact + zero_init_fc: bool = False act_layer: str = 'silu' +def _nfres_cfg( + depths, channels=(256, 512, 1024, 2048), group_size=None, act_layer='relu', attn_layer=None, attn_kwargs=None): + attn_kwargs = attn_kwargs or {} + cfg = NfCfg( + depths=depths, channels=channels, stem_type='7x7_pool', stem_chs=64, bottle_ratio=0.25, + group_size=group_size, act_layer=act_layer, attn_layer=attn_layer, attn_kwargs=attn_kwargs) + return cfg + + +def _nfreg_cfg(depths, channels=(48, 104, 208, 440)): + num_features = 1280 * channels[-1] // 440 + attn_kwargs = dict(reduction_ratio=0.5, divisor=8) + cfg = NfCfg( + depths=depths, channels=channels, stem_type='3x3', group_size=8, width_factor=0.75, bottle_ratio=2.25, + num_features=num_features, reg=True, attn_layer='se', attn_kwargs=attn_kwargs) + return cfg + + +def _nfnet_cfg(depths, act_layer='gelu', attn_layer='se', attn_kwargs=None): + channels = (256, 512, 1536, 1536) + num_features = channels[-1] * 2 + attn_kwargs = attn_kwargs or dict(reduction_ratio=0.5, divisor=8) + cfg = NfCfg( + depths=depths, channels=channels, stem_type='nff', group_size=128, bottle_ratio=0.5, extra_conv=True, + num_features=num_features, act_layer=act_layer, attn_layer=attn_layer, attn_kwargs=attn_kwargs) + return cfg + + model_cfgs = dict( - # EffNet influenced RegNet defs - nf_regnet_b0=NfCfg(depths=(1, 3, 6, 6), channels=(48, 104, 208, 440), num_features=1280), - nf_regnet_b1=NfCfg(depths=(2, 4, 7, 7), channels=(48, 104, 208, 440), num_features=1280), - nf_regnet_b2=NfCfg(depths=(2, 4, 8, 8), channels=(56, 112, 232, 488), num_features=1416), - nf_regnet_b3=NfCfg(depths=(2, 5, 9, 9), channels=(56, 128, 248, 528), num_features=1536), - nf_regnet_b4=NfCfg(depths=(2, 6, 11, 11), channels=(64, 144, 288, 616), num_features=1792), - nf_regnet_b5=NfCfg(depths=(3, 7, 14, 14), channels=(80, 168, 336, 704), num_features=2048), + # NFNet-F models w/ GeLU + nfnet_f0=_nfnet_cfg(depths=(1, 2, 6, 3)), + nfnet_f1=_nfnet_cfg(depths=(2, 4, 12, 6)), + nfnet_f2=_nfnet_cfg(depths=(3, 6, 18, 9)), + nfnet_f3=_nfnet_cfg(depths=(4, 8, 24, 12)), + nfnet_f4=_nfnet_cfg(depths=(5, 10, 30, 15)), + nfnet_f5=_nfnet_cfg(depths=(6, 12, 36, 18)), + nfnet_f6=_nfnet_cfg(depths=(7, 14, 42, 21)), + nfnet_f7=_nfnet_cfg(depths=(8, 16, 48, 24)), + + # NFNet-F models w/ SiLU (much faster in PyTorch) + nfnet_f0s=_nfnet_cfg(depths=(1, 2, 6, 3), act_layer='silu'), + nfnet_f1s=_nfnet_cfg(depths=(2, 4, 12, 6), act_layer='silu'), + nfnet_f2s=_nfnet_cfg(depths=(3, 6, 18, 9), act_layer='silu'), + nfnet_f3s=_nfnet_cfg(depths=(4, 8, 24, 12), act_layer='silu'), + nfnet_f4s=_nfnet_cfg(depths=(5, 10, 30, 15), act_layer='silu'), + nfnet_f5s=_nfnet_cfg(depths=(6, 12, 36, 18), act_layer='silu'), + nfnet_f6s=_nfnet_cfg(depths=(7, 14, 42, 21), act_layer='silu'), + nfnet_f7s=_nfnet_cfg(depths=(8, 16, 48, 24), act_layer='silu'), + + # NFNet-F models w/ SiLU (much faster in PyTorch) + # FIXME add remainder if silu vs gelu proves worthwhile + + # EffNet influenced RegNet defs. + # NOTE: These aren't quite the official ver, ch_div=1 must be set for exact ch counts. I round to ch_div=8. + nf_regnet_b0=_nfreg_cfg(depths=(1, 3, 6, 6)), + nf_regnet_b1=_nfreg_cfg(depths=(2, 4, 7, 7)), + nf_regnet_b2=_nfreg_cfg(depths=(2, 4, 8, 8), channels=(56, 112, 232, 488)), + nf_regnet_b3=_nfreg_cfg(depths=(2, 5, 9, 9), channels=(56, 128, 248, 528)), + nf_regnet_b4=_nfreg_cfg(depths=(2, 6, 11, 11), channels=(64, 144, 288, 616)), + nf_regnet_b5=_nfreg_cfg(depths=(3, 7, 14, 14), channels=(80, 168, 336, 704)), + # FIXME add B6-B8 # ResNet (preact, D style deep stem/avg down) defs - nf_resnet26=NfCfg( - depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048), - stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, - act_layer='relu', attn_layer=None,), - nf_resnet50=NfCfg( - depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048), - stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, - act_layer='relu', attn_layer=None), - nf_resnet101=NfCfg( - depths=(3, 4, 23, 3), channels=(256, 512, 1024, 2048), - stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, - act_layer='relu', attn_layer=None), - - - nf_seresnet26=NfCfg( - depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048), - stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, - act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)), - nf_seresnet50=NfCfg( - depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048), - stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, - act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)), - nf_seresnet101=NfCfg( - depths=(3, 4, 23, 3), channels=(256, 512, 1024, 2048), - stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, - act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)), - - - nf_ecaresnet26=NfCfg( - depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048), - stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, - act_layer='relu', attn_layer='eca', attn_kwargs=dict()), - nf_ecaresnet50=NfCfg( - depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048), - stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, - act_layer='relu', attn_layer='eca', attn_kwargs=dict()), - nf_ecaresnet101=NfCfg( - depths=(3, 4, 23, 3), channels=(256, 512, 1024, 2048), - stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, - act_layer='relu', attn_layer='eca', attn_kwargs=dict()), + nf_resnet26=_nfres_cfg(depths=(2, 2, 2, 2)), + nf_resnet50=_nfres_cfg(depths=(3, 4, 6, 3)), + nf_resnet101=_nfres_cfg(depths=(3, 4, 23, 3)), + + nf_seresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)), + nf_seresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)), + nf_seresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)), + + nf_ecaresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='eca', attn_kwargs=dict()), + nf_ecaresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='eca', attn_kwargs=dict()), + nf_ecaresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='eca', attn_kwargs=dict()), ) @@ -170,20 +239,20 @@ class DownsampleAvg(nn.Module): return self.conv(self.pool(x)) -class NormalizationFreeBlock(nn.Module): +class NormFreeBlock(nn.Module): """Normalization-free pre-activation block. """ def __init__( self, in_chs, out_chs=None, stride=1, dilation=1, first_dilation=None, - alpha=1.0, beta=1.0, bottle_ratio=0.25, efficient=True, ch_div=1, group_size=None, - attn_layer=None, attn_gain=2.0, act_layer=None, conv_layer=None, drop_path_rate=0., skipinit=False): + alpha=1.0, beta=1.0, bottle_ratio=0.25, group_size=None, ch_div=1, reg=True, extra_conv=False, + skipinit=False, attn_layer=None, attn_gain=2.0, act_layer=None, conv_layer=None, drop_path_rate=0.): super().__init__() first_dilation = first_dilation or dilation out_chs = out_chs or in_chs - # EfficientNet-like models scale bottleneck from in_chs, otherwise scale from out_chs like ResNet - mid_chs = make_divisible(in_chs * bottle_ratio if efficient else out_chs * bottle_ratio, ch_div) - groups = 1 if group_size is None else mid_chs // group_size + # RegNet variants scale bottleneck from in_chs, otherwise scale from out_chs like ResNet + mid_chs = make_divisible(in_chs * bottle_ratio if reg else out_chs * bottle_ratio, ch_div) + groups = 1 if not group_size else mid_chs // group_size if group_size and group_size % ch_div == 0: mid_chs = group_size * groups # correct mid_chs if group_size divisible by ch_div, otherwise error self.alpha = alpha @@ -200,12 +269,22 @@ class NormalizationFreeBlock(nn.Module): self.conv1 = conv_layer(in_chs, mid_chs, 1) self.act2 = act_layer(inplace=True) self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups) - if attn_layer is not None: - self.attn = attn_layer(mid_chs) + if extra_conv: + self.act2b = act_layer(inplace=True) + self.conv2b = conv_layer(mid_chs, mid_chs, 3, stride=1, dilation=dilation, groups=groups) + else: + self.act2b = None + self.conv2b = None + if reg and attn_layer is not None: + self.attn = attn_layer(mid_chs) # RegNet blocks apply attn btw conv2 & 3 else: self.attn = None self.act3 = act_layer() self.conv3 = conv_layer(mid_chs, out_chs, 1) + if not reg and attn_layer is not None: + self.attn_last = attn_layer(out_chs) # ResNet blocks apply attn after conv3 + else: + self.attn_last = None self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() self.skipinit_gain = nn.Parameter(torch.tensor(0.)) if skipinit else None @@ -220,28 +299,60 @@ class NormalizationFreeBlock(nn.Module): # residual branch out = self.conv1(out) out = self.conv2(self.act2(out)) + if self.conv2b is not None: + out = self.conv2b(self.act2b(out)) if self.attn is not None: out = self.attn_gain * self.attn(out) out = self.conv3(self.act3(out)) + if self.attn_last is not None: + out = self.attn_gain * self.attn_last(out) out = self.drop_path(out) - if self.skipinit_gain is None: - out = out * self.alpha + shortcut - else: + + if self.skipinit_gain is not None: # this really slows things down for some reason, TBD - out = out * self.alpha * self.skipinit_gain + shortcut + out = out * self.skipinit_gain + out = out * self.alpha + shortcut return out -def create_stem(in_chs, out_chs, stem_type='', conv_layer=None): +def stem_info(stem_type): + stem_stride = 2 + if 'nff' in stem_type or 'pool' in stem_type: + stem_stride = 4 + stem_feat = '' + if 'nff' in stem_type: + stem_feat = 'stem.act3' + elif 'deep' in stem_type and not 'pool' in stem_type: + stem_feat = 'stem.act2' + return stem_stride, stem_feat + + +def create_stem(in_chs, out_chs, stem_type='', conv_layer=None, act_layer=None): stem_stride = 2 + stem_feature = '' stem = OrderedDict() - assert stem_type in ('', 'deep', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool') - if 'deep' in stem_type: + assert stem_type in ('', 'nff', 'deep', 'deep_tiered', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool') + if 'deep' in stem_type or 'nff' in stem_type: # 3 deep 3x3 conv stack as in ResNet V1D models. NOTE: doesn't work as well here - mid_chs = out_chs // 2 - stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2) - stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1) - stem['conv3'] = conv_layer(mid_chs, out_chs, kernel_size=3, stride=1) + if 'nff' in stem_type: + assert not 'pool' in stem_type + stem_chs = (16, 32, 64, out_chs) + strides = (2, 1, 1, 2) + stem_stride = 4 + stem_feature = 'stem.act4' + else: + if 'tiered' in stem_type: + stem_chs = (3 * out_chs // 8, out_chs // 2, out_chs) + else: + stem_chs = (out_chs // 2, out_chs // 2, out_chs) + strides = (2, 1, 1) + stem_feature = 'stem.act3' + last_idx = len(stem_chs) - 1 + for i, (c, s) in enumerate(zip(stem_chs, strides)): + stem[f'conv{i+1}'] = conv_layer(in_chs, c, kernel_size=3, stride=s) + if i != last_idx: + stem[f'act{i+2}'] = act_layer(inplace=True) + in_chs = c elif '3x3' in stem_type: # 3x3 stem conv as in RegNet stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=3, stride=2) @@ -253,18 +364,30 @@ def create_stem(in_chs, out_chs, stem_type='', conv_layer=None): stem['pool'] = nn.MaxPool2d(3, stride=2, padding=1) stem_stride = 4 - return nn.Sequential(stem), stem_stride + return nn.Sequential(stem), stem_stride, stem_feature _nonlin_gamma = dict( - silu=1./.5595, - relu=(0.5 * (1. - 1. / math.pi)) ** -0.5, - identity=1.0 + identity=1.0, + celu=1.270926833152771, + elu=1.2716004848480225, + gelu=1.7015043497085571, + leaky_relu=1.70590341091156, + log_sigmoid=1.9193484783172607, + log_softmax=1.0002083778381348, + relu=1.7139588594436646, + relu6=1.7131484746932983, + selu=1.0008515119552612, + sigmoid=4.803835391998291, + silu=1.7881293296813965, + softsign=2.338853120803833, + softplus=1.9203323125839233, + tanh=1.5939117670059204, ) -class NormalizerFreeNet(nn.Module): - """ Normalizer-free ResNets and RegNets +class NormFreeNet(nn.Module): + """ Normalization-free ResNets and RegNets As described in `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - https://arxiv.org/abs/2101.08692 @@ -298,10 +421,11 @@ class NormalizerFreeNet(nn.Module): stem_chs = cfg.stem_chs or cfg.channels[0] stem_chs = make_divisible(stem_chs * cfg.width_factor, cfg.ch_div) - self.stem, stem_stride = create_stem(in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer) + self.stem, stem_stride, stem_feat = create_stem( + in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer, act_layer=act_layer) - self.feature_info = [] # NOTE: there will be no stride == 2 feature if stem_stride == 4 - dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] + self.feature_info = [dict(num_chs=stem_chs, reduction=2, module=stem_feat)] if stem_stride == 4 else [] + drop_path_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] prev_chs = stem_chs net_stride = stem_stride dilation = 1 @@ -309,8 +433,8 @@ class NormalizerFreeNet(nn.Module): stages = [] for stage_idx, stage_depth in enumerate(cfg.depths): stride = 1 if stage_idx == 0 and stem_stride > 2 else 2 - self.feature_info += [dict( - num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}.0.act1' if stride == 2 else '')] + if stride == 2: + self.feature_info += [dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}.0.act1')] if net_stride >= output_stride and stride > 1: dilation *= stride stride = 1 @@ -321,7 +445,7 @@ class NormalizerFreeNet(nn.Module): for block_idx in range(cfg.depths[stage_idx]): first_block = block_idx == 0 and stage_idx == 0 out_chs = make_divisible(cfg.channels[stage_idx] * cfg.width_factor, cfg.ch_div) - blocks += [NormalizationFreeBlock( + blocks += [NormFreeBlock( in_chs=prev_chs, out_chs=out_chs, alpha=cfg.alpha, beta=1. / expected_var ** 0.5, # NOTE: beta used as multiplier in block @@ -329,15 +453,16 @@ class NormalizerFreeNet(nn.Module): dilation=dilation, first_dilation=first_dilation, group_size=cfg.group_size, - bottle_ratio=1. if cfg.efficient and first_block else cfg.bottle_ratio, - efficient=cfg.efficient, + bottle_ratio=1. if cfg.reg and first_block else cfg.bottle_ratio, ch_div=cfg.ch_div, + reg=cfg.reg, + extra_conv=cfg.extra_conv, + skipinit=cfg.skipinit, attn_layer=attn_layer, attn_gain=cfg.attn_gain, act_layer=act_layer, conv_layer=conv_layer, - drop_path_rate=dpr[stage_idx][block_idx], - skipinit=cfg.skipinit, + drop_path_rate=drop_path_rates[stage_idx][block_idx], )] if block_idx == 0: expected_var = 1. # expected var is reset after first block of each stage @@ -347,22 +472,25 @@ class NormalizerFreeNet(nn.Module): stages += [nn.Sequential(*blocks)] self.stages = nn.Sequential(*stages) - if cfg.efficient and cfg.num_features: + if cfg.num_features: # The paper NFRegNet models have an EfficientNet-like final head convolution. self.num_features = make_divisible(cfg.width_factor * cfg.num_features, cfg.ch_div) self.final_conv = conv_layer(prev_chs, self.num_features, 1) + # FIXME not 100% clear on gamma subtleties final conv/final act in case where it's pushed into stdconv else: self.num_features = prev_chs self.final_conv = nn.Identity() - # FIXME not 100% clear on gamma subtleties final conv/final act in case where it's in stdconv - self.final_act = act_layer() + self.final_act = act_layer(inplace=cfg.num_features > 0) self.feature_info += [dict(num_chs=self.num_features, reduction=net_stride, module='final_act')] self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) for n, m in self.named_modules(): if 'fc' in n and isinstance(m, nn.Linear): - nn.init.zeros_(m.weight) + if cfg.zero_init_fc: + nn.init.zeros_(m.weight) + else: + nn.init.normal_(m.weight, 0., .01) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Conv2d): @@ -395,17 +523,121 @@ def _create_normfreenet(variant, pretrained=False, **kwargs): model_cfg = model_cfgs[variant] feature_cfg = dict(flatten_sequential=True) feature_cfg['feature_cls'] = 'hook' # pre-act models need hooks to grab feat from act1 in bottleneck blocks - if 'pool' in model_cfg.stem_type: - feature_cfg['out_indices'] = (1, 2, 3, 4) # no stride 2, 0 level feat for stride 4 maxpool stems in ResNet + if 'pool' in model_cfg.stem_type and 'deep' not in model_cfg.stem_type: + feature_cfg['out_indices'] = (1, 2, 3, 4) # no stride 2 feat for stride 4, 1 layer maxpool stems return build_model_with_cfg( - NormalizerFreeNet, variant, pretrained, + NormFreeNet, variant, pretrained, default_cfg=default_cfgs[variant], model_cfg=model_cfg, feature_cfg=feature_cfg, **kwargs) +@register_model +def nfnet_f0(pretrained=False, **kwargs): + return _create_normfreenet('nfnet_f0', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f1(pretrained=False, **kwargs): + return _create_normfreenet('nfnet_f1', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f2(pretrained=False, **kwargs): + return _create_normfreenet('nfnet_f2', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f3(pretrained=False, **kwargs): + return _create_normfreenet('nfnet_f3', pretrained=pretrained, **kwargs) + + +def nfnet_f4(pretrained=False, **kwargs): + return _create_normfreenet('nfnet_f4', pretrained=pretrained, **kwargs) + + +def nfnet_f5(pretrained=False, **kwargs): + return _create_normfreenet('nfnet_f5', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f6(pretrained=False, **kwargs): + return _create_normfreenet('nfnet_f6', pretrained=pretrained, **kwargs) + + +def nfnet_f7(pretrained=False, **kwargs): + return _create_normfreenet('nfnet_f7', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f0s(pretrained=False, **kwargs): + return _create_normfreenet('nfnet_f0s', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f1s(pretrained=False, **kwargs): + return _create_normfreenet('nfnet_f1s', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f2s(pretrained=False, **kwargs): + return _create_normfreenet('nfnet_f2s', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f3s(pretrained=False, **kwargs): + return _create_normfreenet('nfnet_f3s', pretrained=pretrained, **kwargs) + + +def nfnet_f4s(pretrained=False, **kwargs): + return _create_normfreenet('nfnet_f4s', pretrained=pretrained, **kwargs) + + +def nfnet_f5s(pretrained=False, **kwargs): + return _create_normfreenet('nfnet_f5s', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f6s(pretrained=False, **kwargs): + return _create_normfreenet('nfnet_f6s', pretrained=pretrained, **kwargs) + + +def nfnet_f7s(pretrained=False, **kwargs): + return _create_normfreenet('nfnet_f7s', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b0(pretrained=False, **kwargs): + return _create_normfreenet('nf_regnet_b0', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b1(pretrained=False, **kwargs): + return _create_normfreenet('nf_regnet_b1', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b2(pretrained=False, **kwargs): + return _create_normfreenet('nf_regnet_b2', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b3(pretrained=False, **kwargs): + return _create_normfreenet('nf_regnet_b3', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b4(pretrained=False, **kwargs): + return _create_normfreenet('nf_regnet_b4', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b5(pretrained=False, **kwargs): + return _create_normfreenet('nf_regnet_b5', pretrained=pretrained, **kwargs) + + @register_model def nf_regnet_b0(pretrained=False, **kwargs): return _create_normfreenet('nf_regnet_b0', pretrained=pretrained, **kwargs) @@ -475,6 +707,7 @@ def nf_ecaresnet26(pretrained=False, **kwargs): def nf_ecaresnet50(pretrained=False, **kwargs): return _create_normfreenet('nf_ecaresnet50', pretrained=pretrained, **kwargs) + @register_model def nf_ecaresnet101(pretrained=False, **kwargs): return _create_normfreenet('nf_ecaresnet101', pretrained=pretrained, **kwargs) \ No newline at end of file