|
|
@ -1,10 +1,19 @@
|
|
|
|
""" Normalizer Free RegNet / ResNet (pre-activation) Models
|
|
|
|
""" Normalization Free Nets. NFNet, NF-RegNet, NF-ResNet (pre-activation) Models
|
|
|
|
|
|
|
|
|
|
|
|
Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
|
|
|
|
Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
|
|
|
|
- https://arxiv.org/abs/2101.08692
|
|
|
|
- https://arxiv.org/abs/2101.08692
|
|
|
|
|
|
|
|
|
|
|
|
NOTE: These models are a work in progress, no pretrained weights yet but I'm currently training some.
|
|
|
|
|
|
|
|
Details may change, especially once the paper authors release their official models.
|
|
|
|
Paper: `High-Performance Large-Scale Image Recognition Without Normalization`
|
|
|
|
|
|
|
|
- https://arxiv.org/abs/2102.06171
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Status:
|
|
|
|
|
|
|
|
* These models are a work in progress, experiments ongoing.
|
|
|
|
|
|
|
|
* Two pretrained weights so far, more to come.
|
|
|
|
|
|
|
|
* Model details update to closer match official JAX code now that it's released
|
|
|
|
|
|
|
|
* NF-ResNet, NF-RegNet-B, and NFNet-F models supported
|
|
|
|
|
|
|
|
|
|
|
|
Hacked together by / copyright Ross Wightman, 2021.
|
|
|
|
Hacked together by / copyright Ross Wightman, 2021.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -28,37 +37,71 @@ def _dcfg(url='', **kwargs):
|
|
|
|
return {
|
|
|
|
return {
|
|
|
|
'url': url,
|
|
|
|
'url': url,
|
|
|
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
|
|
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
|
|
|
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
|
|
|
'crop_pct': 0.9, 'interpolation': 'bicubic',
|
|
|
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
|
|
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
|
|
|
'first_conv': 'stem.conv', 'classifier': 'head.fc',
|
|
|
|
'first_conv': 'stem.conv', 'classifier': 'head.fc',
|
|
|
|
**kwargs
|
|
|
|
**kwargs
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
default_cfgs = {
|
|
|
|
default_cfgs = dict(
|
|
|
|
'nf_regnet_b0': _dcfg(url=''),
|
|
|
|
nfnet_f0=_dcfg(
|
|
|
|
'nf_regnet_b1': _dcfg(
|
|
|
|
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'),
|
|
|
|
|
|
|
|
nfnet_f1=_dcfg(
|
|
|
|
|
|
|
|
url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320), first_conv='stem.conv1'),
|
|
|
|
|
|
|
|
nfnet_f2=_dcfg(
|
|
|
|
|
|
|
|
url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352), first_conv='stem.conv1'),
|
|
|
|
|
|
|
|
nfnet_f3=_dcfg(
|
|
|
|
|
|
|
|
url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416), first_conv='stem.conv1'),
|
|
|
|
|
|
|
|
nfnet_f4=_dcfg(
|
|
|
|
|
|
|
|
url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512), first_conv='stem.conv1'),
|
|
|
|
|
|
|
|
nfnet_f5=_dcfg(
|
|
|
|
|
|
|
|
url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544), first_conv='stem.conv1'),
|
|
|
|
|
|
|
|
nfnet_f6=_dcfg(
|
|
|
|
|
|
|
|
url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576), first_conv='stem.conv1'),
|
|
|
|
|
|
|
|
nfnet_f7=_dcfg(
|
|
|
|
|
|
|
|
url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608), first_conv='stem.conv1'),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nfnet_f0s=_dcfg(
|
|
|
|
|
|
|
|
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'),
|
|
|
|
|
|
|
|
nfnet_f1s=_dcfg(
|
|
|
|
|
|
|
|
url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320), first_conv='stem.conv1'),
|
|
|
|
|
|
|
|
nfnet_f2s=_dcfg(
|
|
|
|
|
|
|
|
url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352), first_conv='stem.conv1'),
|
|
|
|
|
|
|
|
nfnet_f3s=_dcfg(
|
|
|
|
|
|
|
|
url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416), first_conv='stem.conv1'),
|
|
|
|
|
|
|
|
nfnet_f4s=_dcfg(
|
|
|
|
|
|
|
|
url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512), first_conv='stem.conv1'),
|
|
|
|
|
|
|
|
nfnet_f5s=_dcfg(
|
|
|
|
|
|
|
|
url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544), first_conv='stem.conv1'),
|
|
|
|
|
|
|
|
nfnet_f6s=_dcfg(
|
|
|
|
|
|
|
|
url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576), first_conv='stem.conv1'),
|
|
|
|
|
|
|
|
nfnet_f7s=_dcfg(
|
|
|
|
|
|
|
|
url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608), first_conv='stem.conv1'),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nf_regnet_b0=_dcfg(url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
|
|
|
|
|
|
|
|
nf_regnet_b1=_dcfg(
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_regnet_b1_256_ra2-ad85cfef.pth',
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_regnet_b1_256_ra2-ad85cfef.pth',
|
|
|
|
pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), crop_pct=0.9),
|
|
|
|
pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288)), # NOT to paper spec
|
|
|
|
'nf_regnet_b2': _dcfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
|
|
|
nf_regnet_b2=_dcfg(url='', pool_size=(8, 8), input_size=(3, 240, 240), test_input_size=(3, 272, 272)),
|
|
|
|
'nf_regnet_b3': _dcfg(url='', input_size=(3, 272, 272), pool_size=(9, 9)),
|
|
|
|
nf_regnet_b3=_dcfg(url='', pool_size=(9, 9), input_size=(3, 288, 288), test_input_size=(3, 320, 320)),
|
|
|
|
'nf_regnet_b4': _dcfg(url='', input_size=(3, 320, 320), pool_size=(10, 10)),
|
|
|
|
nf_regnet_b4=_dcfg(url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384)),
|
|
|
|
'nf_regnet_b5': _dcfg(url='', input_size=(3, 384, 384), pool_size=(12, 12)),
|
|
|
|
nf_regnet_b5=_dcfg(url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 456, 456)),
|
|
|
|
|
|
|
|
|
|
|
|
'nf_resnet26': _dcfg(url='', first_conv='stem.conv'),
|
|
|
|
nf_resnet26=_dcfg(url=''),
|
|
|
|
'nf_resnet50': _dcfg(
|
|
|
|
nf_resnet50=_dcfg(
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_resnet50_ra2-9f236009.pth',
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_resnet50_ra2-9f236009.pth',
|
|
|
|
first_conv='stem.conv', pool_size=(8, 8), input_size=(3, 256, 256), crop_pct=0.94),
|
|
|
|
pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), crop_pct=0.94),
|
|
|
|
'nf_resnet101': _dcfg(url='', first_conv='stem.conv'),
|
|
|
|
nf_resnet101=_dcfg(url=''),
|
|
|
|
|
|
|
|
|
|
|
|
'nf_seresnet26': _dcfg(url='', first_conv='stem.conv'),
|
|
|
|
nf_seresnet26=_dcfg(url=''),
|
|
|
|
'nf_seresnet50': _dcfg(url='', first_conv='stem.conv'),
|
|
|
|
nf_seresnet50=_dcfg(url=''),
|
|
|
|
'nf_seresnet101': _dcfg(url='', first_conv='stem.conv'),
|
|
|
|
nf_seresnet101=_dcfg(url=''),
|
|
|
|
|
|
|
|
|
|
|
|
'nf_ecaresnet26': _dcfg(url='', first_conv='stem.conv'),
|
|
|
|
nf_ecaresnet26=_dcfg(url=''),
|
|
|
|
'nf_ecaresnet50': _dcfg(url='', first_conv='stem.conv'),
|
|
|
|
nf_ecaresnet50=_dcfg(url=''),
|
|
|
|
'nf_ecaresnet101': _dcfg(url='', first_conv='stem.conv'),
|
|
|
|
nf_ecaresnet101=_dcfg(url=''),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
@dataclass
|
|
|
@ -69,69 +112,95 @@ class NfCfg:
|
|
|
|
gamma_in_act: bool = False
|
|
|
|
gamma_in_act: bool = False
|
|
|
|
stem_type: str = '3x3'
|
|
|
|
stem_type: str = '3x3'
|
|
|
|
stem_chs: Optional[int] = None
|
|
|
|
stem_chs: Optional[int] = None
|
|
|
|
group_size: Optional[int] = 8
|
|
|
|
group_size: Optional[int] = None
|
|
|
|
attn_layer: Optional[str] = 'se'
|
|
|
|
attn_layer: Optional[str] = None
|
|
|
|
attn_kwargs: dict = field(default_factory=lambda: dict(reduction_ratio=0.5, divisor=8))
|
|
|
|
attn_kwargs: dict = None
|
|
|
|
attn_gain: float = 2.0 # NF correction gain to apply if attn layer is used
|
|
|
|
attn_gain: float = 2.0 # NF correction gain to apply if attn layer is used
|
|
|
|
width_factor: float = 0.75
|
|
|
|
width_factor: float = 1.0
|
|
|
|
bottle_ratio: float = 2.25
|
|
|
|
bottle_ratio: float = 0.5
|
|
|
|
efficient: bool = True # enables EfficientNet-like options that are used in paper 'nf_regnet_b*' models
|
|
|
|
num_features: int = 0 # num out_channels for final conv, no final_conv if 0
|
|
|
|
num_features: int = 1280 # num out_channels for final conv (when enabled in efficient mode)
|
|
|
|
|
|
|
|
ch_div: int = 8 # round channels % 8 == 0 to keep tensor-core use optimal
|
|
|
|
ch_div: int = 8 # round channels % 8 == 0 to keep tensor-core use optimal
|
|
|
|
skipinit: bool = False
|
|
|
|
reg: bool = False # enables EfficientNet-like options used in RegNet variants, expand from in_chs, se in middle
|
|
|
|
|
|
|
|
extra_conv: bool = False # extra 3x3 bottleneck convolution for NFNet models
|
|
|
|
|
|
|
|
skipinit: bool = False # disabled by default, non-trivial performance impact
|
|
|
|
|
|
|
|
zero_init_fc: bool = False
|
|
|
|
act_layer: str = 'silu'
|
|
|
|
act_layer: str = 'silu'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _nfres_cfg(
|
|
|
|
|
|
|
|
depths, channels=(256, 512, 1024, 2048), group_size=None, act_layer='relu', attn_layer=None, attn_kwargs=None):
|
|
|
|
|
|
|
|
attn_kwargs = attn_kwargs or {}
|
|
|
|
|
|
|
|
cfg = NfCfg(
|
|
|
|
|
|
|
|
depths=depths, channels=channels, stem_type='7x7_pool', stem_chs=64, bottle_ratio=0.25,
|
|
|
|
|
|
|
|
group_size=group_size, act_layer=act_layer, attn_layer=attn_layer, attn_kwargs=attn_kwargs)
|
|
|
|
|
|
|
|
return cfg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _nfreg_cfg(depths, channels=(48, 104, 208, 440)):
|
|
|
|
|
|
|
|
num_features = 1280 * channels[-1] // 440
|
|
|
|
|
|
|
|
attn_kwargs = dict(reduction_ratio=0.5, divisor=8)
|
|
|
|
|
|
|
|
cfg = NfCfg(
|
|
|
|
|
|
|
|
depths=depths, channels=channels, stem_type='3x3', group_size=8, width_factor=0.75, bottle_ratio=2.25,
|
|
|
|
|
|
|
|
num_features=num_features, reg=True, attn_layer='se', attn_kwargs=attn_kwargs)
|
|
|
|
|
|
|
|
return cfg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _nfnet_cfg(depths, act_layer='gelu', attn_layer='se', attn_kwargs=None):
|
|
|
|
|
|
|
|
channels = (256, 512, 1536, 1536)
|
|
|
|
|
|
|
|
num_features = channels[-1] * 2
|
|
|
|
|
|
|
|
attn_kwargs = attn_kwargs or dict(reduction_ratio=0.5, divisor=8)
|
|
|
|
|
|
|
|
cfg = NfCfg(
|
|
|
|
|
|
|
|
depths=depths, channels=channels, stem_type='nff', group_size=128, bottle_ratio=0.5, extra_conv=True,
|
|
|
|
|
|
|
|
num_features=num_features, act_layer=act_layer, attn_layer=attn_layer, attn_kwargs=attn_kwargs)
|
|
|
|
|
|
|
|
return cfg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_cfgs = dict(
|
|
|
|
model_cfgs = dict(
|
|
|
|
# EffNet influenced RegNet defs
|
|
|
|
# NFNet-F models w/ GeLU
|
|
|
|
nf_regnet_b0=NfCfg(depths=(1, 3, 6, 6), channels=(48, 104, 208, 440), num_features=1280),
|
|
|
|
nfnet_f0=_nfnet_cfg(depths=(1, 2, 6, 3)),
|
|
|
|
nf_regnet_b1=NfCfg(depths=(2, 4, 7, 7), channels=(48, 104, 208, 440), num_features=1280),
|
|
|
|
nfnet_f1=_nfnet_cfg(depths=(2, 4, 12, 6)),
|
|
|
|
nf_regnet_b2=NfCfg(depths=(2, 4, 8, 8), channels=(56, 112, 232, 488), num_features=1416),
|
|
|
|
nfnet_f2=_nfnet_cfg(depths=(3, 6, 18, 9)),
|
|
|
|
nf_regnet_b3=NfCfg(depths=(2, 5, 9, 9), channels=(56, 128, 248, 528), num_features=1536),
|
|
|
|
nfnet_f3=_nfnet_cfg(depths=(4, 8, 24, 12)),
|
|
|
|
nf_regnet_b4=NfCfg(depths=(2, 6, 11, 11), channels=(64, 144, 288, 616), num_features=1792),
|
|
|
|
nfnet_f4=_nfnet_cfg(depths=(5, 10, 30, 15)),
|
|
|
|
nf_regnet_b5=NfCfg(depths=(3, 7, 14, 14), channels=(80, 168, 336, 704), num_features=2048),
|
|
|
|
nfnet_f5=_nfnet_cfg(depths=(6, 12, 36, 18)),
|
|
|
|
|
|
|
|
nfnet_f6=_nfnet_cfg(depths=(7, 14, 42, 21)),
|
|
|
|
|
|
|
|
nfnet_f7=_nfnet_cfg(depths=(8, 16, 48, 24)),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# NFNet-F models w/ SiLU (much faster in PyTorch)
|
|
|
|
|
|
|
|
nfnet_f0s=_nfnet_cfg(depths=(1, 2, 6, 3), act_layer='silu'),
|
|
|
|
|
|
|
|
nfnet_f1s=_nfnet_cfg(depths=(2, 4, 12, 6), act_layer='silu'),
|
|
|
|
|
|
|
|
nfnet_f2s=_nfnet_cfg(depths=(3, 6, 18, 9), act_layer='silu'),
|
|
|
|
|
|
|
|
nfnet_f3s=_nfnet_cfg(depths=(4, 8, 24, 12), act_layer='silu'),
|
|
|
|
|
|
|
|
nfnet_f4s=_nfnet_cfg(depths=(5, 10, 30, 15), act_layer='silu'),
|
|
|
|
|
|
|
|
nfnet_f5s=_nfnet_cfg(depths=(6, 12, 36, 18), act_layer='silu'),
|
|
|
|
|
|
|
|
nfnet_f6s=_nfnet_cfg(depths=(7, 14, 42, 21), act_layer='silu'),
|
|
|
|
|
|
|
|
nfnet_f7s=_nfnet_cfg(depths=(8, 16, 48, 24), act_layer='silu'),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# NFNet-F models w/ SiLU (much faster in PyTorch)
|
|
|
|
|
|
|
|
# FIXME add remainder if silu vs gelu proves worthwhile
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# EffNet influenced RegNet defs.
|
|
|
|
|
|
|
|
# NOTE: These aren't quite the official ver, ch_div=1 must be set for exact ch counts. I round to ch_div=8.
|
|
|
|
|
|
|
|
nf_regnet_b0=_nfreg_cfg(depths=(1, 3, 6, 6)),
|
|
|
|
|
|
|
|
nf_regnet_b1=_nfreg_cfg(depths=(2, 4, 7, 7)),
|
|
|
|
|
|
|
|
nf_regnet_b2=_nfreg_cfg(depths=(2, 4, 8, 8), channels=(56, 112, 232, 488)),
|
|
|
|
|
|
|
|
nf_regnet_b3=_nfreg_cfg(depths=(2, 5, 9, 9), channels=(56, 128, 248, 528)),
|
|
|
|
|
|
|
|
nf_regnet_b4=_nfreg_cfg(depths=(2, 6, 11, 11), channels=(64, 144, 288, 616)),
|
|
|
|
|
|
|
|
nf_regnet_b5=_nfreg_cfg(depths=(3, 7, 14, 14), channels=(80, 168, 336, 704)),
|
|
|
|
|
|
|
|
# FIXME add B6-B8
|
|
|
|
|
|
|
|
|
|
|
|
# ResNet (preact, D style deep stem/avg down) defs
|
|
|
|
# ResNet (preact, D style deep stem/avg down) defs
|
|
|
|
nf_resnet26=NfCfg(
|
|
|
|
nf_resnet26=_nfres_cfg(depths=(2, 2, 2, 2)),
|
|
|
|
depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048),
|
|
|
|
nf_resnet50=_nfres_cfg(depths=(3, 4, 6, 3)),
|
|
|
|
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
|
|
|
nf_resnet101=_nfres_cfg(depths=(3, 4, 23, 3)),
|
|
|
|
act_layer='relu', attn_layer=None,),
|
|
|
|
|
|
|
|
nf_resnet50=NfCfg(
|
|
|
|
nf_seresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
|
|
|
|
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
|
|
|
|
nf_seresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
|
|
|
|
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
|
|
|
nf_seresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
|
|
|
|
act_layer='relu', attn_layer=None),
|
|
|
|
|
|
|
|
nf_resnet101=NfCfg(
|
|
|
|
nf_ecaresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='eca', attn_kwargs=dict()),
|
|
|
|
depths=(3, 4, 23, 3), channels=(256, 512, 1024, 2048),
|
|
|
|
nf_ecaresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='eca', attn_kwargs=dict()),
|
|
|
|
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
|
|
|
nf_ecaresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='eca', attn_kwargs=dict()),
|
|
|
|
act_layer='relu', attn_layer=None),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nf_seresnet26=NfCfg(
|
|
|
|
|
|
|
|
depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048),
|
|
|
|
|
|
|
|
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
|
|
|
|
|
|
|
act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
|
|
|
|
|
|
|
|
nf_seresnet50=NfCfg(
|
|
|
|
|
|
|
|
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
|
|
|
|
|
|
|
|
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
|
|
|
|
|
|
|
act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
|
|
|
|
|
|
|
|
nf_seresnet101=NfCfg(
|
|
|
|
|
|
|
|
depths=(3, 4, 23, 3), channels=(256, 512, 1024, 2048),
|
|
|
|
|
|
|
|
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
|
|
|
|
|
|
|
act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nf_ecaresnet26=NfCfg(
|
|
|
|
|
|
|
|
depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048),
|
|
|
|
|
|
|
|
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
|
|
|
|
|
|
|
act_layer='relu', attn_layer='eca', attn_kwargs=dict()),
|
|
|
|
|
|
|
|
nf_ecaresnet50=NfCfg(
|
|
|
|
|
|
|
|
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
|
|
|
|
|
|
|
|
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
|
|
|
|
|
|
|
act_layer='relu', attn_layer='eca', attn_kwargs=dict()),
|
|
|
|
|
|
|
|
nf_ecaresnet101=NfCfg(
|
|
|
|
|
|
|
|
depths=(3, 4, 23, 3), channels=(256, 512, 1024, 2048),
|
|
|
|
|
|
|
|
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
|
|
|
|
|
|
|
act_layer='relu', attn_layer='eca', attn_kwargs=dict()),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
@ -170,20 +239,20 @@ class DownsampleAvg(nn.Module):
|
|
|
|
return self.conv(self.pool(x))
|
|
|
|
return self.conv(self.pool(x))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NormalizationFreeBlock(nn.Module):
|
|
|
|
class NormFreeBlock(nn.Module):
|
|
|
|
"""Normalization-free pre-activation block.
|
|
|
|
"""Normalization-free pre-activation block.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
|
self, in_chs, out_chs=None, stride=1, dilation=1, first_dilation=None,
|
|
|
|
self, in_chs, out_chs=None, stride=1, dilation=1, first_dilation=None,
|
|
|
|
alpha=1.0, beta=1.0, bottle_ratio=0.25, efficient=True, ch_div=1, group_size=None,
|
|
|
|
alpha=1.0, beta=1.0, bottle_ratio=0.25, group_size=None, ch_div=1, reg=True, extra_conv=False,
|
|
|
|
attn_layer=None, attn_gain=2.0, act_layer=None, conv_layer=None, drop_path_rate=0., skipinit=False):
|
|
|
|
skipinit=False, attn_layer=None, attn_gain=2.0, act_layer=None, conv_layer=None, drop_path_rate=0.):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
first_dilation = first_dilation or dilation
|
|
|
|
first_dilation = first_dilation or dilation
|
|
|
|
out_chs = out_chs or in_chs
|
|
|
|
out_chs = out_chs or in_chs
|
|
|
|
# EfficientNet-like models scale bottleneck from in_chs, otherwise scale from out_chs like ResNet
|
|
|
|
# RegNet variants scale bottleneck from in_chs, otherwise scale from out_chs like ResNet
|
|
|
|
mid_chs = make_divisible(in_chs * bottle_ratio if efficient else out_chs * bottle_ratio, ch_div)
|
|
|
|
mid_chs = make_divisible(in_chs * bottle_ratio if reg else out_chs * bottle_ratio, ch_div)
|
|
|
|
groups = 1 if group_size is None else mid_chs // group_size
|
|
|
|
groups = 1 if not group_size else mid_chs // group_size
|
|
|
|
if group_size and group_size % ch_div == 0:
|
|
|
|
if group_size and group_size % ch_div == 0:
|
|
|
|
mid_chs = group_size * groups # correct mid_chs if group_size divisible by ch_div, otherwise error
|
|
|
|
mid_chs = group_size * groups # correct mid_chs if group_size divisible by ch_div, otherwise error
|
|
|
|
self.alpha = alpha
|
|
|
|
self.alpha = alpha
|
|
|
@ -200,12 +269,22 @@ class NormalizationFreeBlock(nn.Module):
|
|
|
|
self.conv1 = conv_layer(in_chs, mid_chs, 1)
|
|
|
|
self.conv1 = conv_layer(in_chs, mid_chs, 1)
|
|
|
|
self.act2 = act_layer(inplace=True)
|
|
|
|
self.act2 = act_layer(inplace=True)
|
|
|
|
self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups)
|
|
|
|
self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups)
|
|
|
|
if attn_layer is not None:
|
|
|
|
if extra_conv:
|
|
|
|
self.attn = attn_layer(mid_chs)
|
|
|
|
self.act2b = act_layer(inplace=True)
|
|
|
|
|
|
|
|
self.conv2b = conv_layer(mid_chs, mid_chs, 3, stride=1, dilation=dilation, groups=groups)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
self.act2b = None
|
|
|
|
|
|
|
|
self.conv2b = None
|
|
|
|
|
|
|
|
if reg and attn_layer is not None:
|
|
|
|
|
|
|
|
self.attn = attn_layer(mid_chs) # RegNet blocks apply attn btw conv2 & 3
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
self.attn = None
|
|
|
|
self.attn = None
|
|
|
|
self.act3 = act_layer()
|
|
|
|
self.act3 = act_layer()
|
|
|
|
self.conv3 = conv_layer(mid_chs, out_chs, 1)
|
|
|
|
self.conv3 = conv_layer(mid_chs, out_chs, 1)
|
|
|
|
|
|
|
|
if not reg and attn_layer is not None:
|
|
|
|
|
|
|
|
self.attn_last = attn_layer(out_chs) # ResNet blocks apply attn after conv3
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
self.attn_last = None
|
|
|
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
|
|
|
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
|
|
|
|
self.skipinit_gain = nn.Parameter(torch.tensor(0.)) if skipinit else None
|
|
|
|
self.skipinit_gain = nn.Parameter(torch.tensor(0.)) if skipinit else None
|
|
|
|
|
|
|
|
|
|
|
@ -220,28 +299,60 @@ class NormalizationFreeBlock(nn.Module):
|
|
|
|
# residual branch
|
|
|
|
# residual branch
|
|
|
|
out = self.conv1(out)
|
|
|
|
out = self.conv1(out)
|
|
|
|
out = self.conv2(self.act2(out))
|
|
|
|
out = self.conv2(self.act2(out))
|
|
|
|
|
|
|
|
if self.conv2b is not None:
|
|
|
|
|
|
|
|
out = self.conv2b(self.act2b(out))
|
|
|
|
if self.attn is not None:
|
|
|
|
if self.attn is not None:
|
|
|
|
out = self.attn_gain * self.attn(out)
|
|
|
|
out = self.attn_gain * self.attn(out)
|
|
|
|
out = self.conv3(self.act3(out))
|
|
|
|
out = self.conv3(self.act3(out))
|
|
|
|
|
|
|
|
if self.attn_last is not None:
|
|
|
|
|
|
|
|
out = self.attn_gain * self.attn_last(out)
|
|
|
|
out = self.drop_path(out)
|
|
|
|
out = self.drop_path(out)
|
|
|
|
if self.skipinit_gain is None:
|
|
|
|
|
|
|
|
out = out * self.alpha + shortcut
|
|
|
|
if self.skipinit_gain is not None:
|
|
|
|
else:
|
|
|
|
|
|
|
|
# this really slows things down for some reason, TBD
|
|
|
|
# this really slows things down for some reason, TBD
|
|
|
|
out = out * self.alpha * self.skipinit_gain + shortcut
|
|
|
|
out = out * self.skipinit_gain
|
|
|
|
|
|
|
|
out = out * self.alpha + shortcut
|
|
|
|
return out
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_stem(in_chs, out_chs, stem_type='', conv_layer=None):
|
|
|
|
def stem_info(stem_type):
|
|
|
|
stem_stride = 2
|
|
|
|
stem_stride = 2
|
|
|
|
|
|
|
|
if 'nff' in stem_type or 'pool' in stem_type:
|
|
|
|
|
|
|
|
stem_stride = 4
|
|
|
|
|
|
|
|
stem_feat = ''
|
|
|
|
|
|
|
|
if 'nff' in stem_type:
|
|
|
|
|
|
|
|
stem_feat = 'stem.act3'
|
|
|
|
|
|
|
|
elif 'deep' in stem_type and not 'pool' in stem_type:
|
|
|
|
|
|
|
|
stem_feat = 'stem.act2'
|
|
|
|
|
|
|
|
return stem_stride, stem_feat
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_stem(in_chs, out_chs, stem_type='', conv_layer=None, act_layer=None):
|
|
|
|
|
|
|
|
stem_stride = 2
|
|
|
|
|
|
|
|
stem_feature = ''
|
|
|
|
stem = OrderedDict()
|
|
|
|
stem = OrderedDict()
|
|
|
|
assert stem_type in ('', 'deep', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool')
|
|
|
|
assert stem_type in ('', 'nff', 'deep', 'deep_tiered', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool')
|
|
|
|
if 'deep' in stem_type:
|
|
|
|
if 'deep' in stem_type or 'nff' in stem_type:
|
|
|
|
# 3 deep 3x3 conv stack as in ResNet V1D models. NOTE: doesn't work as well here
|
|
|
|
# 3 deep 3x3 conv stack as in ResNet V1D models. NOTE: doesn't work as well here
|
|
|
|
mid_chs = out_chs // 2
|
|
|
|
if 'nff' in stem_type:
|
|
|
|
stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2)
|
|
|
|
assert not 'pool' in stem_type
|
|
|
|
stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1)
|
|
|
|
stem_chs = (16, 32, 64, out_chs)
|
|
|
|
stem['conv3'] = conv_layer(mid_chs, out_chs, kernel_size=3, stride=1)
|
|
|
|
strides = (2, 1, 1, 2)
|
|
|
|
|
|
|
|
stem_stride = 4
|
|
|
|
|
|
|
|
stem_feature = 'stem.act4'
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
if 'tiered' in stem_type:
|
|
|
|
|
|
|
|
stem_chs = (3 * out_chs // 8, out_chs // 2, out_chs)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
stem_chs = (out_chs // 2, out_chs // 2, out_chs)
|
|
|
|
|
|
|
|
strides = (2, 1, 1)
|
|
|
|
|
|
|
|
stem_feature = 'stem.act3'
|
|
|
|
|
|
|
|
last_idx = len(stem_chs) - 1
|
|
|
|
|
|
|
|
for i, (c, s) in enumerate(zip(stem_chs, strides)):
|
|
|
|
|
|
|
|
stem[f'conv{i+1}'] = conv_layer(in_chs, c, kernel_size=3, stride=s)
|
|
|
|
|
|
|
|
if i != last_idx:
|
|
|
|
|
|
|
|
stem[f'act{i+2}'] = act_layer(inplace=True)
|
|
|
|
|
|
|
|
in_chs = c
|
|
|
|
elif '3x3' in stem_type:
|
|
|
|
elif '3x3' in stem_type:
|
|
|
|
# 3x3 stem conv as in RegNet
|
|
|
|
# 3x3 stem conv as in RegNet
|
|
|
|
stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=3, stride=2)
|
|
|
|
stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=3, stride=2)
|
|
|
@ -253,18 +364,30 @@ def create_stem(in_chs, out_chs, stem_type='', conv_layer=None):
|
|
|
|
stem['pool'] = nn.MaxPool2d(3, stride=2, padding=1)
|
|
|
|
stem['pool'] = nn.MaxPool2d(3, stride=2, padding=1)
|
|
|
|
stem_stride = 4
|
|
|
|
stem_stride = 4
|
|
|
|
|
|
|
|
|
|
|
|
return nn.Sequential(stem), stem_stride
|
|
|
|
return nn.Sequential(stem), stem_stride, stem_feature
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_nonlin_gamma = dict(
|
|
|
|
_nonlin_gamma = dict(
|
|
|
|
silu=1./.5595,
|
|
|
|
identity=1.0,
|
|
|
|
relu=(0.5 * (1. - 1. / math.pi)) ** -0.5,
|
|
|
|
celu=1.270926833152771,
|
|
|
|
identity=1.0
|
|
|
|
elu=1.2716004848480225,
|
|
|
|
|
|
|
|
gelu=1.7015043497085571,
|
|
|
|
|
|
|
|
leaky_relu=1.70590341091156,
|
|
|
|
|
|
|
|
log_sigmoid=1.9193484783172607,
|
|
|
|
|
|
|
|
log_softmax=1.0002083778381348,
|
|
|
|
|
|
|
|
relu=1.7139588594436646,
|
|
|
|
|
|
|
|
relu6=1.7131484746932983,
|
|
|
|
|
|
|
|
selu=1.0008515119552612,
|
|
|
|
|
|
|
|
sigmoid=4.803835391998291,
|
|
|
|
|
|
|
|
silu=1.7881293296813965,
|
|
|
|
|
|
|
|
softsign=2.338853120803833,
|
|
|
|
|
|
|
|
softplus=1.9203323125839233,
|
|
|
|
|
|
|
|
tanh=1.5939117670059204,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NormalizerFreeNet(nn.Module):
|
|
|
|
class NormFreeNet(nn.Module):
|
|
|
|
""" Normalizer-free ResNets and RegNets
|
|
|
|
""" Normalization-free ResNets and RegNets
|
|
|
|
|
|
|
|
|
|
|
|
As described in `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
|
|
|
|
As described in `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
|
|
|
|
- https://arxiv.org/abs/2101.08692
|
|
|
|
- https://arxiv.org/abs/2101.08692
|
|
|
@ -298,10 +421,11 @@ class NormalizerFreeNet(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
stem_chs = cfg.stem_chs or cfg.channels[0]
|
|
|
|
stem_chs = cfg.stem_chs or cfg.channels[0]
|
|
|
|
stem_chs = make_divisible(stem_chs * cfg.width_factor, cfg.ch_div)
|
|
|
|
stem_chs = make_divisible(stem_chs * cfg.width_factor, cfg.ch_div)
|
|
|
|
self.stem, stem_stride = create_stem(in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer)
|
|
|
|
self.stem, stem_stride, stem_feat = create_stem(
|
|
|
|
|
|
|
|
in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer, act_layer=act_layer)
|
|
|
|
|
|
|
|
|
|
|
|
self.feature_info = [] # NOTE: there will be no stride == 2 feature if stem_stride == 4
|
|
|
|
self.feature_info = [dict(num_chs=stem_chs, reduction=2, module=stem_feat)] if stem_stride == 4 else []
|
|
|
|
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
|
|
|
|
drop_path_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
|
|
|
|
prev_chs = stem_chs
|
|
|
|
prev_chs = stem_chs
|
|
|
|
net_stride = stem_stride
|
|
|
|
net_stride = stem_stride
|
|
|
|
dilation = 1
|
|
|
|
dilation = 1
|
|
|
@ -309,8 +433,8 @@ class NormalizerFreeNet(nn.Module):
|
|
|
|
stages = []
|
|
|
|
stages = []
|
|
|
|
for stage_idx, stage_depth in enumerate(cfg.depths):
|
|
|
|
for stage_idx, stage_depth in enumerate(cfg.depths):
|
|
|
|
stride = 1 if stage_idx == 0 and stem_stride > 2 else 2
|
|
|
|
stride = 1 if stage_idx == 0 and stem_stride > 2 else 2
|
|
|
|
self.feature_info += [dict(
|
|
|
|
if stride == 2:
|
|
|
|
num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}.0.act1' if stride == 2 else '')]
|
|
|
|
self.feature_info += [dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}.0.act1')]
|
|
|
|
if net_stride >= output_stride and stride > 1:
|
|
|
|
if net_stride >= output_stride and stride > 1:
|
|
|
|
dilation *= stride
|
|
|
|
dilation *= stride
|
|
|
|
stride = 1
|
|
|
|
stride = 1
|
|
|
@ -321,7 +445,7 @@ class NormalizerFreeNet(nn.Module):
|
|
|
|
for block_idx in range(cfg.depths[stage_idx]):
|
|
|
|
for block_idx in range(cfg.depths[stage_idx]):
|
|
|
|
first_block = block_idx == 0 and stage_idx == 0
|
|
|
|
first_block = block_idx == 0 and stage_idx == 0
|
|
|
|
out_chs = make_divisible(cfg.channels[stage_idx] * cfg.width_factor, cfg.ch_div)
|
|
|
|
out_chs = make_divisible(cfg.channels[stage_idx] * cfg.width_factor, cfg.ch_div)
|
|
|
|
blocks += [NormalizationFreeBlock(
|
|
|
|
blocks += [NormFreeBlock(
|
|
|
|
in_chs=prev_chs, out_chs=out_chs,
|
|
|
|
in_chs=prev_chs, out_chs=out_chs,
|
|
|
|
alpha=cfg.alpha,
|
|
|
|
alpha=cfg.alpha,
|
|
|
|
beta=1. / expected_var ** 0.5, # NOTE: beta used as multiplier in block
|
|
|
|
beta=1. / expected_var ** 0.5, # NOTE: beta used as multiplier in block
|
|
|
@ -329,15 +453,16 @@ class NormalizerFreeNet(nn.Module):
|
|
|
|
dilation=dilation,
|
|
|
|
dilation=dilation,
|
|
|
|
first_dilation=first_dilation,
|
|
|
|
first_dilation=first_dilation,
|
|
|
|
group_size=cfg.group_size,
|
|
|
|
group_size=cfg.group_size,
|
|
|
|
bottle_ratio=1. if cfg.efficient and first_block else cfg.bottle_ratio,
|
|
|
|
bottle_ratio=1. if cfg.reg and first_block else cfg.bottle_ratio,
|
|
|
|
efficient=cfg.efficient,
|
|
|
|
|
|
|
|
ch_div=cfg.ch_div,
|
|
|
|
ch_div=cfg.ch_div,
|
|
|
|
|
|
|
|
reg=cfg.reg,
|
|
|
|
|
|
|
|
extra_conv=cfg.extra_conv,
|
|
|
|
|
|
|
|
skipinit=cfg.skipinit,
|
|
|
|
attn_layer=attn_layer,
|
|
|
|
attn_layer=attn_layer,
|
|
|
|
attn_gain=cfg.attn_gain,
|
|
|
|
attn_gain=cfg.attn_gain,
|
|
|
|
act_layer=act_layer,
|
|
|
|
act_layer=act_layer,
|
|
|
|
conv_layer=conv_layer,
|
|
|
|
conv_layer=conv_layer,
|
|
|
|
drop_path_rate=dpr[stage_idx][block_idx],
|
|
|
|
drop_path_rate=drop_path_rates[stage_idx][block_idx],
|
|
|
|
skipinit=cfg.skipinit,
|
|
|
|
|
|
|
|
)]
|
|
|
|
)]
|
|
|
|
if block_idx == 0:
|
|
|
|
if block_idx == 0:
|
|
|
|
expected_var = 1. # expected var is reset after first block of each stage
|
|
|
|
expected_var = 1. # expected var is reset after first block of each stage
|
|
|
@ -347,22 +472,25 @@ class NormalizerFreeNet(nn.Module):
|
|
|
|
stages += [nn.Sequential(*blocks)]
|
|
|
|
stages += [nn.Sequential(*blocks)]
|
|
|
|
self.stages = nn.Sequential(*stages)
|
|
|
|
self.stages = nn.Sequential(*stages)
|
|
|
|
|
|
|
|
|
|
|
|
if cfg.efficient and cfg.num_features:
|
|
|
|
if cfg.num_features:
|
|
|
|
# The paper NFRegNet models have an EfficientNet-like final head convolution.
|
|
|
|
# The paper NFRegNet models have an EfficientNet-like final head convolution.
|
|
|
|
self.num_features = make_divisible(cfg.width_factor * cfg.num_features, cfg.ch_div)
|
|
|
|
self.num_features = make_divisible(cfg.width_factor * cfg.num_features, cfg.ch_div)
|
|
|
|
self.final_conv = conv_layer(prev_chs, self.num_features, 1)
|
|
|
|
self.final_conv = conv_layer(prev_chs, self.num_features, 1)
|
|
|
|
|
|
|
|
# FIXME not 100% clear on gamma subtleties final conv/final act in case where it's pushed into stdconv
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
self.num_features = prev_chs
|
|
|
|
self.num_features = prev_chs
|
|
|
|
self.final_conv = nn.Identity()
|
|
|
|
self.final_conv = nn.Identity()
|
|
|
|
# FIXME not 100% clear on gamma subtleties final conv/final act in case where it's in stdconv
|
|
|
|
self.final_act = act_layer(inplace=cfg.num_features > 0)
|
|
|
|
self.final_act = act_layer()
|
|
|
|
|
|
|
|
self.feature_info += [dict(num_chs=self.num_features, reduction=net_stride, module='final_act')]
|
|
|
|
self.feature_info += [dict(num_chs=self.num_features, reduction=net_stride, module='final_act')]
|
|
|
|
|
|
|
|
|
|
|
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
|
|
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
|
|
|
|
|
|
|
|
|
|
|
for n, m in self.named_modules():
|
|
|
|
for n, m in self.named_modules():
|
|
|
|
if 'fc' in n and isinstance(m, nn.Linear):
|
|
|
|
if 'fc' in n and isinstance(m, nn.Linear):
|
|
|
|
|
|
|
|
if cfg.zero_init_fc:
|
|
|
|
nn.init.zeros_(m.weight)
|
|
|
|
nn.init.zeros_(m.weight)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
nn.init.normal_(m.weight, 0., .01)
|
|
|
|
if m.bias is not None:
|
|
|
|
if m.bias is not None:
|
|
|
|
nn.init.zeros_(m.bias)
|
|
|
|
nn.init.zeros_(m.bias)
|
|
|
|
elif isinstance(m, nn.Conv2d):
|
|
|
|
elif isinstance(m, nn.Conv2d):
|
|
|
@ -395,17 +523,121 @@ def _create_normfreenet(variant, pretrained=False, **kwargs):
|
|
|
|
model_cfg = model_cfgs[variant]
|
|
|
|
model_cfg = model_cfgs[variant]
|
|
|
|
feature_cfg = dict(flatten_sequential=True)
|
|
|
|
feature_cfg = dict(flatten_sequential=True)
|
|
|
|
feature_cfg['feature_cls'] = 'hook' # pre-act models need hooks to grab feat from act1 in bottleneck blocks
|
|
|
|
feature_cfg['feature_cls'] = 'hook' # pre-act models need hooks to grab feat from act1 in bottleneck blocks
|
|
|
|
if 'pool' in model_cfg.stem_type:
|
|
|
|
if 'pool' in model_cfg.stem_type and 'deep' not in model_cfg.stem_type:
|
|
|
|
feature_cfg['out_indices'] = (1, 2, 3, 4) # no stride 2, 0 level feat for stride 4 maxpool stems in ResNet
|
|
|
|
feature_cfg['out_indices'] = (1, 2, 3, 4) # no stride 2 feat for stride 4, 1 layer maxpool stems
|
|
|
|
|
|
|
|
|
|
|
|
return build_model_with_cfg(
|
|
|
|
return build_model_with_cfg(
|
|
|
|
NormalizerFreeNet, variant, pretrained,
|
|
|
|
NormFreeNet, variant, pretrained,
|
|
|
|
default_cfg=default_cfgs[variant],
|
|
|
|
default_cfg=default_cfgs[variant],
|
|
|
|
model_cfg=model_cfg,
|
|
|
|
model_cfg=model_cfg,
|
|
|
|
feature_cfg=feature_cfg,
|
|
|
|
feature_cfg=feature_cfg,
|
|
|
|
**kwargs)
|
|
|
|
**kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def nfnet_f0(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
return _create_normfreenet('nfnet_f0', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def nfnet_f1(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
return _create_normfreenet('nfnet_f1', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def nfnet_f2(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
return _create_normfreenet('nfnet_f2', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def nfnet_f3(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
return _create_normfreenet('nfnet_f3', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def nfnet_f4(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
return _create_normfreenet('nfnet_f4', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def nfnet_f5(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
return _create_normfreenet('nfnet_f5', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def nfnet_f6(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
return _create_normfreenet('nfnet_f6', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def nfnet_f7(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
return _create_normfreenet('nfnet_f7', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def nfnet_f0s(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
return _create_normfreenet('nfnet_f0s', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def nfnet_f1s(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
return _create_normfreenet('nfnet_f1s', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def nfnet_f2s(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
return _create_normfreenet('nfnet_f2s', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def nfnet_f3s(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
return _create_normfreenet('nfnet_f3s', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def nfnet_f4s(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
return _create_normfreenet('nfnet_f4s', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def nfnet_f5s(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
return _create_normfreenet('nfnet_f5s', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def nfnet_f6s(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
return _create_normfreenet('nfnet_f6s', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def nfnet_f7s(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
return _create_normfreenet('nfnet_f7s', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def nf_regnet_b0(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
return _create_normfreenet('nf_regnet_b0', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def nf_regnet_b1(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
return _create_normfreenet('nf_regnet_b1', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def nf_regnet_b2(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
return _create_normfreenet('nf_regnet_b2', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def nf_regnet_b3(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
return _create_normfreenet('nf_regnet_b3', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def nf_regnet_b4(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
return _create_normfreenet('nf_regnet_b4', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def nf_regnet_b5(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
return _create_normfreenet('nf_regnet_b5', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def nf_regnet_b0(pretrained=False, **kwargs):
|
|
|
|
def nf_regnet_b0(pretrained=False, **kwargs):
|
|
|
|
return _create_normfreenet('nf_regnet_b0', pretrained=pretrained, **kwargs)
|
|
|
|
return _create_normfreenet('nf_regnet_b0', pretrained=pretrained, **kwargs)
|
|
|
@ -475,6 +707,7 @@ def nf_ecaresnet26(pretrained=False, **kwargs):
|
|
|
|
def nf_ecaresnet50(pretrained=False, **kwargs):
|
|
|
|
def nf_ecaresnet50(pretrained=False, **kwargs):
|
|
|
|
return _create_normfreenet('nf_ecaresnet50', pretrained=pretrained, **kwargs)
|
|
|
|
return _create_normfreenet('nf_ecaresnet50', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def nf_ecaresnet101(pretrained=False, **kwargs):
|
|
|
|
def nf_ecaresnet101(pretrained=False, **kwargs):
|
|
|
|
return _create_normfreenet('nf_ecaresnet101', pretrained=pretrained, **kwargs)
|
|
|
|
return _create_normfreenet('nf_ecaresnet101', pretrained=pretrained, **kwargs)
|