Fix efficient head for MobileNetV3

pull/1/head
Ross Wightman 6 years ago
parent 17da1adaca
commit c3fbdd4655

@ -290,6 +290,7 @@ class _BlockBuilder:
ba['bn_eps'] = self.bn_eps ba['bn_eps'] = self.bn_eps
ba['folded_bn'] = self.folded_bn ba['folded_bn'] = self.folded_bn
ba['padding_same'] = self.padding_same ba['padding_same'] = self.padding_same
# block act fn overrides the model default
ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn
assert ba['act_fn'] is not None assert ba['act_fn'] is not None
if _DEBUG: if _DEBUG:
@ -611,15 +612,14 @@ class GenMobileNet(nn.Module):
depth_multiplier=1.0, depth_divisor=8, min_depth=None, depth_multiplier=1.0, depth_divisor=8, min_depth=None,
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT, bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
drop_rate=0., act_fn=F.relu, se_gate_fn=torch.sigmoid, se_reduce_mid=False, drop_rate=0., act_fn=F.relu, se_gate_fn=torch.sigmoid, se_reduce_mid=False,
global_pool='avg', skip_head_conv=False, efficient_head=False, global_pool='avg', head_conv='default', weight_init='goog',
weight_init='goog', folded_bn=False, padding_same=False): folded_bn=False, padding_same=False):
super(GenMobileNet, self).__init__() super(GenMobileNet, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.depth_multiplier = depth_multiplier self.depth_multiplier = depth_multiplier
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.act_fn = act_fn self.act_fn = act_fn
self.num_features = num_features self.num_features = num_features
self.efficient_head = efficient_head # pool before last conv
stem_size = _round_channels(stem_size, depth_multiplier, depth_divisor, min_depth) stem_size = _round_channels(stem_size, depth_multiplier, depth_divisor, min_depth)
self.conv_stem = sconv2d( self.conv_stem = sconv2d(
@ -629,19 +629,22 @@ class GenMobileNet(nn.Module):
in_chs = stem_size in_chs = stem_size
builder = _BlockBuilder( builder = _BlockBuilder(
depth_multiplier, depth_divisor, min_depth, act_fn, se_gate_fn, se_reduce_mid, depth_multiplier, depth_divisor, min_depth,
act_fn, se_gate_fn, se_reduce_mid,
bn_momentum, bn_eps, folded_bn, padding_same) bn_momentum, bn_eps, folded_bn, padding_same)
self.blocks = nn.Sequential(*builder(in_chs, block_args)) self.blocks = nn.Sequential(*builder(in_chs, block_args))
in_chs = builder.in_chs in_chs = builder.in_chs
if skip_head_conv: if not head_conv or head_conv == 'none':
self.efficient_head = False
self.conv_head = None self.conv_head = None
assert in_chs == self.num_features assert in_chs == self.num_features
else: else:
self.efficient_head = head_conv == 'efficient'
self.conv_head = sconv2d( self.conv_head = sconv2d(
in_chs, self.num_features, 1, in_chs, self.num_features, 1,
padding=_padding_arg(0, padding_same), bias=folded_bn and not efficient_head) padding=_padding_arg(0, padding_same), bias=folded_bn and not self.efficient_head)
self.bn2 = None if (folded_bn or efficient_head) else \ self.bn2 = None if (folded_bn or self.efficient_head) else \
nn.BatchNorm2d(self.num_features, momentum=bn_momentum, eps=bn_eps) nn.BatchNorm2d(self.num_features, momentum=bn_momentum, eps=bn_eps)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
@ -674,7 +677,7 @@ class GenMobileNet(nn.Module):
x = self.blocks(x) x = self.blocks(x)
if self.efficient_head: if self.efficient_head:
# efficient head, currently only mobilenet-v3 performs pool before last 1x1 conv # efficient head, currently only mobilenet-v3 performs pool before last 1x1 conv
x = self.global_pool(x) # always need to pool here regardless of bool x = self.global_pool(x) # always need to pool here regardless of flag
x = self.conv_head(x) x = self.conv_head(x)
# no BN # no BN
x = self.act_fn(x) x = self.act_fn(x)
@ -836,7 +839,7 @@ def _gen_mobilenet_v1(depth_multiplier, num_classes=1000, **kwargs):
bn_momentum=bn_momentum, bn_momentum=bn_momentum,
bn_eps=bn_eps, bn_eps=bn_eps,
act_fn=F.relu6, act_fn=F.relu6,
skip_head_conv=True, head_conv='none',
**kwargs **kwargs
) )
return model return model
@ -914,6 +917,7 @@ def _gen_mobilenet_v3(depth_multiplier, num_classes=1000, **kwargs):
act_fn=hard_swish, act_fn=hard_swish,
se_gate_fn=hard_sigmoid, se_gate_fn=hard_sigmoid,
se_reduce_mid=True, se_reduce_mid=True,
head_conv='efficient',
**kwargs **kwargs
) )
return model return model

Loading…
Cancel
Save