A few MobileNetV3 tweaks

* fix expansion ratio on early block
* change comment re stride mistake in paper
* fix rounding not being called properly for all multipliers != 1.0
pull/1/head
Ross Wightman 6 years ago
parent 6523e4abe4
commit 17da1adaca

@ -285,7 +285,7 @@ class _BlockBuilder:
def _make_block(self, ba): def _make_block(self, ba):
bt = ba.pop('block_type') bt = ba.pop('block_type')
ba['in_chs'] = self.in_chs ba['in_chs'] = self.in_chs
ba['out_chs'] = _round_channels(ba['out_chs']) ba['out_chs'] = self._round_channels(ba['out_chs'])
ba['bn_momentum'] = self.bn_momentum ba['bn_momentum'] = self.bn_momentum
ba['bn_eps'] = self.bn_eps ba['bn_eps'] = self.bn_eps
ba['folded_bn'] = self.folded_bn ba['folded_bn'] = self.folded_bn
@ -676,6 +676,7 @@ class GenMobileNet(nn.Module):
# efficient head, currently only mobilenet-v3 performs pool before last 1x1 conv # efficient head, currently only mobilenet-v3 performs pool before last 1x1 conv
x = self.global_pool(x) # always need to pool here regardless of bool x = self.global_pool(x) # always need to pool here regardless of bool
x = self.conv_head(x) x = self.conv_head(x)
# no BN
x = self.act_fn(x) x = self.act_fn(x)
if pool: if pool:
# expect flattened output if pool is true, otherwise keep dim # expect flattened output if pool is true, otherwise keep dim
@ -884,7 +885,7 @@ def _gen_mobilenet_v3(depth_multiplier, num_classes=1000, **kwargs):
# stage 0, 112x112 in # stage 0, 112x112 in
['ds_r1_k3_s1_e1_c16_are_noskip'], # relu ['ds_r1_k3_s1_e1_c16_are_noskip'], # relu
# stage 1, 112x112 in # stage 1, 112x112 in
['ir_r1_k3_s2_e4_c24_are', 'ir_r1_k3_s1_e6_c24_are'], # relu ['ir_r1_k3_s2_e4_c24_are', 'ir_r1_k3_s1_e3_c24_are'], # relu
# stage 2, 56x56 in # stage 2, 56x56 in
['ir_r3_k5_s2_e3_c40_se0.25_are'], # relu ['ir_r3_k5_s2_e3_c40_se0.25_are'], # relu
# stage 3, 28x28 in # stage 3, 28x28 in
@ -893,9 +894,10 @@ def _gen_mobilenet_v3(depth_multiplier, num_classes=1000, **kwargs):
# stage 4, 14x14in # stage 4, 14x14in
['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish
# stage 5, 14x14in # stage 5, 14x14in
# FIXME the paper contains a weird block-stride pattern 1-2-1 that doesn't fit the usual 2-1-... # FIXME paper has a mistaken block-stride pattern 1-2-1 that doesn't fit the usual 2-1-..., ignoring
# What is correct? # The paper numbers result in an exp factor of 4.2 in the middle of this block, but keeping at 6
['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish # results in a param count closer to 5.4m
['ir_r1_k5_s2_e6_c160_se0.25', 'ir_r1_k5_s1_e6_c160_se0.25', 'ir_r1_k5_s1_e6_c160_se0.25'], # hard-swish
# stage 6, 7x7 in # stage 6, 7x7 in
['cn_r1_k1_s1_c960'], # hard-swish ['cn_r1_k1_s1_c960'], # hard-swish
] ]

Loading…
Cancel
Save