Merge pull request #427 from rwightman/nfnet

Add NFNet-F models and tweak existing NF models.
pull/437/head
Ross Wightman 3 years ago committed by GitHub
commit 4df513c68f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,6 +2,9 @@
## What's New
### Feb 12, 2021
* Update Normalization-Free nets to include new NFNet-F (https://arxiv.org/abs/2102.06171) model defs
### Feb 10, 2021
* First Normalization-Free model training experiments done,
* nf_resnet50 - 80.68 top-1 @ 288x288, 80.31 @ 256x256
@ -164,6 +167,7 @@ A full version of the list below with source links can be found in the [document
* Inception-ResNet-V2 and Inception-V4 - https://arxiv.org/abs/1602.07261
* MobileNet-V3 (MBConvNet w/ Efficient Head) - https://arxiv.org/abs/1905.02244
* NASNet-A - https://arxiv.org/abs/1707.07012
* NFNet-F - https://arxiv.org/abs/2102.06171
* NF-RegNet / NF-ResNet - https://arxiv.org/abs/2101.08692
* PNasNet - https://arxiv.org/abs/1712.00559
* RegNet - https://arxiv.org/abs/2003.13678

@ -19,7 +19,9 @@ NON_STD_FILTERS = ['vit_*']
# exclude models that cause specific test failures
if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
# GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models
EXCLUDE_FILTERS = ['*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm'] + NON_STD_FILTERS
EXCLUDE_FILTERS = [
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm',
'nfnet_f4*', 'nfnet_f5*', 'nfnet_f6*', 'nfnet_f7*'] + NON_STD_FILTERS
else:
EXCLUDE_FILTERS = NON_STD_FILTERS

@ -1,10 +1,18 @@
""" Normalizer Free RegNet / ResNet (pre-activation) Models
""" Normalization Free Nets. NFNet, NF-RegNet, NF-ResNet (pre-activation) Models
Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
- https://arxiv.org/abs/2101.08692
NOTE: These models are a work in progress, no pretrained weights yet but I'm currently training some.
Details may change, especially once the paper authors release their official models.
Paper: `High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets
Status:
* These models are a work in progress, experiments ongoing.
* Pretrained weights for two models so far, more to come.
* Model details updated to closer match official JAX code now that it's released
* NF-ResNet, NF-RegNet-B, and NFNet-F models supported
Hacked together by / copyright Ross Wightman, 2021.
"""
@ -28,37 +36,71 @@ def _dcfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'crop_pct': 0.9, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.conv', 'classifier': 'head.fc',
**kwargs
}
default_cfgs = {
'nf_regnet_b0': _dcfg(url=''),
'nf_regnet_b1': _dcfg(
default_cfgs = dict(
nfnet_f0=_dcfg(
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'),
nfnet_f1=_dcfg(
url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320), first_conv='stem.conv1'),
nfnet_f2=_dcfg(
url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352), first_conv='stem.conv1'),
nfnet_f3=_dcfg(
url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416), first_conv='stem.conv1'),
nfnet_f4=_dcfg(
url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512), first_conv='stem.conv1'),
nfnet_f5=_dcfg(
url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544), first_conv='stem.conv1'),
nfnet_f6=_dcfg(
url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576), first_conv='stem.conv1'),
nfnet_f7=_dcfg(
url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608), first_conv='stem.conv1'),
nfnet_f0s=_dcfg(
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'),
nfnet_f1s=_dcfg(
url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320), first_conv='stem.conv1'),
nfnet_f2s=_dcfg(
url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352), first_conv='stem.conv1'),
nfnet_f3s=_dcfg(
url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416), first_conv='stem.conv1'),
nfnet_f4s=_dcfg(
url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512), first_conv='stem.conv1'),
nfnet_f5s=_dcfg(
url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544), first_conv='stem.conv1'),
nfnet_f6s=_dcfg(
url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576), first_conv='stem.conv1'),
nfnet_f7s=_dcfg(
url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608), first_conv='stem.conv1'),
nf_regnet_b0=_dcfg(url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
nf_regnet_b1=_dcfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_regnet_b1_256_ra2-ad85cfef.pth',
pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), crop_pct=0.9),
'nf_regnet_b2': _dcfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'nf_regnet_b3': _dcfg(url='', input_size=(3, 272, 272), pool_size=(9, 9)),
'nf_regnet_b4': _dcfg(url='', input_size=(3, 320, 320), pool_size=(10, 10)),
'nf_regnet_b5': _dcfg(url='', input_size=(3, 384, 384), pool_size=(12, 12)),
'nf_resnet26': _dcfg(url='', first_conv='stem.conv'),
'nf_resnet50': _dcfg(
pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288)), # NOT to paper spec
nf_regnet_b2=_dcfg(url='', pool_size=(8, 8), input_size=(3, 240, 240), test_input_size=(3, 272, 272)),
nf_regnet_b3=_dcfg(url='', pool_size=(9, 9), input_size=(3, 288, 288), test_input_size=(3, 320, 320)),
nf_regnet_b4=_dcfg(url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384)),
nf_regnet_b5=_dcfg(url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 456, 456)),
nf_resnet26=_dcfg(url=''),
nf_resnet50=_dcfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_resnet50_ra2-9f236009.pth',
first_conv='stem.conv', pool_size=(8, 8), input_size=(3, 256, 256), crop_pct=0.94),
'nf_resnet101': _dcfg(url='', first_conv='stem.conv'),
pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), crop_pct=0.94),
nf_resnet101=_dcfg(url=''),
'nf_seresnet26': _dcfg(url='', first_conv='stem.conv'),
'nf_seresnet50': _dcfg(url='', first_conv='stem.conv'),
'nf_seresnet101': _dcfg(url='', first_conv='stem.conv'),
nf_seresnet26=_dcfg(url=''),
nf_seresnet50=_dcfg(url=''),
nf_seresnet101=_dcfg(url=''),
'nf_ecaresnet26': _dcfg(url='', first_conv='stem.conv'),
'nf_ecaresnet50': _dcfg(url='', first_conv='stem.conv'),
'nf_ecaresnet101': _dcfg(url='', first_conv='stem.conv'),
}
nf_ecaresnet26=_dcfg(url=''),
nf_ecaresnet50=_dcfg(url=''),
nf_ecaresnet101=_dcfg(url=''),
)
@dataclass
@ -69,69 +111,92 @@ class NfCfg:
gamma_in_act: bool = False
stem_type: str = '3x3'
stem_chs: Optional[int] = None
group_size: Optional[int] = 8
attn_layer: Optional[str] = 'se'
attn_kwargs: dict = field(default_factory=lambda: dict(reduction_ratio=0.5, divisor=8))
group_size: Optional[int] = None
attn_layer: Optional[str] = None
attn_kwargs: dict = None
attn_gain: float = 2.0 # NF correction gain to apply if attn layer is used
width_factor: float = 0.75
bottle_ratio: float = 2.25
efficient: bool = True # enables EfficientNet-like options that are used in paper 'nf_regnet_b*' models
num_features: int = 1280 # num out_channels for final conv (when enabled in efficient mode)
width_factor: float = 1.0
bottle_ratio: float = 0.5
num_features: int = 0 # num out_channels for final conv, no final_conv if 0
ch_div: int = 8 # round channels % 8 == 0 to keep tensor-core use optimal
skipinit: bool = False
reg: bool = False # enables EfficientNet-like options used in RegNet variants, expand from in_chs, se in middle
extra_conv: bool = False # extra 3x3 bottleneck convolution for NFNet models
skipinit: bool = False # disabled by default, non-trivial performance impact
zero_init_fc: bool = False
act_layer: str = 'silu'
def _nfres_cfg(
depths, channels=(256, 512, 1024, 2048), group_size=None, act_layer='relu', attn_layer=None, attn_kwargs=None):
attn_kwargs = attn_kwargs or {}
cfg = NfCfg(
depths=depths, channels=channels, stem_type='7x7_pool', stem_chs=64, bottle_ratio=0.25,
group_size=group_size, act_layer=act_layer, attn_layer=attn_layer, attn_kwargs=attn_kwargs)
return cfg
def _nfreg_cfg(depths, channels=(48, 104, 208, 440)):
num_features = 1280 * channels[-1] // 440
attn_kwargs = dict(reduction_ratio=0.5, divisor=8)
cfg = NfCfg(
depths=depths, channels=channels, stem_type='3x3', group_size=8, width_factor=0.75, bottle_ratio=2.25,
num_features=num_features, reg=True, attn_layer='se', attn_kwargs=attn_kwargs)
return cfg
def _nfnet_cfg(depths, act_layer='gelu', attn_layer='se', attn_kwargs=None):
channels = (256, 512, 1536, 1536)
num_features = channels[-1] * 2
attn_kwargs = attn_kwargs or dict(reduction_ratio=0.5, divisor=8)
cfg = NfCfg(
depths=depths, channels=channels, stem_type='deep_quad', group_size=128, bottle_ratio=0.5, extra_conv=True,
num_features=num_features, act_layer=act_layer, attn_layer=attn_layer, attn_kwargs=attn_kwargs)
return cfg
model_cfgs = dict(
# EffNet influenced RegNet defs
nf_regnet_b0=NfCfg(depths=(1, 3, 6, 6), channels=(48, 104, 208, 440), num_features=1280),
nf_regnet_b1=NfCfg(depths=(2, 4, 7, 7), channels=(48, 104, 208, 440), num_features=1280),
nf_regnet_b2=NfCfg(depths=(2, 4, 8, 8), channels=(56, 112, 232, 488), num_features=1416),
nf_regnet_b3=NfCfg(depths=(2, 5, 9, 9), channels=(56, 128, 248, 528), num_features=1536),
nf_regnet_b4=NfCfg(depths=(2, 6, 11, 11), channels=(64, 144, 288, 616), num_features=1792),
nf_regnet_b5=NfCfg(depths=(3, 7, 14, 14), channels=(80, 168, 336, 704), num_features=2048),
# NFNet-F models w/ GeLU
nfnet_f0=_nfnet_cfg(depths=(1, 2, 6, 3)),
nfnet_f1=_nfnet_cfg(depths=(2, 4, 12, 6)),
nfnet_f2=_nfnet_cfg(depths=(3, 6, 18, 9)),
nfnet_f3=_nfnet_cfg(depths=(4, 8, 24, 12)),
nfnet_f4=_nfnet_cfg(depths=(5, 10, 30, 15)),
nfnet_f5=_nfnet_cfg(depths=(6, 12, 36, 18)),
nfnet_f6=_nfnet_cfg(depths=(7, 14, 42, 21)),
nfnet_f7=_nfnet_cfg(depths=(8, 16, 48, 24)),
# NFNet-F models w/ SiLU (much faster in PyTorch)
nfnet_f0s=_nfnet_cfg(depths=(1, 2, 6, 3), act_layer='silu'),
nfnet_f1s=_nfnet_cfg(depths=(2, 4, 12, 6), act_layer='silu'),
nfnet_f2s=_nfnet_cfg(depths=(3, 6, 18, 9), act_layer='silu'),
nfnet_f3s=_nfnet_cfg(depths=(4, 8, 24, 12), act_layer='silu'),
nfnet_f4s=_nfnet_cfg(depths=(5, 10, 30, 15), act_layer='silu'),
nfnet_f5s=_nfnet_cfg(depths=(6, 12, 36, 18), act_layer='silu'),
nfnet_f6s=_nfnet_cfg(depths=(7, 14, 42, 21), act_layer='silu'),
nfnet_f7s=_nfnet_cfg(depths=(8, 16, 48, 24), act_layer='silu'),
# EffNet influenced RegNet defs.
# NOTE: These aren't quite the official ver, ch_div=1 must be set for exact ch counts. I round to ch_div=8.
nf_regnet_b0=_nfreg_cfg(depths=(1, 3, 6, 6)),
nf_regnet_b1=_nfreg_cfg(depths=(2, 4, 7, 7)),
nf_regnet_b2=_nfreg_cfg(depths=(2, 4, 8, 8), channels=(56, 112, 232, 488)),
nf_regnet_b3=_nfreg_cfg(depths=(2, 5, 9, 9), channels=(56, 128, 248, 528)),
nf_regnet_b4=_nfreg_cfg(depths=(2, 6, 11, 11), channels=(64, 144, 288, 616)),
nf_regnet_b5=_nfreg_cfg(depths=(3, 7, 14, 14), channels=(80, 168, 336, 704)),
# FIXME add B6-B8
# ResNet (preact, D style deep stem/avg down) defs
nf_resnet26=NfCfg(
depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048),
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer=None,),
nf_resnet50=NfCfg(
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer=None),
nf_resnet101=NfCfg(
depths=(3, 4, 23, 3), channels=(256, 512, 1024, 2048),
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer=None),
nf_seresnet26=NfCfg(
depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048),
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
nf_seresnet50=NfCfg(
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
nf_seresnet101=NfCfg(
depths=(3, 4, 23, 3), channels=(256, 512, 1024, 2048),
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
nf_ecaresnet26=NfCfg(
depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048),
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer='eca', attn_kwargs=dict()),
nf_ecaresnet50=NfCfg(
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer='eca', attn_kwargs=dict()),
nf_ecaresnet101=NfCfg(
depths=(3, 4, 23, 3), channels=(256, 512, 1024, 2048),
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer='eca', attn_kwargs=dict()),
nf_resnet26=_nfres_cfg(depths=(2, 2, 2, 2)),
nf_resnet50=_nfres_cfg(depths=(3, 4, 6, 3)),
nf_resnet101=_nfres_cfg(depths=(3, 4, 23, 3)),
nf_seresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)),
nf_seresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)),
nf_seresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)),
nf_ecaresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='eca', attn_kwargs=dict()),
nf_ecaresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='eca', attn_kwargs=dict()),
nf_ecaresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='eca', attn_kwargs=dict()),
)
@ -170,20 +235,20 @@ class DownsampleAvg(nn.Module):
return self.conv(self.pool(x))
class NormalizationFreeBlock(nn.Module):
"""Normalization-free pre-activation block.
class NormFreeBlock(nn.Module):
"""Normalization-Free pre-activation block.
"""
def __init__(
self, in_chs, out_chs=None, stride=1, dilation=1, first_dilation=None,
alpha=1.0, beta=1.0, bottle_ratio=0.25, efficient=True, ch_div=1, group_size=None,
attn_layer=None, attn_gain=2.0, act_layer=None, conv_layer=None, drop_path_rate=0., skipinit=False):
alpha=1.0, beta=1.0, bottle_ratio=0.25, group_size=None, ch_div=1, reg=True, extra_conv=False,
skipinit=False, attn_layer=None, attn_gain=2.0, act_layer=None, conv_layer=None, drop_path_rate=0.):
super().__init__()
first_dilation = first_dilation or dilation
out_chs = out_chs or in_chs
# EfficientNet-like models scale bottleneck from in_chs, otherwise scale from out_chs like ResNet
mid_chs = make_divisible(in_chs * bottle_ratio if efficient else out_chs * bottle_ratio, ch_div)
groups = 1 if group_size is None else mid_chs // group_size
# RegNet variants scale bottleneck from in_chs, otherwise scale from out_chs like ResNet
mid_chs = make_divisible(in_chs * bottle_ratio if reg else out_chs * bottle_ratio, ch_div)
groups = 1 if not group_size else mid_chs // group_size
if group_size and group_size % ch_div == 0:
mid_chs = group_size * groups # correct mid_chs if group_size divisible by ch_div, otherwise error
self.alpha = alpha
@ -200,12 +265,22 @@ class NormalizationFreeBlock(nn.Module):
self.conv1 = conv_layer(in_chs, mid_chs, 1)
self.act2 = act_layer(inplace=True)
self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups)
if attn_layer is not None:
self.attn = attn_layer(mid_chs)
if extra_conv:
self.act2b = act_layer(inplace=True)
self.conv2b = conv_layer(mid_chs, mid_chs, 3, stride=1, dilation=dilation, groups=groups)
else:
self.act2b = None
self.conv2b = None
if reg and attn_layer is not None:
self.attn = attn_layer(mid_chs) # RegNet blocks apply attn btw conv2 & 3
else:
self.attn = None
self.act3 = act_layer()
self.conv3 = conv_layer(mid_chs, out_chs, 1)
if not reg and attn_layer is not None:
self.attn_last = attn_layer(out_chs) # ResNet blocks apply attn after conv3
else:
self.attn_last = None
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
self.skipinit_gain = nn.Parameter(torch.tensor(0.)) if skipinit else None
@ -220,28 +295,48 @@ class NormalizationFreeBlock(nn.Module):
# residual branch
out = self.conv1(out)
out = self.conv2(self.act2(out))
if self.conv2b is not None:
out = self.conv2b(self.act2b(out))
if self.attn is not None:
out = self.attn_gain * self.attn(out)
out = self.conv3(self.act3(out))
if self.attn_last is not None:
out = self.attn_gain * self.attn_last(out)
out = self.drop_path(out)
if self.skipinit_gain is None:
out = out * self.alpha + shortcut
else:
if self.skipinit_gain is not None:
# this really slows things down for some reason, TBD
out = out * self.alpha * self.skipinit_gain + shortcut
out = out * self.skipinit_gain
out = out * self.alpha + shortcut
return out
def create_stem(in_chs, out_chs, stem_type='', conv_layer=None):
def create_stem(in_chs, out_chs, stem_type='', conv_layer=None, act_layer=None):
stem_stride = 2
stem_feature = dict(num_chs=out_chs, reduction=2, module='')
stem = OrderedDict()
assert stem_type in ('', 'deep', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool')
if 'deep' in stem_type:
assert stem_type in ('', 'deep', 'deep_tiered', 'deep_quad', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool')
if 'deep' in stem_type or 'nff' in stem_type:
# 3 deep 3x3 conv stack as in ResNet V1D models. NOTE: doesn't work as well here
mid_chs = out_chs // 2
stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2)
stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1)
stem['conv3'] = conv_layer(mid_chs, out_chs, kernel_size=3, stride=1)
if 'quad' in stem_type:
assert not 'pool' in stem_type
stem_chs = (16, 32, 64, out_chs)
strides = (2, 1, 1, 2)
stem_stride = 4
stem_feature = dict(num_chs=64, reduction=2, module='stem.act4')
else:
if 'tiered' in stem_type:
stem_chs = (3 * out_chs // 8, out_chs // 2, out_chs) # like 'T' resnets in resnet.py
else:
stem_chs = (out_chs // 2, out_chs // 2, out_chs) # 'D' ResNets
strides = (2, 1, 1)
stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.act3')
last_idx = len(stem_chs) - 1
for i, (c, s) in enumerate(zip(stem_chs, strides)):
stem[f'conv{i+1}'] = conv_layer(in_chs, c, kernel_size=3, stride=s)
if i != last_idx:
stem[f'act{i+2}'] = act_layer(inplace=True)
in_chs = c
elif '3x3' in stem_type:
# 3x3 stem conv as in RegNet
stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=3, stride=2)
@ -253,21 +348,37 @@ def create_stem(in_chs, out_chs, stem_type='', conv_layer=None):
stem['pool'] = nn.MaxPool2d(3, stride=2, padding=1)
stem_stride = 4
return nn.Sequential(stem), stem_stride
return nn.Sequential(stem), stem_stride, stem_feature
# from https://github.com/deepmind/deepmind-research/tree/master/nfnets
_nonlin_gamma = dict(
silu=1./.5595,
relu=(0.5 * (1. - 1. / math.pi)) ** -0.5,
identity=1.0
identity=1.0,
celu=1.270926833152771,
elu=1.2716004848480225,
gelu=1.7015043497085571,
leaky_relu=1.70590341091156,
log_sigmoid=1.9193484783172607,
log_softmax=1.0002083778381348,
relu=1.7139588594436646,
relu6=1.7131484746932983,
selu=1.0008515119552612,
sigmoid=4.803835391998291,
silu=1.7881293296813965,
softsign=2.338853120803833,
softplus=1.9203323125839233,
tanh=1.5939117670059204,
)
class NormalizerFreeNet(nn.Module):
""" Normalizer-free ResNets and RegNets
class NormFreeNet(nn.Module):
""" Normalization-Free Network
As described in `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
As described in :
`Characterizing signal propagation to close the performance gap in unnormalized ResNets`
- https://arxiv.org/abs/2101.08692
and
`High-Performance Large-Scale Image Recognition Without Normalization` - https://arxiv.org/abs/2102.06171
This model aims to cover both the NFRegNet-Bx models as detailed in the paper's code snippets and
the (preact) ResNet models described earlier in the paper.
@ -278,7 +389,7 @@ class NormalizerFreeNet(nn.Module):
* activation correcting gamma constants are moved into the ScaledStdConv as it has less performance
impact in PyTorch when done with the weight scaling there. This likely wasn't a concern in the JAX impl.
* a config option `gamma_in_act` can be enabled to not apply gamma in StdConv as described above, but
apply it in each activation. This is slightly slower, and yields slightly different results.
apply it in each activation. This is slightly slower, numerically different, but matches official impl.
* skipinit is disabled by default, it seems to have a rather drastic impact on GPU memory use and throughput
for what it is/does. Approx 8-10% throughput loss.
"""
@ -298,10 +409,11 @@ class NormalizerFreeNet(nn.Module):
stem_chs = cfg.stem_chs or cfg.channels[0]
stem_chs = make_divisible(stem_chs * cfg.width_factor, cfg.ch_div)
self.stem, stem_stride = create_stem(in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer)
self.stem, stem_stride, stem_feat = create_stem(
in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer, act_layer=act_layer)
self.feature_info = [] # NOTE: there will be no stride == 2 feature if stem_stride == 4
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
self.feature_info = [stem_feat] if stem_stride == 4 else []
drop_path_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
prev_chs = stem_chs
net_stride = stem_stride
dilation = 1
@ -309,8 +421,8 @@ class NormalizerFreeNet(nn.Module):
stages = []
for stage_idx, stage_depth in enumerate(cfg.depths):
stride = 1 if stage_idx == 0 and stem_stride > 2 else 2
self.feature_info += [dict(
num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}.0.act1' if stride == 2 else '')]
if stride == 2:
self.feature_info += [dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}.0.act1')]
if net_stride >= output_stride and stride > 1:
dilation *= stride
stride = 1
@ -321,23 +433,24 @@ class NormalizerFreeNet(nn.Module):
for block_idx in range(cfg.depths[stage_idx]):
first_block = block_idx == 0 and stage_idx == 0
out_chs = make_divisible(cfg.channels[stage_idx] * cfg.width_factor, cfg.ch_div)
blocks += [NormalizationFreeBlock(
blocks += [NormFreeBlock(
in_chs=prev_chs, out_chs=out_chs,
alpha=cfg.alpha,
beta=1. / expected_var ** 0.5, # NOTE: beta used as multiplier in block
beta=1. / expected_var ** 0.5,
stride=stride if block_idx == 0 else 1,
dilation=dilation,
first_dilation=first_dilation,
group_size=cfg.group_size,
bottle_ratio=1. if cfg.efficient and first_block else cfg.bottle_ratio,
efficient=cfg.efficient,
bottle_ratio=1. if cfg.reg and first_block else cfg.bottle_ratio,
ch_div=cfg.ch_div,
reg=cfg.reg,
extra_conv=cfg.extra_conv,
skipinit=cfg.skipinit,
attn_layer=attn_layer,
attn_gain=cfg.attn_gain,
act_layer=act_layer,
conv_layer=conv_layer,
drop_path_rate=dpr[stage_idx][block_idx],
skipinit=cfg.skipinit,
drop_path_rate=drop_path_rates[stage_idx][block_idx],
)]
if block_idx == 0:
expected_var = 1. # expected var is reset after first block of each stage
@ -347,27 +460,27 @@ class NormalizerFreeNet(nn.Module):
stages += [nn.Sequential(*blocks)]
self.stages = nn.Sequential(*stages)
if cfg.efficient and cfg.num_features:
if cfg.num_features:
# The paper NFRegNet models have an EfficientNet-like final head convolution.
self.num_features = make_divisible(cfg.width_factor * cfg.num_features, cfg.ch_div)
self.final_conv = conv_layer(prev_chs, self.num_features, 1)
else:
self.num_features = prev_chs
self.final_conv = nn.Identity()
# FIXME not 100% clear on gamma subtleties final conv/final act in case where it's in stdconv
self.final_act = act_layer()
self.final_act = act_layer(inplace=cfg.num_features > 0)
self.feature_info += [dict(num_chs=self.num_features, reduction=net_stride, module='final_act')]
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
for n, m in self.named_modules():
if 'fc' in n and isinstance(m, nn.Linear):
nn.init.zeros_(m.weight)
if cfg.zero_init_fc:
nn.init.zeros_(m.weight)
else:
nn.init.normal_(m.weight, 0., .01)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
# as per discussion with paper authors, original in haiku is
# hk.initializers.VarianceScaling(1.0, 'fan_in', 'normal')' w/ zero'd bias
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='linear')
if m.bias is not None:
nn.init.zeros_(m.bias)
@ -395,17 +508,127 @@ def _create_normfreenet(variant, pretrained=False, **kwargs):
model_cfg = model_cfgs[variant]
feature_cfg = dict(flatten_sequential=True)
feature_cfg['feature_cls'] = 'hook' # pre-act models need hooks to grab feat from act1 in bottleneck blocks
if 'pool' in model_cfg.stem_type:
feature_cfg['out_indices'] = (1, 2, 3, 4) # no stride 2, 0 level feat for stride 4 maxpool stems in ResNet
if 'pool' in model_cfg.stem_type and 'deep' not in model_cfg.stem_type:
feature_cfg['out_indices'] = (1, 2, 3, 4) # no stride 2 feat for stride 4, 1 layer maxpool stems
return build_model_with_cfg(
NormalizerFreeNet, variant, pretrained,
NormFreeNet, variant, pretrained,
default_cfg=default_cfgs[variant],
model_cfg=model_cfg,
feature_cfg=feature_cfg,
**kwargs)
@register_model
def nfnet_f0(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f0', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f1(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f1', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f2(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f2', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f3(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f3', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f4(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f4', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f5(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f5', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f6(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f6', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f7(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f7', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f0s(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f0s', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f1s(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f1s', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f2s(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f2s', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f3s(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f3s', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f4s(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f4s', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f5s(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f5s', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f6s(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f6s', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f7s(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f7s', pretrained=pretrained, **kwargs)
@register_model
def nf_regnet_b0(pretrained=False, **kwargs):
return _create_normfreenet('nf_regnet_b0', pretrained=pretrained, **kwargs)
@register_model
def nf_regnet_b1(pretrained=False, **kwargs):
return _create_normfreenet('nf_regnet_b1', pretrained=pretrained, **kwargs)
@register_model
def nf_regnet_b2(pretrained=False, **kwargs):
return _create_normfreenet('nf_regnet_b2', pretrained=pretrained, **kwargs)
@register_model
def nf_regnet_b3(pretrained=False, **kwargs):
return _create_normfreenet('nf_regnet_b3', pretrained=pretrained, **kwargs)
@register_model
def nf_regnet_b4(pretrained=False, **kwargs):
return _create_normfreenet('nf_regnet_b4', pretrained=pretrained, **kwargs)
@register_model
def nf_regnet_b5(pretrained=False, **kwargs):
return _create_normfreenet('nf_regnet_b5', pretrained=pretrained, **kwargs)
@register_model
def nf_regnet_b0(pretrained=False, **kwargs):
return _create_normfreenet('nf_regnet_b0', pretrained=pretrained, **kwargs)
@ -475,6 +698,7 @@ def nf_ecaresnet26(pretrained=False, **kwargs):
def nf_ecaresnet50(pretrained=False, **kwargs):
return _create_normfreenet('nf_ecaresnet50', pretrained=pretrained, **kwargs)
@register_model
def nf_ecaresnet101(pretrained=False, **kwargs):
return _create_normfreenet('nf_ecaresnet101', pretrained=pretrained, **kwargs)
Loading…
Cancel
Save