Add Single-Path NAS pixel1 model

pull/1/head
Ross Wightman 6 years ago
parent 419555be62
commit 34cd76899f

@ -5,6 +5,7 @@ A generic MobileNet class with building blocks to support a variety of models:
* MobileNetV2 * MobileNetV2
* FBNet-C (TODO A & B) * FBNet-C (TODO A & B)
* ChamNet (TODO still guessing at architecture definition) * ChamNet (TODO still guessing at architecture definition)
* Single-Path NAS Pixel1
* ShuffleNetV2 (TODO add IR shuffle block) * ShuffleNetV2 (TODO add IR shuffle block)
* And likely more... * And likely more...
@ -25,8 +26,9 @@ from models.adaptive_avgmax_pool import SelectAdaptivePool2d
from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = ['GenMobileNet', 'mnasnet0_50', 'mnasnet0_75', 'mnasnet1_00', 'mnasnet1_40', __all__ = ['GenMobileNet', 'mnasnet0_50', 'mnasnet0_75', 'mnasnet1_00', 'mnasnet1_40',
'semnasnet0_50', 'semnasnet0_75', 'semnasnet1_00', 'semnasnet1_40', 'semnasnet0_50', 'semnasnet0_75', 'semnasnet1_00', 'semnasnet1_40', 'mnasnet_small',
'mnasnet_small'] 'mobilenetv1_1_00', 'mobilenetv2_1_00', 'chamnetv1_1_00', 'chamnetv2_1_00',
'fbnetc_1_00', 'spnasnet1_00']
def _cfg(url='', **kwargs): def _cfg(url='', **kwargs):
@ -54,6 +56,7 @@ default_cfgs = {
'chamnetv1_1_00': _cfg(url=''), 'chamnetv1_1_00': _cfg(url=''),
'chamnetv2_1_00': _cfg(url=''), 'chamnetv2_1_00': _cfg(url=''),
'fbnetc_1_00': _cfg(url=''), 'fbnetc_1_00': _cfg(url=''),
'spnasnet1_00': _cfg(url=''),
} }
_DEBUG = True _DEBUG = True
@ -476,6 +479,7 @@ class GenMobileNet(nn.Module):
* MNASNet A1, B1, and small * MNASNet A1, B1, and small
* FBNet A, B, and C * FBNet A, B, and C
* ChamNet (arch details are murky) * ChamNet (arch details are murky)
* Single-Path NAS Pixel1
""" """
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280, def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280,
@ -820,6 +824,45 @@ def _gen_fbnetc(depth_multiplier, num_classes=1000, **kwargs):
return model return model
def _gen_spnasnet(depth_multiplier, num_classes=1000, **kwargs):
"""Creates the Single-Path NAS model from search targeted for Pixel1 phone.
Paper: https://arxiv.org/abs/1904.02877
Args:
depth_multiplier: multiplier to number of channels per layer.
"""
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s1_c16_noskip'],
# stage 1, 112x112 in
['ir_r3_k3_s2_e3_c24'],
# stage 2, 56x56 in
['ir_r1_k5_s2_e6_c40', 'ir_r3_k3_s1_e3_c40'],
# stage 3, 28x28 in
['ir_r1_k5_s2_e6_c80', 'ir_r3_k3_s1_e3_c80'],
# stage 4, 14x14in
['ir_r1_k5_s1_e6_c96', 'ir_r3_k5_s1_e3_c96'],
# stage 5, 14x14in
['ir_r4_k5_s2_e6_c192'],
# stage 6, 7x7 in
['ir_r1_k3_s1_e6_c320_noskip']
]
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
model = GenMobileNet(
arch_def,
num_classes=num_classes,
stem_size=32,
depth_multiplier=depth_multiplier,
depth_divisor=8,
min_depth=None,
bn_momentum=bn_momentum,
bn_eps=bn_eps,
**kwargs
)
return model
def mnasnet0_50(num_classes=1000, in_chans=3, pretrained=False, **kwargs): def mnasnet0_50(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
""" MNASNet B1, depth multiplier of 0.5. """ """ MNASNet B1, depth multiplier of 0.5. """
default_cfg = default_cfgs['mnasnet0_50'] default_cfg = default_cfgs['mnasnet0_50']
@ -958,3 +1001,13 @@ def chamnetv2_1_00(num_classes, in_chans=3, pretrained=False, **kwargs):
if pretrained: if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans) load_pretrained(model, default_cfg, num_classes, in_chans)
return model return model
def spnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs):
""" Single-Path NAS Pixel1"""
default_cfg = default_cfgs['spnasnet1_00']
model = _gen_spnasnet(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model

@ -11,7 +11,8 @@ from models.pnasnet import pnasnet5large
from models.genmobilenet import \ from models.genmobilenet import \
mnasnet0_50, mnasnet0_75, mnasnet1_00, mnasnet1_40,\ mnasnet0_50, mnasnet0_75, mnasnet1_00, mnasnet1_40,\
semnasnet0_50, semnasnet0_75, semnasnet1_00, semnasnet1_40, mnasnet_small,\ semnasnet0_50, semnasnet0_75, semnasnet1_00, semnasnet1_40, mnasnet_small,\
mobilenetv1_1_00, mobilenetv2_1_00, fbnetc_1_00, chamnetv1_1_00, chamnetv2_1_00 mobilenetv1_1_00, mobilenetv2_1_00, fbnetc_1_00, chamnetv1_1_00, chamnetv2_1_00,\
spnasnet1_00
from models.helpers import load_checkpoint from models.helpers import load_checkpoint

Loading…
Cancel
Save