Merge changes in feature extraction interface to MobileNetV3

Experimental feature extraction interface seems to be changed
a little bit with the most up to date version apparently found
in EfficientNet class. Here these changes are added to
MobileNetV3 class to make it support it and work again, too.
pull/123/head
Alexey Chernov 5 years ago
parent 13cf68850b
commit bdb165a8a4

@ -148,7 +148,7 @@ class MobileNetV3Features(nn.Module):
and object detection models.
"""
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl',
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck',
in_chans=3, stem_size=16, channel_multiplier=1.0, output_stride=32, pad_type='',
act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., se_kwargs=None,
norm_layer=nn.BatchNorm2d, norm_kwargs=None):
@ -174,18 +174,23 @@ class MobileNetV3Features(nn.Module):
channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs,
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
self.feature_info = builder.features # builder provides info about feature channels for each block
self._feature_info = builder.features # builder provides info about feature channels for each block
self._stage_to_feature_idx = {
v['stage_idx']: fi for fi, v in self._feature_info.items() if fi in self.out_indices}
self._in_chs = builder.in_chs
efficientnet_init_weights(self)
if _DEBUG:
for k, v in self.feature_info.items():
for k, v in self._feature_info.items():
print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs']))
# Register feature extraction hooks with FeatureHooks helper
hook_type = 'forward_pre' if feature_location == 'pre_pwl' else 'forward'
hooks = [dict(name=self.feature_info[idx]['name'], type=hook_type) for idx in out_indices]
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
self.feature_hooks = None
if feature_location != 'bottleneck':
hooks = [dict(
name=self._feature_info[idx]['module'],
type=self._feature_info[idx]['hook_type']) for idx in out_indices]
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
def feature_channels(self, idx=None):
""" Feature Channel Shortcut
@ -193,15 +198,23 @@ class MobileNetV3Features(nn.Module):
return feature channel count for that feature block index (independent of out_indices setting).
"""
if isinstance(idx, int):
return self.feature_info[idx]['num_chs']
return [self.feature_info[i]['num_chs'] for i in self.out_indices]
return self._feature_info[idx]['num_chs']
return [self._feature_info[i]['num_chs'] for i in self.out_indices]
def forward(self, x):
x = self.conv_stem(x)
x = self.bn1(x)
x = self.act1(x)
self.blocks(x)
return self.feature_hooks.get_output(x.device)
if self.feature_hooks is None:
features = []
for i, b in enumerate(self.blocks):
x = b(x)
if i in self._stage_to_feature_idx:
features.append(x)
return features
else:
self.blocks(x)
return self.feature_hooks.get_output(x.device)
def _create_model(model_kwargs, default_cfg, pretrained=False):

Loading…
Cancel
Save