diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index fe90767c..86ca9f7a 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -148,7 +148,7 @@ class MobileNetV3Features(nn.Module): and object detection models. """ - def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl', + def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3, stem_size=16, channel_multiplier=1.0, output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None): @@ -174,18 +174,23 @@ class MobileNetV3Features(nn.Module): channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs, norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG) self.blocks = nn.Sequential(*builder(self._in_chs, block_args)) - self.feature_info = builder.features # builder provides info about feature channels for each block + self._feature_info = builder.features # builder provides info about feature channels for each block + self._stage_to_feature_idx = { + v['stage_idx']: fi for fi, v in self._feature_info.items() if fi in self.out_indices} self._in_chs = builder.in_chs efficientnet_init_weights(self) if _DEBUG: - for k, v in self.feature_info.items(): + for k, v in self._feature_info.items(): print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs'])) # Register feature extraction hooks with FeatureHooks helper - hook_type = 'forward_pre' if feature_location == 'pre_pwl' else 'forward' - hooks = [dict(name=self.feature_info[idx]['name'], type=hook_type) for idx in out_indices] - self.feature_hooks = FeatureHooks(hooks, self.named_modules()) + self.feature_hooks = None + if feature_location != 'bottleneck': + hooks = [dict( + name=self._feature_info[idx]['module'], + type=self._feature_info[idx]['hook_type']) for idx in out_indices] + self.feature_hooks = FeatureHooks(hooks, self.named_modules()) def feature_channels(self, idx=None): """ Feature Channel Shortcut @@ -193,15 +198,23 @@ class MobileNetV3Features(nn.Module): return feature channel count for that feature block index (independent of out_indices setting). """ if isinstance(idx, int): - return self.feature_info[idx]['num_chs'] - return [self.feature_info[i]['num_chs'] for i in self.out_indices] + return self._feature_info[idx]['num_chs'] + return [self._feature_info[i]['num_chs'] for i in self.out_indices] def forward(self, x): x = self.conv_stem(x) x = self.bn1(x) x = self.act1(x) - self.blocks(x) - return self.feature_hooks.get_output(x.device) + if self.feature_hooks is None: + features = [] + for i, b in enumerate(self.blocks): + x = b(x) + if i in self._stage_to_feature_idx: + features.append(x) + return features + else: + self.blocks(x) + return self.feature_hooks.get_output(x.device) def _create_model(model_kwargs, default_cfg, pretrained=False):