|
|
|
@ -148,7 +148,7 @@ class MobileNetV3Features(nn.Module):
|
|
|
|
|
and object detection models.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl',
|
|
|
|
|
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck',
|
|
|
|
|
in_chans=3, stem_size=16, channel_multiplier=1.0, output_stride=32, pad_type='',
|
|
|
|
|
act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., se_kwargs=None,
|
|
|
|
|
norm_layer=nn.BatchNorm2d, norm_kwargs=None):
|
|
|
|
@ -174,18 +174,23 @@ class MobileNetV3Features(nn.Module):
|
|
|
|
|
channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs,
|
|
|
|
|
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
|
|
|
|
|
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
|
|
|
|
|
self.feature_info = builder.features # builder provides info about feature channels for each block
|
|
|
|
|
self._feature_info = builder.features # builder provides info about feature channels for each block
|
|
|
|
|
self._stage_to_feature_idx = {
|
|
|
|
|
v['stage_idx']: fi for fi, v in self._feature_info.items() if fi in self.out_indices}
|
|
|
|
|
self._in_chs = builder.in_chs
|
|
|
|
|
|
|
|
|
|
efficientnet_init_weights(self)
|
|
|
|
|
if _DEBUG:
|
|
|
|
|
for k, v in self.feature_info.items():
|
|
|
|
|
for k, v in self._feature_info.items():
|
|
|
|
|
print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs']))
|
|
|
|
|
|
|
|
|
|
# Register feature extraction hooks with FeatureHooks helper
|
|
|
|
|
hook_type = 'forward_pre' if feature_location == 'pre_pwl' else 'forward'
|
|
|
|
|
hooks = [dict(name=self.feature_info[idx]['name'], type=hook_type) for idx in out_indices]
|
|
|
|
|
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
|
|
|
|
|
self.feature_hooks = None
|
|
|
|
|
if feature_location != 'bottleneck':
|
|
|
|
|
hooks = [dict(
|
|
|
|
|
name=self._feature_info[idx]['module'],
|
|
|
|
|
type=self._feature_info[idx]['hook_type']) for idx in out_indices]
|
|
|
|
|
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
|
|
|
|
|
|
|
|
|
|
def feature_channels(self, idx=None):
|
|
|
|
|
""" Feature Channel Shortcut
|
|
|
|
@ -193,15 +198,23 @@ class MobileNetV3Features(nn.Module):
|
|
|
|
|
return feature channel count for that feature block index (independent of out_indices setting).
|
|
|
|
|
"""
|
|
|
|
|
if isinstance(idx, int):
|
|
|
|
|
return self.feature_info[idx]['num_chs']
|
|
|
|
|
return [self.feature_info[i]['num_chs'] for i in self.out_indices]
|
|
|
|
|
return self._feature_info[idx]['num_chs']
|
|
|
|
|
return [self._feature_info[i]['num_chs'] for i in self.out_indices]
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.conv_stem(x)
|
|
|
|
|
x = self.bn1(x)
|
|
|
|
|
x = self.act1(x)
|
|
|
|
|
self.blocks(x)
|
|
|
|
|
return self.feature_hooks.get_output(x.device)
|
|
|
|
|
if self.feature_hooks is None:
|
|
|
|
|
features = []
|
|
|
|
|
for i, b in enumerate(self.blocks):
|
|
|
|
|
x = b(x)
|
|
|
|
|
if i in self._stage_to_feature_idx:
|
|
|
|
|
features.append(x)
|
|
|
|
|
return features
|
|
|
|
|
else:
|
|
|
|
|
self.blocks(x)
|
|
|
|
|
return self.feature_hooks.get_output(x.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_model(model_kwargs, default_cfg, pretrained=False):
|
|
|
|
|