|
|
@ -290,6 +290,7 @@ class _BlockBuilder:
|
|
|
|
ba['bn_eps'] = self.bn_eps
|
|
|
|
ba['bn_eps'] = self.bn_eps
|
|
|
|
ba['folded_bn'] = self.folded_bn
|
|
|
|
ba['folded_bn'] = self.folded_bn
|
|
|
|
ba['padding_same'] = self.padding_same
|
|
|
|
ba['padding_same'] = self.padding_same
|
|
|
|
|
|
|
|
# block act fn overrides the model default
|
|
|
|
ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn
|
|
|
|
ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn
|
|
|
|
assert ba['act_fn'] is not None
|
|
|
|
assert ba['act_fn'] is not None
|
|
|
|
if _DEBUG:
|
|
|
|
if _DEBUG:
|
|
|
@ -611,15 +612,14 @@ class GenMobileNet(nn.Module):
|
|
|
|
depth_multiplier=1.0, depth_divisor=8, min_depth=None,
|
|
|
|
depth_multiplier=1.0, depth_divisor=8, min_depth=None,
|
|
|
|
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
|
|
|
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
|
|
|
drop_rate=0., act_fn=F.relu, se_gate_fn=torch.sigmoid, se_reduce_mid=False,
|
|
|
|
drop_rate=0., act_fn=F.relu, se_gate_fn=torch.sigmoid, se_reduce_mid=False,
|
|
|
|
global_pool='avg', skip_head_conv=False, efficient_head=False,
|
|
|
|
global_pool='avg', head_conv='default', weight_init='goog',
|
|
|
|
weight_init='goog', folded_bn=False, padding_same=False):
|
|
|
|
folded_bn=False, padding_same=False):
|
|
|
|
super(GenMobileNet, self).__init__()
|
|
|
|
super(GenMobileNet, self).__init__()
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.depth_multiplier = depth_multiplier
|
|
|
|
self.depth_multiplier = depth_multiplier
|
|
|
|
self.drop_rate = drop_rate
|
|
|
|
self.drop_rate = drop_rate
|
|
|
|
self.act_fn = act_fn
|
|
|
|
self.act_fn = act_fn
|
|
|
|
self.num_features = num_features
|
|
|
|
self.num_features = num_features
|
|
|
|
self.efficient_head = efficient_head # pool before last conv
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
stem_size = _round_channels(stem_size, depth_multiplier, depth_divisor, min_depth)
|
|
|
|
stem_size = _round_channels(stem_size, depth_multiplier, depth_divisor, min_depth)
|
|
|
|
self.conv_stem = sconv2d(
|
|
|
|
self.conv_stem = sconv2d(
|
|
|
@ -629,19 +629,22 @@ class GenMobileNet(nn.Module):
|
|
|
|
in_chs = stem_size
|
|
|
|
in_chs = stem_size
|
|
|
|
|
|
|
|
|
|
|
|
builder = _BlockBuilder(
|
|
|
|
builder = _BlockBuilder(
|
|
|
|
depth_multiplier, depth_divisor, min_depth, act_fn, se_gate_fn, se_reduce_mid,
|
|
|
|
depth_multiplier, depth_divisor, min_depth,
|
|
|
|
|
|
|
|
act_fn, se_gate_fn, se_reduce_mid,
|
|
|
|
bn_momentum, bn_eps, folded_bn, padding_same)
|
|
|
|
bn_momentum, bn_eps, folded_bn, padding_same)
|
|
|
|
self.blocks = nn.Sequential(*builder(in_chs, block_args))
|
|
|
|
self.blocks = nn.Sequential(*builder(in_chs, block_args))
|
|
|
|
in_chs = builder.in_chs
|
|
|
|
in_chs = builder.in_chs
|
|
|
|
|
|
|
|
|
|
|
|
if skip_head_conv:
|
|
|
|
if not head_conv or head_conv == 'none':
|
|
|
|
|
|
|
|
self.efficient_head = False
|
|
|
|
self.conv_head = None
|
|
|
|
self.conv_head = None
|
|
|
|
assert in_chs == self.num_features
|
|
|
|
assert in_chs == self.num_features
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
|
|
|
|
self.efficient_head = head_conv == 'efficient'
|
|
|
|
self.conv_head = sconv2d(
|
|
|
|
self.conv_head = sconv2d(
|
|
|
|
in_chs, self.num_features, 1,
|
|
|
|
in_chs, self.num_features, 1,
|
|
|
|
padding=_padding_arg(0, padding_same), bias=folded_bn and not efficient_head)
|
|
|
|
padding=_padding_arg(0, padding_same), bias=folded_bn and not self.efficient_head)
|
|
|
|
self.bn2 = None if (folded_bn or efficient_head) else \
|
|
|
|
self.bn2 = None if (folded_bn or self.efficient_head) else \
|
|
|
|
nn.BatchNorm2d(self.num_features, momentum=bn_momentum, eps=bn_eps)
|
|
|
|
nn.BatchNorm2d(self.num_features, momentum=bn_momentum, eps=bn_eps)
|
|
|
|
|
|
|
|
|
|
|
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
|
|
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
|
|
@ -674,7 +677,7 @@ class GenMobileNet(nn.Module):
|
|
|
|
x = self.blocks(x)
|
|
|
|
x = self.blocks(x)
|
|
|
|
if self.efficient_head:
|
|
|
|
if self.efficient_head:
|
|
|
|
# efficient head, currently only mobilenet-v3 performs pool before last 1x1 conv
|
|
|
|
# efficient head, currently only mobilenet-v3 performs pool before last 1x1 conv
|
|
|
|
x = self.global_pool(x) # always need to pool here regardless of bool
|
|
|
|
x = self.global_pool(x) # always need to pool here regardless of flag
|
|
|
|
x = self.conv_head(x)
|
|
|
|
x = self.conv_head(x)
|
|
|
|
# no BN
|
|
|
|
# no BN
|
|
|
|
x = self.act_fn(x)
|
|
|
|
x = self.act_fn(x)
|
|
|
@ -836,7 +839,7 @@ def _gen_mobilenet_v1(depth_multiplier, num_classes=1000, **kwargs):
|
|
|
|
bn_momentum=bn_momentum,
|
|
|
|
bn_momentum=bn_momentum,
|
|
|
|
bn_eps=bn_eps,
|
|
|
|
bn_eps=bn_eps,
|
|
|
|
act_fn=F.relu6,
|
|
|
|
act_fn=F.relu6,
|
|
|
|
skip_head_conv=True,
|
|
|
|
head_conv='none',
|
|
|
|
**kwargs
|
|
|
|
**kwargs
|
|
|
|
)
|
|
|
|
)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
@ -914,6 +917,7 @@ def _gen_mobilenet_v3(depth_multiplier, num_classes=1000, **kwargs):
|
|
|
|
act_fn=hard_swish,
|
|
|
|
act_fn=hard_swish,
|
|
|
|
se_gate_fn=hard_sigmoid,
|
|
|
|
se_gate_fn=hard_sigmoid,
|
|
|
|
se_reduce_mid=True,
|
|
|
|
se_reduce_mid=True,
|
|
|
|
|
|
|
|
head_conv='efficient',
|
|
|
|
**kwargs
|
|
|
|
**kwargs
|
|
|
|
)
|
|
|
|
)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|