From c3fbdd465516c1d4c4ac9e62a5bbbf6aa4c74e71 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 11 May 2019 11:23:11 -0700 Subject: [PATCH] Fix efficient head for MobileNetV3 --- models/genmobilenet.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/models/genmobilenet.py b/models/genmobilenet.py index 472d43e4..34caef14 100644 --- a/models/genmobilenet.py +++ b/models/genmobilenet.py @@ -290,6 +290,7 @@ class _BlockBuilder: ba['bn_eps'] = self.bn_eps ba['folded_bn'] = self.folded_bn ba['padding_same'] = self.padding_same + # block act fn overrides the model default ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn assert ba['act_fn'] is not None if _DEBUG: @@ -611,15 +612,14 @@ class GenMobileNet(nn.Module): depth_multiplier=1.0, depth_divisor=8, min_depth=None, bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT, drop_rate=0., act_fn=F.relu, se_gate_fn=torch.sigmoid, se_reduce_mid=False, - global_pool='avg', skip_head_conv=False, efficient_head=False, - weight_init='goog', folded_bn=False, padding_same=False): + global_pool='avg', head_conv='default', weight_init='goog', + folded_bn=False, padding_same=False): super(GenMobileNet, self).__init__() self.num_classes = num_classes self.depth_multiplier = depth_multiplier self.drop_rate = drop_rate self.act_fn = act_fn self.num_features = num_features - self.efficient_head = efficient_head # pool before last conv stem_size = _round_channels(stem_size, depth_multiplier, depth_divisor, min_depth) self.conv_stem = sconv2d( @@ -629,19 +629,22 @@ class GenMobileNet(nn.Module): in_chs = stem_size builder = _BlockBuilder( - depth_multiplier, depth_divisor, min_depth, act_fn, se_gate_fn, se_reduce_mid, + depth_multiplier, depth_divisor, min_depth, + act_fn, se_gate_fn, se_reduce_mid, bn_momentum, bn_eps, folded_bn, padding_same) self.blocks = nn.Sequential(*builder(in_chs, block_args)) in_chs = builder.in_chs - if skip_head_conv: + if not head_conv or head_conv == 'none': + self.efficient_head = False self.conv_head = None assert in_chs == self.num_features else: + self.efficient_head = head_conv == 'efficient' self.conv_head = sconv2d( in_chs, self.num_features, 1, - padding=_padding_arg(0, padding_same), bias=folded_bn and not efficient_head) - self.bn2 = None if (folded_bn or efficient_head) else \ + padding=_padding_arg(0, padding_same), bias=folded_bn and not self.efficient_head) + self.bn2 = None if (folded_bn or self.efficient_head) else \ nn.BatchNorm2d(self.num_features, momentum=bn_momentum, eps=bn_eps) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) @@ -674,7 +677,7 @@ class GenMobileNet(nn.Module): x = self.blocks(x) if self.efficient_head: # efficient head, currently only mobilenet-v3 performs pool before last 1x1 conv - x = self.global_pool(x) # always need to pool here regardless of bool + x = self.global_pool(x) # always need to pool here regardless of flag x = self.conv_head(x) # no BN x = self.act_fn(x) @@ -836,7 +839,7 @@ def _gen_mobilenet_v1(depth_multiplier, num_classes=1000, **kwargs): bn_momentum=bn_momentum, bn_eps=bn_eps, act_fn=F.relu6, - skip_head_conv=True, + head_conv='none', **kwargs ) return model @@ -914,6 +917,7 @@ def _gen_mobilenet_v3(depth_multiplier, num_classes=1000, **kwargs): act_fn=hard_swish, se_gate_fn=hard_sigmoid, se_reduce_mid=True, + head_conv='efficient', **kwargs ) return model