Cleanup EfficientNet/MobileNetV3 feature extraction a bit, only two tap locations now, small mobilenetv3 models work

pull/175/head
Ross Wightman 4 years ago
parent 68fd8a267b
commit c146b54abc

@ -416,12 +416,6 @@ class EfficientNetFeatures(nn.Module):
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None): se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
super(EfficientNetFeatures, self).__init__() super(EfficientNetFeatures, self).__init__()
norm_kwargs = norm_kwargs or {} norm_kwargs = norm_kwargs or {}
# TODO only create stages needed, currently all stages are created regardless of out_indices
num_stages = max(out_indices) + 1
self.out_indices = out_indices
self.feature_location = feature_location
self.drop_rate = drop_rate self.drop_rate = drop_rate
self._in_chs = in_chans self._in_chs = in_chans
@ -439,14 +433,10 @@ class EfficientNetFeatures(nn.Module):
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG) norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
self.blocks = nn.Sequential(*builder(self._in_chs, block_args)) self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
self.feature_info = FeatureInfo(builder.features, out_indices) self.feature_info = FeatureInfo(builder.features, out_indices)
self._stage_to_feature_idx = { self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices}
v['stage_idx']: fi for fi, v in enumerate(self.feature_info) if fi in self.out_indices}
self._in_chs = builder.in_chs self._in_chs = builder.in_chs
efficientnet_init_weights(self) efficientnet_init_weights(self)
if _DEBUG:
for fi, v in enumerate(self.feature_info):
print('Feature idx: {}: Name: {}, Channels: {}'.format(fi, v['module'], v['num_chs']))
# Register feature extraction hooks with FeatureHooks helper # Register feature extraction hooks with FeatureHooks helper
self.feature_hooks = None self.feature_hooks = None
@ -460,14 +450,17 @@ class EfficientNetFeatures(nn.Module):
x = self.act1(x) x = self.act1(x)
if self.feature_hooks is None: if self.feature_hooks is None:
features = [] features = []
if 0 in self._stage_out_idx:
features.append(x) # add stem out
for i, b in enumerate(self.blocks): for i, b in enumerate(self.blocks):
x = b(x) x = b(x)
if i in self._stage_to_feature_idx: if i + 1 in self._stage_out_idx:
features.append(x) features.append(x)
return features return features
else: else:
self.blocks(x) self.blocks(x)
return self.feature_hooks.get_output(x.device) out = self.feature_hooks.get_output(x.device)
return list(out.values())
def _create_effnet(model_kwargs, variant, pretrained=False): def _create_effnet(model_kwargs, variant, pretrained=False):

@ -128,10 +128,9 @@ class ConvBnAct(nn.Module):
self.act1 = act_layer(inplace=True) self.act1 = act_layer(inplace=True)
def feature_info(self, location): def feature_info(self, location):
if location == 'expansion' or location == 'depthwise': if location == 'expansion': # output of conv after act, same as block coutput
# no expansion or depthwise this block, use act after conv
info = dict(module='act1', hook_type='forward', num_chs=self.conv.out_channels) info = dict(module='act1', hook_type='forward', num_chs=self.conv.out_channels)
else: # location == 'bottleneck' else: # location == 'bottleneck', block output
info = dict(module='', hook_type='', num_chs=self.conv.out_channels) info = dict(module='', hook_type='', num_chs=self.conv.out_channels)
return info return info
@ -175,12 +174,9 @@ class DepthwiseSeparableConv(nn.Module):
self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity() self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity()
def feature_info(self, location): def feature_info(self, location):
if location == 'expansion': if location == 'expansion': # after SE, input to PW
# no expansion in this block, use depthwise, before SE
info = dict(module='act1', hook_type='forward', num_chs=self.conv_pw.in_channels)
elif location == 'depthwise': # after SE
info = dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels) info = dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels)
else: # location == 'bottleneck' else: # location == 'bottleneck', block output
info = dict(module='', hook_type='', num_chs=self.conv_pw.out_channels) info = dict(module='', hook_type='', num_chs=self.conv_pw.out_channels)
return info return info
@ -245,11 +241,9 @@ class InvertedResidual(nn.Module):
self.bn3 = norm_layer(out_chs, **norm_kwargs) self.bn3 = norm_layer(out_chs, **norm_kwargs)
def feature_info(self, location): def feature_info(self, location):
if location == 'expansion': if location == 'expansion': # after SE, input to PWL
info = dict(module='act1', hook_type='forward', num_chs=self.conv_pw.in_channels)
elif location == 'depthwise': # after SE
info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels) info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
else: # location == 'bottleneck' else: # location == 'bottleneck', block output
info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels) info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels)
return info return info
@ -370,12 +364,9 @@ class EdgeResidual(nn.Module):
self.bn2 = norm_layer(out_chs, **norm_kwargs) self.bn2 = norm_layer(out_chs, **norm_kwargs)
def feature_info(self, location): def feature_info(self, location):
if location == 'expansion': if location == 'expansion': # after SE, before PWL
info = dict(module='act1', hook_type='forward', num_chs=self.conv_exp.out_channels)
elif location == 'depthwise':
# there is no depthwise, take after SE, before PWL
info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels) info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
else: # location == 'bottleneck' else: # location == 'bottleneck', block output
info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels) info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels)
return info return info

@ -1,7 +1,6 @@
import logging import logging
import math import math
import re import re
from collections import OrderedDict
from copy import deepcopy from copy import deepcopy
import torch.nn as nn import torch.nn as nn
@ -12,6 +11,11 @@ from .layers import CondConv2d, get_condconv_initializer
__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights"] __all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights"]
def _log_info_if(msg, condition):
if condition:
logging.info(msg)
def _parse_ksize(ss): def _parse_ksize(ss):
if ss.isdigit(): if ss.isdigit():
return int(ss) return int(ss)
@ -219,8 +223,12 @@ class EfficientNetBuilder:
self.norm_layer = norm_layer self.norm_layer = norm_layer
self.norm_kwargs = norm_kwargs self.norm_kwargs = norm_kwargs
self.drop_path_rate = drop_path_rate self.drop_path_rate = drop_path_rate
if feature_location == 'depthwise':
# old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense
logging.warning("feature_location=='depthwise' is deprecated, using 'expansion'")
feature_location = 'expansion'
self.feature_location = feature_location self.feature_location = feature_location
assert feature_location in ('bottleneck', 'depthwise', 'expansion', '') assert feature_location in ('bottleneck', 'expansion', '')
self.verbose = verbose self.verbose = verbose
# state updated during build, consumed by model # state updated during build, consumed by model
@ -247,8 +255,7 @@ class EfficientNetBuilder:
if bt == 'ir': if bt == 'ir':
ba['drop_path_rate'] = drop_path_rate ba['drop_path_rate'] = drop_path_rate
ba['se_kwargs'] = self.se_kwargs ba['se_kwargs'] = self.se_kwargs
if self.verbose: _log_info_if(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
logging.info(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)))
if ba.get('num_experts', 0) > 0: if ba.get('num_experts', 0) > 0:
block = CondConvResidual(**ba) block = CondConvResidual(**ba)
else: else:
@ -256,18 +263,15 @@ class EfficientNetBuilder:
elif bt == 'ds' or bt == 'dsa': elif bt == 'ds' or bt == 'dsa':
ba['drop_path_rate'] = drop_path_rate ba['drop_path_rate'] = drop_path_rate
ba['se_kwargs'] = self.se_kwargs ba['se_kwargs'] = self.se_kwargs
if self.verbose: _log_info_if(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
logging.info(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)))
block = DepthwiseSeparableConv(**ba) block = DepthwiseSeparableConv(**ba)
elif bt == 'er': elif bt == 'er':
ba['drop_path_rate'] = drop_path_rate ba['drop_path_rate'] = drop_path_rate
ba['se_kwargs'] = self.se_kwargs ba['se_kwargs'] = self.se_kwargs
if self.verbose: _log_info_if(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
logging.info(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)))
block = EdgeResidual(**ba) block = EdgeResidual(**ba)
elif bt == 'cn': elif bt == 'cn':
if self.verbose: _log_info_if(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
logging.info(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)))
block = ConvBnAct(**ba) block = ConvBnAct(**ba)
else: else:
assert False, 'Uknkown block type (%s) while building model.' % bt assert False, 'Uknkown block type (%s) while building model.' % bt
@ -279,64 +283,55 @@ class EfficientNetBuilder:
""" Build the blocks """ Build the blocks
Args: Args:
in_chs: Number of input-channels passed to first block in_chs: Number of input-channels passed to first block
model_block_args: A list of lists, outer list defines stages, inner model_block_args: A list of lists, outer list defines stacks (block stages), inner
list contains strings defining block configuration(s) list contains strings defining block configuration(s)
Return: Return:
List of block stacks (each stack wrapped in nn.Sequential) List of block stacks (each stack wrapped in nn.Sequential)
""" """
if self.verbose: _log_info_if('Building model trunk with %d stages...' % len(model_block_args), self.verbose)
logging.info('Building model trunk with %d stages...' % len(model_block_args))
self.in_chs = in_chs self.in_chs = in_chs
total_block_count = sum([len(x) for x in model_block_args]) total_block_count = sum([len(x) for x in model_block_args])
total_block_idx = 0 total_block_idx = 0
current_stride = 2 current_stride = 2
current_dilation = 1 current_dilation = 1
stages = [] stages = []
# outer list of block_args defines the stacks ('stages' by some conventions) if model_block_args[0][0]['stride'] > 1:
for stage_idx, stage_block_args in enumerate(model_block_args): # if the first block starts with a stride, we need to extract first level feat from stem
last_stack = stage_idx == (len(model_block_args) - 1) feature_info = dict(
if self.verbose: module='act1', num_chs=in_chs, stage=0, reduction=current_stride,
logging.info('Stack: {}'.format(stage_idx)) hook_type='forward' if self.feature_location != 'bottleneck' else '')
assert isinstance(stage_block_args, list) self.features.append(feature_info)
# outer list of block_args defines the stacks
for stack_idx, stack_args in enumerate(model_block_args):
last_stack = stack_idx + 1 == len(model_block_args)
_log_info_if('Stack: {}'.format(stack_idx), self.verbose)
assert isinstance(stack_args, list)
blocks = [] blocks = []
# each stack (stage) contains a list of block arguments # each stack (stage of blocks) contains a list of block arguments
for block_idx, block_args in enumerate(stage_block_args): for block_idx, block_args in enumerate(stack_args):
last_block = block_idx == (len(stage_block_args) - 1) last_block = block_idx + 1 == len(stack_args)
extract_features = '' # No features extracted _log_info_if(' Block: {}'.format(block_idx), self.verbose)
if self.verbose:
logging.info(' Block: {}'.format(block_idx))
# Sort out stride, dilation, and feature extraction details
assert block_args['stride'] in (1, 2) assert block_args['stride'] in (1, 2)
if block_idx >= 1: if block_idx >= 1: # only the first block in any stack can have a stride > 1
# only the first block in any stack can have a stride > 1
block_args['stride'] = 1 block_args['stride'] = 1
do_extract = False extract_features = False
if self.feature_location == 'bottleneck' or self.feature_location == 'depthwise':
if last_block: if last_block:
next_stage_idx = stage_idx + 1 next_stack_idx = stack_idx + 1
if next_stage_idx >= len(model_block_args): extract_features = next_stack_idx >= len(model_block_args) or \
do_extract = True model_block_args[next_stack_idx][0]['stride'] > 1
else:
do_extract = model_block_args[next_stage_idx][0]['stride'] > 1
elif self.feature_location == 'expansion':
if block_args['stride'] > 1 or (last_stack and last_block):
do_extract = True
if do_extract:
extract_features = self.feature_location
next_dilation = current_dilation next_dilation = current_dilation
next_output_stride = current_stride
if block_args['stride'] > 1: if block_args['stride'] > 1:
next_output_stride = current_stride * block_args['stride'] next_output_stride = current_stride * block_args['stride']
if next_output_stride > self.output_stride: if next_output_stride > self.output_stride:
next_dilation = current_dilation * block_args['stride'] next_dilation = current_dilation * block_args['stride']
block_args['stride'] = 1 block_args['stride'] = 1
if self.verbose: _log_info_if(' Converting stride to dilation to maintain output_stride=={}'.format(
logging.info(' Converting stride to dilation to maintain output_stride=={}'.format( self.output_stride), self.verbose)
self.output_stride))
else: else:
current_stride = next_output_stride current_stride = next_output_stride
block_args['dilation'] = current_dilation block_args['dilation'] = current_dilation
@ -349,15 +344,11 @@ class EfficientNetBuilder:
# stash feature module name and channel info for model feature extraction # stash feature module name and channel info for model feature extraction
if extract_features: if extract_features:
feature_info = block.feature_info(extract_features) feature_info = dict(
module_name = f'blocks.{stage_idx}.{block_idx}' stage=stack_idx + 1, reduction=current_stride, **block.feature_info(self.feature_location))
if 'module' in feature_info and feature_info['module']: module_name = f'blocks.{stack_idx}.{block_idx}'
feature_info['module'] = '.'.join([module_name, feature_info['module']]) leaf_name = feature_info.get('module', '')
else: feature_info['module'] = '.'.join([module_name, leaf_name]) if leaf_name else module_name
feature_info['module'] = module_name
feature_info['stage_idx'] = stage_idx
feature_info['block_idx'] = block_idx
feature_info['reduction'] = current_stride
self.features.append(feature_info) self.features.append(feature_info)
total_block_idx += 1 # incr global block idx (across all stacks) total_block_idx += 1 # incr global block idx (across all stacks)

@ -162,11 +162,6 @@ class MobileNetV3Features(nn.Module):
norm_layer=nn.BatchNorm2d, norm_kwargs=None): norm_layer=nn.BatchNorm2d, norm_kwargs=None):
super(MobileNetV3Features, self).__init__() super(MobileNetV3Features, self).__init__()
norm_kwargs = norm_kwargs or {} norm_kwargs = norm_kwargs or {}
# TODO only create stages needed, currently all stages are created regardless of out_indices
num_stages = max(out_indices) + 1
self.out_indices = out_indices
self.drop_rate = drop_rate self.drop_rate = drop_rate
self._in_chs = in_chans self._in_chs = in_chans
@ -183,14 +178,10 @@ class MobileNetV3Features(nn.Module):
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG) norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
self.blocks = nn.Sequential(*builder(self._in_chs, block_args)) self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
self.feature_info = FeatureInfo(builder.features, out_indices) self.feature_info = FeatureInfo(builder.features, out_indices)
self._stage_to_feature_idx = { self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices}
v['stage_idx']: fi for fi, v in enumerate(self.feature_info) if fi in self.out_indices}
self._in_chs = builder.in_chs self._in_chs = builder.in_chs
efficientnet_init_weights(self) efficientnet_init_weights(self)
if _DEBUG:
for fi, v in enumerate(self.feature_info):
print('Feature idx: {}: Name: {}, Channels: {}'.format(fi, v['module'], v['num_chs']))
# Register feature extraction hooks with FeatureHooks helper # Register feature extraction hooks with FeatureHooks helper
self.feature_hooks = None self.feature_hooks = None
@ -204,9 +195,11 @@ class MobileNetV3Features(nn.Module):
x = self.act1(x) x = self.act1(x)
if self.feature_hooks is None: if self.feature_hooks is None:
features = [] features = []
if 0 in self._stage_out_idx:
features.append(x) # add stem out
for i, b in enumerate(self.blocks): for i, b in enumerate(self.blocks):
x = b(x) x = b(x)
if i in self._stage_to_feature_idx: if i + 1 in self._stage_out_idx:
features.append(x) features.append(x)
return features return features
else: else:

Loading…
Cancel
Save