Cleanup EfficientNet/MobileNetV3 feature extraction a bit, only two tap locations now, small mobilenetv3 models work

pull/175/head
Ross Wightman 4 years ago
parent 68fd8a267b
commit c146b54abc

@ -416,12 +416,6 @@ class EfficientNetFeatures(nn.Module):
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
super(EfficientNetFeatures, self).__init__()
norm_kwargs = norm_kwargs or {}
# TODO only create stages needed, currently all stages are created regardless of out_indices
num_stages = max(out_indices) + 1
self.out_indices = out_indices
self.feature_location = feature_location
self.drop_rate = drop_rate
self._in_chs = in_chans
@ -439,14 +433,10 @@ class EfficientNetFeatures(nn.Module):
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
self.feature_info = FeatureInfo(builder.features, out_indices)
self._stage_to_feature_idx = {
v['stage_idx']: fi for fi, v in enumerate(self.feature_info) if fi in self.out_indices}
self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices}
self._in_chs = builder.in_chs
efficientnet_init_weights(self)
if _DEBUG:
for fi, v in enumerate(self.feature_info):
print('Feature idx: {}: Name: {}, Channels: {}'.format(fi, v['module'], v['num_chs']))
# Register feature extraction hooks with FeatureHooks helper
self.feature_hooks = None
@ -460,14 +450,17 @@ class EfficientNetFeatures(nn.Module):
x = self.act1(x)
if self.feature_hooks is None:
features = []
if 0 in self._stage_out_idx:
features.append(x) # add stem out
for i, b in enumerate(self.blocks):
x = b(x)
if i in self._stage_to_feature_idx:
if i + 1 in self._stage_out_idx:
features.append(x)
return features
else:
self.blocks(x)
return self.feature_hooks.get_output(x.device)
out = self.feature_hooks.get_output(x.device)
return list(out.values())
def _create_effnet(model_kwargs, variant, pretrained=False):

@ -128,10 +128,9 @@ class ConvBnAct(nn.Module):
self.act1 = act_layer(inplace=True)
def feature_info(self, location):
if location == 'expansion' or location == 'depthwise':
# no expansion or depthwise this block, use act after conv
if location == 'expansion': # output of conv after act, same as block coutput
info = dict(module='act1', hook_type='forward', num_chs=self.conv.out_channels)
else: # location == 'bottleneck'
else: # location == 'bottleneck', block output
info = dict(module='', hook_type='', num_chs=self.conv.out_channels)
return info
@ -175,12 +174,9 @@ class DepthwiseSeparableConv(nn.Module):
self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity()
def feature_info(self, location):
if location == 'expansion':
# no expansion in this block, use depthwise, before SE
info = dict(module='act1', hook_type='forward', num_chs=self.conv_pw.in_channels)
elif location == 'depthwise': # after SE
if location == 'expansion': # after SE, input to PW
info = dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels)
else: # location == 'bottleneck'
else: # location == 'bottleneck', block output
info = dict(module='', hook_type='', num_chs=self.conv_pw.out_channels)
return info
@ -245,11 +241,9 @@ class InvertedResidual(nn.Module):
self.bn3 = norm_layer(out_chs, **norm_kwargs)
def feature_info(self, location):
if location == 'expansion':
info = dict(module='act1', hook_type='forward', num_chs=self.conv_pw.in_channels)
elif location == 'depthwise': # after SE
if location == 'expansion': # after SE, input to PWL
info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
else: # location == 'bottleneck'
else: # location == 'bottleneck', block output
info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels)
return info
@ -370,12 +364,9 @@ class EdgeResidual(nn.Module):
self.bn2 = norm_layer(out_chs, **norm_kwargs)
def feature_info(self, location):
if location == 'expansion':
info = dict(module='act1', hook_type='forward', num_chs=self.conv_exp.out_channels)
elif location == 'depthwise':
# there is no depthwise, take after SE, before PWL
if location == 'expansion': # after SE, before PWL
info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
else: # location == 'bottleneck'
else: # location == 'bottleneck', block output
info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels)
return info

@ -1,7 +1,6 @@
import logging
import math
import re
from collections import OrderedDict
from copy import deepcopy
import torch.nn as nn
@ -12,6 +11,11 @@ from .layers import CondConv2d, get_condconv_initializer
__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights"]
def _log_info_if(msg, condition):
if condition:
logging.info(msg)
def _parse_ksize(ss):
if ss.isdigit():
return int(ss)
@ -219,8 +223,12 @@ class EfficientNetBuilder:
self.norm_layer = norm_layer
self.norm_kwargs = norm_kwargs
self.drop_path_rate = drop_path_rate
if feature_location == 'depthwise':
# old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense
logging.warning("feature_location=='depthwise' is deprecated, using 'expansion'")
feature_location = 'expansion'
self.feature_location = feature_location
assert feature_location in ('bottleneck', 'depthwise', 'expansion', '')
assert feature_location in ('bottleneck', 'expansion', '')
self.verbose = verbose
# state updated during build, consumed by model
@ -247,8 +255,7 @@ class EfficientNetBuilder:
if bt == 'ir':
ba['drop_path_rate'] = drop_path_rate
ba['se_kwargs'] = self.se_kwargs
if self.verbose:
logging.info(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)))
_log_info_if(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
if ba.get('num_experts', 0) > 0:
block = CondConvResidual(**ba)
else:
@ -256,18 +263,15 @@ class EfficientNetBuilder:
elif bt == 'ds' or bt == 'dsa':
ba['drop_path_rate'] = drop_path_rate
ba['se_kwargs'] = self.se_kwargs
if self.verbose:
logging.info(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)))
_log_info_if(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
block = DepthwiseSeparableConv(**ba)
elif bt == 'er':
ba['drop_path_rate'] = drop_path_rate
ba['se_kwargs'] = self.se_kwargs
if self.verbose:
logging.info(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)))
_log_info_if(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
block = EdgeResidual(**ba)
elif bt == 'cn':
if self.verbose:
logging.info(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)))
_log_info_if(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
block = ConvBnAct(**ba)
else:
assert False, 'Uknkown block type (%s) while building model.' % bt
@ -279,64 +283,55 @@ class EfficientNetBuilder:
""" Build the blocks
Args:
in_chs: Number of input-channels passed to first block
model_block_args: A list of lists, outer list defines stages, inner
model_block_args: A list of lists, outer list defines stacks (block stages), inner
list contains strings defining block configuration(s)
Return:
List of block stacks (each stack wrapped in nn.Sequential)
"""
if self.verbose:
logging.info('Building model trunk with %d stages...' % len(model_block_args))
_log_info_if('Building model trunk with %d stages...' % len(model_block_args), self.verbose)
self.in_chs = in_chs
total_block_count = sum([len(x) for x in model_block_args])
total_block_idx = 0
current_stride = 2
current_dilation = 1
stages = []
# outer list of block_args defines the stacks ('stages' by some conventions)
for stage_idx, stage_block_args in enumerate(model_block_args):
last_stack = stage_idx == (len(model_block_args) - 1)
if self.verbose:
logging.info('Stack: {}'.format(stage_idx))
assert isinstance(stage_block_args, list)
if model_block_args[0][0]['stride'] > 1:
# if the first block starts with a stride, we need to extract first level feat from stem
feature_info = dict(
module='act1', num_chs=in_chs, stage=0, reduction=current_stride,
hook_type='forward' if self.feature_location != 'bottleneck' else '')
self.features.append(feature_info)
# outer list of block_args defines the stacks
for stack_idx, stack_args in enumerate(model_block_args):
last_stack = stack_idx + 1 == len(model_block_args)
_log_info_if('Stack: {}'.format(stack_idx), self.verbose)
assert isinstance(stack_args, list)
blocks = []
# each stack (stage) contains a list of block arguments
for block_idx, block_args in enumerate(stage_block_args):
last_block = block_idx == (len(stage_block_args) - 1)
extract_features = '' # No features extracted
if self.verbose:
logging.info(' Block: {}'.format(block_idx))
# Sort out stride, dilation, and feature extraction details
# each stack (stage of blocks) contains a list of block arguments
for block_idx, block_args in enumerate(stack_args):
last_block = block_idx + 1 == len(stack_args)
_log_info_if(' Block: {}'.format(block_idx), self.verbose)
assert block_args['stride'] in (1, 2)
if block_idx >= 1:
# only the first block in any stack can have a stride > 1
if block_idx >= 1: # only the first block in any stack can have a stride > 1
block_args['stride'] = 1
do_extract = False
if self.feature_location == 'bottleneck' or self.feature_location == 'depthwise':
extract_features = False
if last_block:
next_stage_idx = stage_idx + 1
if next_stage_idx >= len(model_block_args):
do_extract = True
else:
do_extract = model_block_args[next_stage_idx][0]['stride'] > 1
elif self.feature_location == 'expansion':
if block_args['stride'] > 1 or (last_stack and last_block):
do_extract = True
if do_extract:
extract_features = self.feature_location
next_stack_idx = stack_idx + 1
extract_features = next_stack_idx >= len(model_block_args) or \
model_block_args[next_stack_idx][0]['stride'] > 1
next_dilation = current_dilation
next_output_stride = current_stride
if block_args['stride'] > 1:
next_output_stride = current_stride * block_args['stride']
if next_output_stride > self.output_stride:
next_dilation = current_dilation * block_args['stride']
block_args['stride'] = 1
if self.verbose:
logging.info(' Converting stride to dilation to maintain output_stride=={}'.format(
self.output_stride))
_log_info_if(' Converting stride to dilation to maintain output_stride=={}'.format(
self.output_stride), self.verbose)
else:
current_stride = next_output_stride
block_args['dilation'] = current_dilation
@ -349,15 +344,11 @@ class EfficientNetBuilder:
# stash feature module name and channel info for model feature extraction
if extract_features:
feature_info = block.feature_info(extract_features)
module_name = f'blocks.{stage_idx}.{block_idx}'
if 'module' in feature_info and feature_info['module']:
feature_info['module'] = '.'.join([module_name, feature_info['module']])
else:
feature_info['module'] = module_name
feature_info['stage_idx'] = stage_idx
feature_info['block_idx'] = block_idx
feature_info['reduction'] = current_stride
feature_info = dict(
stage=stack_idx + 1, reduction=current_stride, **block.feature_info(self.feature_location))
module_name = f'blocks.{stack_idx}.{block_idx}'
leaf_name = feature_info.get('module', '')
feature_info['module'] = '.'.join([module_name, leaf_name]) if leaf_name else module_name
self.features.append(feature_info)
total_block_idx += 1 # incr global block idx (across all stacks)

@ -162,11 +162,6 @@ class MobileNetV3Features(nn.Module):
norm_layer=nn.BatchNorm2d, norm_kwargs=None):
super(MobileNetV3Features, self).__init__()
norm_kwargs = norm_kwargs or {}
# TODO only create stages needed, currently all stages are created regardless of out_indices
num_stages = max(out_indices) + 1
self.out_indices = out_indices
self.drop_rate = drop_rate
self._in_chs = in_chans
@ -183,14 +178,10 @@ class MobileNetV3Features(nn.Module):
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
self.feature_info = FeatureInfo(builder.features, out_indices)
self._stage_to_feature_idx = {
v['stage_idx']: fi for fi, v in enumerate(self.feature_info) if fi in self.out_indices}
self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices}
self._in_chs = builder.in_chs
efficientnet_init_weights(self)
if _DEBUG:
for fi, v in enumerate(self.feature_info):
print('Feature idx: {}: Name: {}, Channels: {}'.format(fi, v['module'], v['num_chs']))
# Register feature extraction hooks with FeatureHooks helper
self.feature_hooks = None
@ -204,9 +195,11 @@ class MobileNetV3Features(nn.Module):
x = self.act1(x)
if self.feature_hooks is None:
features = []
if 0 in self._stage_out_idx:
features.append(x) # add stem out
for i, b in enumerate(self.blocks):
x = b(x)
if i in self._stage_to_feature_idx:
if i + 1 in self._stage_out_idx:
features.append(x)
return features
else:

Loading…
Cancel
Save