Merge pull request #123 from aclex/mobilenetv3_fix_feature_extraction

Merge changes in feature extraction interface to MobileNetV3
pull/125/head
Ross Wightman 4 years ago committed by GitHub
commit e15f979457
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -148,7 +148,7 @@ class MobileNetV3Features(nn.Module):
and object detection models.
"""
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl',
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck',
in_chans=3, stem_size=16, channel_multiplier=1.0, output_stride=32, pad_type='',
act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., se_kwargs=None,
norm_layer=nn.BatchNorm2d, norm_kwargs=None):
@ -174,18 +174,23 @@ class MobileNetV3Features(nn.Module):
channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs,
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
self.feature_info = builder.features # builder provides info about feature channels for each block
self._feature_info = builder.features # builder provides info about feature channels for each block
self._stage_to_feature_idx = {
v['stage_idx']: fi for fi, v in self._feature_info.items() if fi in self.out_indices}
self._in_chs = builder.in_chs
efficientnet_init_weights(self)
if _DEBUG:
for k, v in self.feature_info.items():
for k, v in self._feature_info.items():
print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs']))
# Register feature extraction hooks with FeatureHooks helper
hook_type = 'forward_pre' if feature_location == 'pre_pwl' else 'forward'
hooks = [dict(name=self.feature_info[idx]['name'], type=hook_type) for idx in out_indices]
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
self.feature_hooks = None
if feature_location != 'bottleneck':
hooks = [dict(
name=self._feature_info[idx]['module'],
type=self._feature_info[idx]['hook_type']) for idx in out_indices]
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
def feature_channels(self, idx=None):
""" Feature Channel Shortcut
@ -193,15 +198,23 @@ class MobileNetV3Features(nn.Module):
return feature channel count for that feature block index (independent of out_indices setting).
"""
if isinstance(idx, int):
return self.feature_info[idx]['num_chs']
return [self.feature_info[i]['num_chs'] for i in self.out_indices]
return self._feature_info[idx]['num_chs']
return [self._feature_info[i]['num_chs'] for i in self.out_indices]
def forward(self, x):
x = self.conv_stem(x)
x = self.bn1(x)
x = self.act1(x)
self.blocks(x)
return self.feature_hooks.get_output(x.device)
if self.feature_hooks is None:
features = []
for i, b in enumerate(self.blocks):
x = b(x)
if i in self._stage_to_feature_idx:
features.append(x)
return features
else:
self.blocks(x)
return self.feature_hooks.get_output(x.device)
def _create_model(model_kwargs, default_cfg, pretrained=False):

Loading…
Cancel
Save