From bdb165a8a43cecdd77a0fa1d092b288e2d0023fb Mon Sep 17 00:00:00 2001 From: Alexey Chernov <4ernov@gmail.com> Date: Mon, 13 Apr 2020 02:02:14 +0300 Subject: [PATCH] Merge changes in feature extraction interface to MobileNetV3 Experimental feature extraction interface seems to be changed a little bit with the most up to date version apparently found in EfficientNet class. Here these changes are added to MobileNetV3 class to make it support it and work again, too. --- timm/models/mobilenetv3.py | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index fe90767c..86ca9f7a 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -148,7 +148,7 @@ class MobileNetV3Features(nn.Module): and object detection models. """ - def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl', + def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3, stem_size=16, channel_multiplier=1.0, output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None): @@ -174,18 +174,23 @@ class MobileNetV3Features(nn.Module): channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs, norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG) self.blocks = nn.Sequential(*builder(self._in_chs, block_args)) - self.feature_info = builder.features # builder provides info about feature channels for each block + self._feature_info = builder.features # builder provides info about feature channels for each block + self._stage_to_feature_idx = { + v['stage_idx']: fi for fi, v in self._feature_info.items() if fi in self.out_indices} self._in_chs = builder.in_chs efficientnet_init_weights(self) if _DEBUG: - for k, v in self.feature_info.items(): + for k, v in self._feature_info.items(): print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs'])) # Register feature extraction hooks with FeatureHooks helper - hook_type = 'forward_pre' if feature_location == 'pre_pwl' else 'forward' - hooks = [dict(name=self.feature_info[idx]['name'], type=hook_type) for idx in out_indices] - self.feature_hooks = FeatureHooks(hooks, self.named_modules()) + self.feature_hooks = None + if feature_location != 'bottleneck': + hooks = [dict( + name=self._feature_info[idx]['module'], + type=self._feature_info[idx]['hook_type']) for idx in out_indices] + self.feature_hooks = FeatureHooks(hooks, self.named_modules()) def feature_channels(self, idx=None): """ Feature Channel Shortcut @@ -193,15 +198,23 @@ class MobileNetV3Features(nn.Module): return feature channel count for that feature block index (independent of out_indices setting). """ if isinstance(idx, int): - return self.feature_info[idx]['num_chs'] - return [self.feature_info[i]['num_chs'] for i in self.out_indices] + return self._feature_info[idx]['num_chs'] + return [self._feature_info[i]['num_chs'] for i in self.out_indices] def forward(self, x): x = self.conv_stem(x) x = self.bn1(x) x = self.act1(x) - self.blocks(x) - return self.feature_hooks.get_output(x.device) + if self.feature_hooks is None: + features = [] + for i, b in enumerate(self.blocks): + x = b(x) + if i in self._stage_to_feature_idx: + features.append(x) + return features + else: + self.blocks(x) + return self.feature_hooks.get_output(x.device) def _create_model(model_kwargs, default_cfg, pretrained=False):