From 34cd76899f0590c2d8232fc6ace70694487ff08c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 22 Apr 2019 12:43:45 -0700 Subject: [PATCH] Add Single-Path NAS pixel1 model --- models/genmobilenet.py | 57 +++++++++++++++++++++++++++++++++++++++-- models/model_factory.py | 3 ++- 2 files changed, 57 insertions(+), 3 deletions(-) diff --git a/models/genmobilenet.py b/models/genmobilenet.py index a4b177ad..46647fd8 100644 --- a/models/genmobilenet.py +++ b/models/genmobilenet.py @@ -5,6 +5,7 @@ A generic MobileNet class with building blocks to support a variety of models: * MobileNetV2 * FBNet-C (TODO A & B) * ChamNet (TODO still guessing at architecture definition) +* Single-Path NAS Pixel1 * ShuffleNetV2 (TODO add IR shuffle block) * And likely more... @@ -25,8 +26,9 @@ from models.adaptive_avgmax_pool import SelectAdaptivePool2d from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD __all__ = ['GenMobileNet', 'mnasnet0_50', 'mnasnet0_75', 'mnasnet1_00', 'mnasnet1_40', - 'semnasnet0_50', 'semnasnet0_75', 'semnasnet1_00', 'semnasnet1_40', - 'mnasnet_small'] + 'semnasnet0_50', 'semnasnet0_75', 'semnasnet1_00', 'semnasnet1_40', 'mnasnet_small', + 'mobilenetv1_1_00', 'mobilenetv2_1_00', 'chamnetv1_1_00', 'chamnetv2_1_00', + 'fbnetc_1_00', 'spnasnet1_00'] def _cfg(url='', **kwargs): @@ -54,6 +56,7 @@ default_cfgs = { 'chamnetv1_1_00': _cfg(url=''), 'chamnetv2_1_00': _cfg(url=''), 'fbnetc_1_00': _cfg(url=''), + 'spnasnet1_00': _cfg(url=''), } _DEBUG = True @@ -476,6 +479,7 @@ class GenMobileNet(nn.Module): * MNASNet A1, B1, and small * FBNet A, B, and C * ChamNet (arch details are murky) + * Single-Path NAS Pixel1 """ def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280, @@ -820,6 +824,45 @@ def _gen_fbnetc(depth_multiplier, num_classes=1000, **kwargs): return model +def _gen_spnasnet(depth_multiplier, num_classes=1000, **kwargs): + """Creates the Single-Path NAS model from search targeted for Pixel1 phone. + + Paper: https://arxiv.org/abs/1904.02877 + + Args: + depth_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_c16_noskip'], + # stage 1, 112x112 in + ['ir_r3_k3_s2_e3_c24'], + # stage 2, 56x56 in + ['ir_r1_k5_s2_e6_c40', 'ir_r3_k3_s1_e3_c40'], + # stage 3, 28x28 in + ['ir_r1_k5_s2_e6_c80', 'ir_r3_k3_s1_e3_c80'], + # stage 4, 14x14in + ['ir_r1_k5_s1_e6_c96', 'ir_r3_k5_s1_e3_c96'], + # stage 5, 14x14in + ['ir_r4_k5_s2_e6_c192'], + # stage 6, 7x7 in + ['ir_r1_k3_s1_e6_c320_noskip'] + ] + bn_momentum, bn_eps = _resolve_bn_params(kwargs) + model = GenMobileNet( + arch_def, + num_classes=num_classes, + stem_size=32, + depth_multiplier=depth_multiplier, + depth_divisor=8, + min_depth=None, + bn_momentum=bn_momentum, + bn_eps=bn_eps, + **kwargs + ) + return model + + def mnasnet0_50(num_classes=1000, in_chans=3, pretrained=False, **kwargs): """ MNASNet B1, depth multiplier of 0.5. """ default_cfg = default_cfgs['mnasnet0_50'] @@ -958,3 +1001,13 @@ def chamnetv2_1_00(num_classes, in_chans=3, pretrained=False, **kwargs): if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) return model + + +def spnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs): + """ Single-Path NAS Pixel1""" + default_cfg = default_cfgs['spnasnet1_00'] + model = _gen_spnasnet(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model diff --git a/models/model_factory.py b/models/model_factory.py index 99cfa931..64c578ac 100644 --- a/models/model_factory.py +++ b/models/model_factory.py @@ -11,7 +11,8 @@ from models.pnasnet import pnasnet5large from models.genmobilenet import \ mnasnet0_50, mnasnet0_75, mnasnet1_00, mnasnet1_40,\ semnasnet0_50, semnasnet0_75, semnasnet1_00, semnasnet1_40, mnasnet_small,\ - mobilenetv1_1_00, mobilenetv2_1_00, fbnetc_1_00, chamnetv1_1_00, chamnetv2_1_00 + mobilenetv1_1_00, mobilenetv2_1_00, fbnetc_1_00, chamnetv1_1_00, chamnetv2_1_00,\ + spnasnet1_00 from models.helpers import load_checkpoint