Merge changes in feature extraction interface to MobileNetV3

Experimental feature extraction interface seems to be changed
a little bit with the most up to date version apparently found
in EfficientNet class. Here these changes are added to
MobileNetV3 class to make it support it and work again, too.
pull/123/head
Alexey Chernov 4 years ago
parent 13cf68850b
commit bdb165a8a4

@ -148,7 +148,7 @@ class MobileNetV3Features(nn.Module):
and object detection models. and object detection models.
""" """
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl', def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck',
in_chans=3, stem_size=16, channel_multiplier=1.0, output_stride=32, pad_type='', in_chans=3, stem_size=16, channel_multiplier=1.0, output_stride=32, pad_type='',
act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., se_kwargs=None, act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., se_kwargs=None,
norm_layer=nn.BatchNorm2d, norm_kwargs=None): norm_layer=nn.BatchNorm2d, norm_kwargs=None):
@ -174,18 +174,23 @@ class MobileNetV3Features(nn.Module):
channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs, channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs,
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG) norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
self.blocks = nn.Sequential(*builder(self._in_chs, block_args)) self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
self.feature_info = builder.features # builder provides info about feature channels for each block self._feature_info = builder.features # builder provides info about feature channels for each block
self._stage_to_feature_idx = {
v['stage_idx']: fi for fi, v in self._feature_info.items() if fi in self.out_indices}
self._in_chs = builder.in_chs self._in_chs = builder.in_chs
efficientnet_init_weights(self) efficientnet_init_weights(self)
if _DEBUG: if _DEBUG:
for k, v in self.feature_info.items(): for k, v in self._feature_info.items():
print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs'])) print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs']))
# Register feature extraction hooks with FeatureHooks helper # Register feature extraction hooks with FeatureHooks helper
hook_type = 'forward_pre' if feature_location == 'pre_pwl' else 'forward' self.feature_hooks = None
hooks = [dict(name=self.feature_info[idx]['name'], type=hook_type) for idx in out_indices] if feature_location != 'bottleneck':
self.feature_hooks = FeatureHooks(hooks, self.named_modules()) hooks = [dict(
name=self._feature_info[idx]['module'],
type=self._feature_info[idx]['hook_type']) for idx in out_indices]
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
def feature_channels(self, idx=None): def feature_channels(self, idx=None):
""" Feature Channel Shortcut """ Feature Channel Shortcut
@ -193,15 +198,23 @@ class MobileNetV3Features(nn.Module):
return feature channel count for that feature block index (independent of out_indices setting). return feature channel count for that feature block index (independent of out_indices setting).
""" """
if isinstance(idx, int): if isinstance(idx, int):
return self.feature_info[idx]['num_chs'] return self._feature_info[idx]['num_chs']
return [self.feature_info[i]['num_chs'] for i in self.out_indices] return [self._feature_info[i]['num_chs'] for i in self.out_indices]
def forward(self, x): def forward(self, x):
x = self.conv_stem(x) x = self.conv_stem(x)
x = self.bn1(x) x = self.bn1(x)
x = self.act1(x) x = self.act1(x)
self.blocks(x) if self.feature_hooks is None:
return self.feature_hooks.get_output(x.device) features = []
for i, b in enumerate(self.blocks):
x = b(x)
if i in self._stage_to_feature_idx:
features.append(x)
return features
else:
self.blocks(x)
return self.feature_hooks.get_output(x.device)
def _create_model(model_kwargs, default_cfg, pretrained=False): def _create_model(model_kwargs, default_cfg, pretrained=False):

Loading…
Cancel
Save