From 17da1adaca899d9264a7034742f06e03846d205b Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 11 May 2019 10:23:40 -0700 Subject: [PATCH] A few MobileNetV3 tweaks * fix expansion ratio on early block * change comment re stride mistake in paper * fix rounding not being called properly for all multipliers != 1.0 --- models/genmobilenet.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/models/genmobilenet.py b/models/genmobilenet.py index 914e82cc..472d43e4 100644 --- a/models/genmobilenet.py +++ b/models/genmobilenet.py @@ -285,7 +285,7 @@ class _BlockBuilder: def _make_block(self, ba): bt = ba.pop('block_type') ba['in_chs'] = self.in_chs - ba['out_chs'] = _round_channels(ba['out_chs']) + ba['out_chs'] = self._round_channels(ba['out_chs']) ba['bn_momentum'] = self.bn_momentum ba['bn_eps'] = self.bn_eps ba['folded_bn'] = self.folded_bn @@ -676,6 +676,7 @@ class GenMobileNet(nn.Module): # efficient head, currently only mobilenet-v3 performs pool before last 1x1 conv x = self.global_pool(x) # always need to pool here regardless of bool x = self.conv_head(x) + # no BN x = self.act_fn(x) if pool: # expect flattened output if pool is true, otherwise keep dim @@ -884,7 +885,7 @@ def _gen_mobilenet_v3(depth_multiplier, num_classes=1000, **kwargs): # stage 0, 112x112 in ['ds_r1_k3_s1_e1_c16_are_noskip'], # relu # stage 1, 112x112 in - ['ir_r1_k3_s2_e4_c24_are', 'ir_r1_k3_s1_e6_c24_are'], # relu + ['ir_r1_k3_s2_e4_c24_are', 'ir_r1_k3_s1_e3_c24_are'], # relu # stage 2, 56x56 in ['ir_r3_k5_s2_e3_c40_se0.25_are'], # relu # stage 3, 28x28 in @@ -893,9 +894,10 @@ def _gen_mobilenet_v3(depth_multiplier, num_classes=1000, **kwargs): # stage 4, 14x14in ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish # stage 5, 14x14in - # FIXME the paper contains a weird block-stride pattern 1-2-1 that doesn't fit the usual 2-1-... - # What is correct? - ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish + # FIXME paper has a mistaken block-stride pattern 1-2-1 that doesn't fit the usual 2-1-..., ignoring + # The paper numbers result in an exp factor of 4.2 in the middle of this block, but keeping at 6 + # results in a param count closer to 5.4m + ['ir_r1_k5_s2_e6_c160_se0.25', 'ir_r1_k5_s1_e6_c160_se0.25', 'ir_r1_k5_s1_e6_c160_se0.25'], # hard-swish # stage 6, 7x7 in ['cn_r1_k1_s1_c960'], # hard-swish ]