From c146b54abce7eb7dc8103d831159accf0bd377ce Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 21 Jul 2020 01:21:38 -0700 Subject: [PATCH] Cleanup EfficientNet/MobileNetV3 feature extraction a bit, only two tap locations now, small mobilenetv3 models work --- timm/models/efficientnet.py | 19 ++---- timm/models/efficientnet_blocks.py | 25 +++---- timm/models/efficientnet_builder.py | 101 +++++++++++++--------------- timm/models/mobilenetv3.py | 15 ++--- 4 files changed, 64 insertions(+), 96 deletions(-) diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 7de4c8c4..21be2a96 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -416,12 +416,6 @@ class EfficientNetFeatures(nn.Module): se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None): super(EfficientNetFeatures, self).__init__() norm_kwargs = norm_kwargs or {} - - # TODO only create stages needed, currently all stages are created regardless of out_indices - num_stages = max(out_indices) + 1 - - self.out_indices = out_indices - self.feature_location = feature_location self.drop_rate = drop_rate self._in_chs = in_chans @@ -439,14 +433,10 @@ class EfficientNetFeatures(nn.Module): norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG) self.blocks = nn.Sequential(*builder(self._in_chs, block_args)) self.feature_info = FeatureInfo(builder.features, out_indices) - self._stage_to_feature_idx = { - v['stage_idx']: fi for fi, v in enumerate(self.feature_info) if fi in self.out_indices} + self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices} self._in_chs = builder.in_chs efficientnet_init_weights(self) - if _DEBUG: - for fi, v in enumerate(self.feature_info): - print('Feature idx: {}: Name: {}, Channels: {}'.format(fi, v['module'], v['num_chs'])) # Register feature extraction hooks with FeatureHooks helper self.feature_hooks = None @@ -460,14 +450,17 @@ class EfficientNetFeatures(nn.Module): x = self.act1(x) if self.feature_hooks is None: features = [] + if 0 in self._stage_out_idx: + features.append(x) # add stem out for i, b in enumerate(self.blocks): x = b(x) - if i in self._stage_to_feature_idx: + if i + 1 in self._stage_out_idx: features.append(x) return features else: self.blocks(x) - return self.feature_hooks.get_output(x.device) + out = self.feature_hooks.get_output(x.device) + return list(out.values()) def _create_effnet(model_kwargs, variant, pretrained=False): diff --git a/timm/models/efficientnet_blocks.py b/timm/models/efficientnet_blocks.py index 236623ff..98758abf 100644 --- a/timm/models/efficientnet_blocks.py +++ b/timm/models/efficientnet_blocks.py @@ -128,10 +128,9 @@ class ConvBnAct(nn.Module): self.act1 = act_layer(inplace=True) def feature_info(self, location): - if location == 'expansion' or location == 'depthwise': - # no expansion or depthwise this block, use act after conv + if location == 'expansion': # output of conv after act, same as block coutput info = dict(module='act1', hook_type='forward', num_chs=self.conv.out_channels) - else: # location == 'bottleneck' + else: # location == 'bottleneck', block output info = dict(module='', hook_type='', num_chs=self.conv.out_channels) return info @@ -175,12 +174,9 @@ class DepthwiseSeparableConv(nn.Module): self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity() def feature_info(self, location): - if location == 'expansion': - # no expansion in this block, use depthwise, before SE - info = dict(module='act1', hook_type='forward', num_chs=self.conv_pw.in_channels) - elif location == 'depthwise': # after SE + if location == 'expansion': # after SE, input to PW info = dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels) - else: # location == 'bottleneck' + else: # location == 'bottleneck', block output info = dict(module='', hook_type='', num_chs=self.conv_pw.out_channels) return info @@ -245,11 +241,9 @@ class InvertedResidual(nn.Module): self.bn3 = norm_layer(out_chs, **norm_kwargs) def feature_info(self, location): - if location == 'expansion': - info = dict(module='act1', hook_type='forward', num_chs=self.conv_pw.in_channels) - elif location == 'depthwise': # after SE + if location == 'expansion': # after SE, input to PWL info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels) - else: # location == 'bottleneck' + else: # location == 'bottleneck', block output info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels) return info @@ -370,12 +364,9 @@ class EdgeResidual(nn.Module): self.bn2 = norm_layer(out_chs, **norm_kwargs) def feature_info(self, location): - if location == 'expansion': - info = dict(module='act1', hook_type='forward', num_chs=self.conv_exp.out_channels) - elif location == 'depthwise': - # there is no depthwise, take after SE, before PWL + if location == 'expansion': # after SE, before PWL info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels) - else: # location == 'bottleneck' + else: # location == 'bottleneck', block output info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels) return info diff --git a/timm/models/efficientnet_builder.py b/timm/models/efficientnet_builder.py index 68d39c86..9e5f3b94 100644 --- a/timm/models/efficientnet_builder.py +++ b/timm/models/efficientnet_builder.py @@ -1,7 +1,6 @@ import logging import math import re -from collections import OrderedDict from copy import deepcopy import torch.nn as nn @@ -12,6 +11,11 @@ from .layers import CondConv2d, get_condconv_initializer __all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights"] +def _log_info_if(msg, condition): + if condition: + logging.info(msg) + + def _parse_ksize(ss): if ss.isdigit(): return int(ss) @@ -219,8 +223,12 @@ class EfficientNetBuilder: self.norm_layer = norm_layer self.norm_kwargs = norm_kwargs self.drop_path_rate = drop_path_rate + if feature_location == 'depthwise': + # old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense + logging.warning("feature_location=='depthwise' is deprecated, using 'expansion'") + feature_location = 'expansion' self.feature_location = feature_location - assert feature_location in ('bottleneck', 'depthwise', 'expansion', '') + assert feature_location in ('bottleneck', 'expansion', '') self.verbose = verbose # state updated during build, consumed by model @@ -247,8 +255,7 @@ class EfficientNetBuilder: if bt == 'ir': ba['drop_path_rate'] = drop_path_rate ba['se_kwargs'] = self.se_kwargs - if self.verbose: - logging.info(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba))) + _log_info_if(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose) if ba.get('num_experts', 0) > 0: block = CondConvResidual(**ba) else: @@ -256,18 +263,15 @@ class EfficientNetBuilder: elif bt == 'ds' or bt == 'dsa': ba['drop_path_rate'] = drop_path_rate ba['se_kwargs'] = self.se_kwargs - if self.verbose: - logging.info(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba))) + _log_info_if(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)), self.verbose) block = DepthwiseSeparableConv(**ba) elif bt == 'er': ba['drop_path_rate'] = drop_path_rate ba['se_kwargs'] = self.se_kwargs - if self.verbose: - logging.info(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba))) + _log_info_if(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose) block = EdgeResidual(**ba) elif bt == 'cn': - if self.verbose: - logging.info(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba))) + _log_info_if(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)), self.verbose) block = ConvBnAct(**ba) else: assert False, 'Uknkown block type (%s) while building model.' % bt @@ -279,64 +283,55 @@ class EfficientNetBuilder: """ Build the blocks Args: in_chs: Number of input-channels passed to first block - model_block_args: A list of lists, outer list defines stages, inner + model_block_args: A list of lists, outer list defines stacks (block stages), inner list contains strings defining block configuration(s) Return: List of block stacks (each stack wrapped in nn.Sequential) """ - if self.verbose: - logging.info('Building model trunk with %d stages...' % len(model_block_args)) + _log_info_if('Building model trunk with %d stages...' % len(model_block_args), self.verbose) self.in_chs = in_chs total_block_count = sum([len(x) for x in model_block_args]) total_block_idx = 0 current_stride = 2 current_dilation = 1 stages = [] - # outer list of block_args defines the stacks ('stages' by some conventions) - for stage_idx, stage_block_args in enumerate(model_block_args): - last_stack = stage_idx == (len(model_block_args) - 1) - if self.verbose: - logging.info('Stack: {}'.format(stage_idx)) - assert isinstance(stage_block_args, list) + if model_block_args[0][0]['stride'] > 1: + # if the first block starts with a stride, we need to extract first level feat from stem + feature_info = dict( + module='act1', num_chs=in_chs, stage=0, reduction=current_stride, + hook_type='forward' if self.feature_location != 'bottleneck' else '') + self.features.append(feature_info) + + # outer list of block_args defines the stacks + for stack_idx, stack_args in enumerate(model_block_args): + last_stack = stack_idx + 1 == len(model_block_args) + _log_info_if('Stack: {}'.format(stack_idx), self.verbose) + assert isinstance(stack_args, list) blocks = [] - # each stack (stage) contains a list of block arguments - for block_idx, block_args in enumerate(stage_block_args): - last_block = block_idx == (len(stage_block_args) - 1) - extract_features = '' # No features extracted - if self.verbose: - logging.info(' Block: {}'.format(block_idx)) - - # Sort out stride, dilation, and feature extraction details + # each stack (stage of blocks) contains a list of block arguments + for block_idx, block_args in enumerate(stack_args): + last_block = block_idx + 1 == len(stack_args) + _log_info_if(' Block: {}'.format(block_idx), self.verbose) + assert block_args['stride'] in (1, 2) - if block_idx >= 1: - # only the first block in any stack can have a stride > 1 + if block_idx >= 1: # only the first block in any stack can have a stride > 1 block_args['stride'] = 1 - do_extract = False - if self.feature_location == 'bottleneck' or self.feature_location == 'depthwise': - if last_block: - next_stage_idx = stage_idx + 1 - if next_stage_idx >= len(model_block_args): - do_extract = True - else: - do_extract = model_block_args[next_stage_idx][0]['stride'] > 1 - elif self.feature_location == 'expansion': - if block_args['stride'] > 1 or (last_stack and last_block): - do_extract = True - if do_extract: - extract_features = self.feature_location + extract_features = False + if last_block: + next_stack_idx = stack_idx + 1 + extract_features = next_stack_idx >= len(model_block_args) or \ + model_block_args[next_stack_idx][0]['stride'] > 1 next_dilation = current_dilation - next_output_stride = current_stride if block_args['stride'] > 1: next_output_stride = current_stride * block_args['stride'] if next_output_stride > self.output_stride: next_dilation = current_dilation * block_args['stride'] block_args['stride'] = 1 - if self.verbose: - logging.info(' Converting stride to dilation to maintain output_stride=={}'.format( - self.output_stride)) + _log_info_if(' Converting stride to dilation to maintain output_stride=={}'.format( + self.output_stride), self.verbose) else: current_stride = next_output_stride block_args['dilation'] = current_dilation @@ -349,15 +344,11 @@ class EfficientNetBuilder: # stash feature module name and channel info for model feature extraction if extract_features: - feature_info = block.feature_info(extract_features) - module_name = f'blocks.{stage_idx}.{block_idx}' - if 'module' in feature_info and feature_info['module']: - feature_info['module'] = '.'.join([module_name, feature_info['module']]) - else: - feature_info['module'] = module_name - feature_info['stage_idx'] = stage_idx - feature_info['block_idx'] = block_idx - feature_info['reduction'] = current_stride + feature_info = dict( + stage=stack_idx + 1, reduction=current_stride, **block.feature_info(self.feature_location)) + module_name = f'blocks.{stack_idx}.{block_idx}' + leaf_name = feature_info.get('module', '') + feature_info['module'] = '.'.join([module_name, leaf_name]) if leaf_name else module_name self.features.append(feature_info) total_block_idx += 1 # incr global block idx (across all stacks) diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index f8e3d738..7e4af274 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -162,11 +162,6 @@ class MobileNetV3Features(nn.Module): norm_layer=nn.BatchNorm2d, norm_kwargs=None): super(MobileNetV3Features, self).__init__() norm_kwargs = norm_kwargs or {} - - # TODO only create stages needed, currently all stages are created regardless of out_indices - num_stages = max(out_indices) + 1 - - self.out_indices = out_indices self.drop_rate = drop_rate self._in_chs = in_chans @@ -183,14 +178,10 @@ class MobileNetV3Features(nn.Module): norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG) self.blocks = nn.Sequential(*builder(self._in_chs, block_args)) self.feature_info = FeatureInfo(builder.features, out_indices) - self._stage_to_feature_idx = { - v['stage_idx']: fi for fi, v in enumerate(self.feature_info) if fi in self.out_indices} + self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices} self._in_chs = builder.in_chs efficientnet_init_weights(self) - if _DEBUG: - for fi, v in enumerate(self.feature_info): - print('Feature idx: {}: Name: {}, Channels: {}'.format(fi, v['module'], v['num_chs'])) # Register feature extraction hooks with FeatureHooks helper self.feature_hooks = None @@ -204,9 +195,11 @@ class MobileNetV3Features(nn.Module): x = self.act1(x) if self.feature_hooks is None: features = [] + if 0 in self._stage_out_idx: + features.append(x) # add stem out for i, b in enumerate(self.blocks): x = b(x) - if i in self._stage_to_feature_idx: + if i + 1 in self._stage_out_idx: features.append(x) return features else: