|
|
|
@ -1,7 +1,6 @@
|
|
|
|
|
import logging
|
|
|
|
|
import math
|
|
|
|
|
import re
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
from copy import deepcopy
|
|
|
|
|
|
|
|
|
|
import torch.nn as nn
|
|
|
|
@ -12,6 +11,11 @@ from .layers import CondConv2d, get_condconv_initializer
|
|
|
|
|
__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _log_info_if(msg, condition):
|
|
|
|
|
if condition:
|
|
|
|
|
logging.info(msg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _parse_ksize(ss):
|
|
|
|
|
if ss.isdigit():
|
|
|
|
|
return int(ss)
|
|
|
|
@ -219,8 +223,12 @@ class EfficientNetBuilder:
|
|
|
|
|
self.norm_layer = norm_layer
|
|
|
|
|
self.norm_kwargs = norm_kwargs
|
|
|
|
|
self.drop_path_rate = drop_path_rate
|
|
|
|
|
if feature_location == 'depthwise':
|
|
|
|
|
# old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense
|
|
|
|
|
logging.warning("feature_location=='depthwise' is deprecated, using 'expansion'")
|
|
|
|
|
feature_location = 'expansion'
|
|
|
|
|
self.feature_location = feature_location
|
|
|
|
|
assert feature_location in ('bottleneck', 'depthwise', 'expansion', '')
|
|
|
|
|
assert feature_location in ('bottleneck', 'expansion', '')
|
|
|
|
|
self.verbose = verbose
|
|
|
|
|
|
|
|
|
|
# state updated during build, consumed by model
|
|
|
|
@ -247,8 +255,7 @@ class EfficientNetBuilder:
|
|
|
|
|
if bt == 'ir':
|
|
|
|
|
ba['drop_path_rate'] = drop_path_rate
|
|
|
|
|
ba['se_kwargs'] = self.se_kwargs
|
|
|
|
|
if self.verbose:
|
|
|
|
|
logging.info(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)))
|
|
|
|
|
_log_info_if(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
|
|
|
|
|
if ba.get('num_experts', 0) > 0:
|
|
|
|
|
block = CondConvResidual(**ba)
|
|
|
|
|
else:
|
|
|
|
@ -256,18 +263,15 @@ class EfficientNetBuilder:
|
|
|
|
|
elif bt == 'ds' or bt == 'dsa':
|
|
|
|
|
ba['drop_path_rate'] = drop_path_rate
|
|
|
|
|
ba['se_kwargs'] = self.se_kwargs
|
|
|
|
|
if self.verbose:
|
|
|
|
|
logging.info(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)))
|
|
|
|
|
_log_info_if(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
|
|
|
|
|
block = DepthwiseSeparableConv(**ba)
|
|
|
|
|
elif bt == 'er':
|
|
|
|
|
ba['drop_path_rate'] = drop_path_rate
|
|
|
|
|
ba['se_kwargs'] = self.se_kwargs
|
|
|
|
|
if self.verbose:
|
|
|
|
|
logging.info(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)))
|
|
|
|
|
_log_info_if(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
|
|
|
|
|
block = EdgeResidual(**ba)
|
|
|
|
|
elif bt == 'cn':
|
|
|
|
|
if self.verbose:
|
|
|
|
|
logging.info(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)))
|
|
|
|
|
_log_info_if(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
|
|
|
|
|
block = ConvBnAct(**ba)
|
|
|
|
|
else:
|
|
|
|
|
assert False, 'Uknkown block type (%s) while building model.' % bt
|
|
|
|
@ -279,64 +283,55 @@ class EfficientNetBuilder:
|
|
|
|
|
""" Build the blocks
|
|
|
|
|
Args:
|
|
|
|
|
in_chs: Number of input-channels passed to first block
|
|
|
|
|
model_block_args: A list of lists, outer list defines stages, inner
|
|
|
|
|
model_block_args: A list of lists, outer list defines stacks (block stages), inner
|
|
|
|
|
list contains strings defining block configuration(s)
|
|
|
|
|
Return:
|
|
|
|
|
List of block stacks (each stack wrapped in nn.Sequential)
|
|
|
|
|
"""
|
|
|
|
|
if self.verbose:
|
|
|
|
|
logging.info('Building model trunk with %d stages...' % len(model_block_args))
|
|
|
|
|
_log_info_if('Building model trunk with %d stages...' % len(model_block_args), self.verbose)
|
|
|
|
|
self.in_chs = in_chs
|
|
|
|
|
total_block_count = sum([len(x) for x in model_block_args])
|
|
|
|
|
total_block_idx = 0
|
|
|
|
|
current_stride = 2
|
|
|
|
|
current_dilation = 1
|
|
|
|
|
stages = []
|
|
|
|
|
# outer list of block_args defines the stacks ('stages' by some conventions)
|
|
|
|
|
for stage_idx, stage_block_args in enumerate(model_block_args):
|
|
|
|
|
last_stack = stage_idx == (len(model_block_args) - 1)
|
|
|
|
|
if self.verbose:
|
|
|
|
|
logging.info('Stack: {}'.format(stage_idx))
|
|
|
|
|
assert isinstance(stage_block_args, list)
|
|
|
|
|
if model_block_args[0][0]['stride'] > 1:
|
|
|
|
|
# if the first block starts with a stride, we need to extract first level feat from stem
|
|
|
|
|
feature_info = dict(
|
|
|
|
|
module='act1', num_chs=in_chs, stage=0, reduction=current_stride,
|
|
|
|
|
hook_type='forward' if self.feature_location != 'bottleneck' else '')
|
|
|
|
|
self.features.append(feature_info)
|
|
|
|
|
|
|
|
|
|
# outer list of block_args defines the stacks
|
|
|
|
|
for stack_idx, stack_args in enumerate(model_block_args):
|
|
|
|
|
last_stack = stack_idx + 1 == len(model_block_args)
|
|
|
|
|
_log_info_if('Stack: {}'.format(stack_idx), self.verbose)
|
|
|
|
|
assert isinstance(stack_args, list)
|
|
|
|
|
|
|
|
|
|
blocks = []
|
|
|
|
|
# each stack (stage) contains a list of block arguments
|
|
|
|
|
for block_idx, block_args in enumerate(stage_block_args):
|
|
|
|
|
last_block = block_idx == (len(stage_block_args) - 1)
|
|
|
|
|
extract_features = '' # No features extracted
|
|
|
|
|
if self.verbose:
|
|
|
|
|
logging.info(' Block: {}'.format(block_idx))
|
|
|
|
|
|
|
|
|
|
# Sort out stride, dilation, and feature extraction details
|
|
|
|
|
# each stack (stage of blocks) contains a list of block arguments
|
|
|
|
|
for block_idx, block_args in enumerate(stack_args):
|
|
|
|
|
last_block = block_idx + 1 == len(stack_args)
|
|
|
|
|
_log_info_if(' Block: {}'.format(block_idx), self.verbose)
|
|
|
|
|
|
|
|
|
|
assert block_args['stride'] in (1, 2)
|
|
|
|
|
if block_idx >= 1:
|
|
|
|
|
# only the first block in any stack can have a stride > 1
|
|
|
|
|
if block_idx >= 1: # only the first block in any stack can have a stride > 1
|
|
|
|
|
block_args['stride'] = 1
|
|
|
|
|
|
|
|
|
|
do_extract = False
|
|
|
|
|
if self.feature_location == 'bottleneck' or self.feature_location == 'depthwise':
|
|
|
|
|
if last_block:
|
|
|
|
|
next_stage_idx = stage_idx + 1
|
|
|
|
|
if next_stage_idx >= len(model_block_args):
|
|
|
|
|
do_extract = True
|
|
|
|
|
else:
|
|
|
|
|
do_extract = model_block_args[next_stage_idx][0]['stride'] > 1
|
|
|
|
|
elif self.feature_location == 'expansion':
|
|
|
|
|
if block_args['stride'] > 1 or (last_stack and last_block):
|
|
|
|
|
do_extract = True
|
|
|
|
|
if do_extract:
|
|
|
|
|
extract_features = self.feature_location
|
|
|
|
|
extract_features = False
|
|
|
|
|
if last_block:
|
|
|
|
|
next_stack_idx = stack_idx + 1
|
|
|
|
|
extract_features = next_stack_idx >= len(model_block_args) or \
|
|
|
|
|
model_block_args[next_stack_idx][0]['stride'] > 1
|
|
|
|
|
|
|
|
|
|
next_dilation = current_dilation
|
|
|
|
|
next_output_stride = current_stride
|
|
|
|
|
if block_args['stride'] > 1:
|
|
|
|
|
next_output_stride = current_stride * block_args['stride']
|
|
|
|
|
if next_output_stride > self.output_stride:
|
|
|
|
|
next_dilation = current_dilation * block_args['stride']
|
|
|
|
|
block_args['stride'] = 1
|
|
|
|
|
if self.verbose:
|
|
|
|
|
logging.info(' Converting stride to dilation to maintain output_stride=={}'.format(
|
|
|
|
|
self.output_stride))
|
|
|
|
|
_log_info_if(' Converting stride to dilation to maintain output_stride=={}'.format(
|
|
|
|
|
self.output_stride), self.verbose)
|
|
|
|
|
else:
|
|
|
|
|
current_stride = next_output_stride
|
|
|
|
|
block_args['dilation'] = current_dilation
|
|
|
|
@ -349,15 +344,11 @@ class EfficientNetBuilder:
|
|
|
|
|
|
|
|
|
|
# stash feature module name and channel info for model feature extraction
|
|
|
|
|
if extract_features:
|
|
|
|
|
feature_info = block.feature_info(extract_features)
|
|
|
|
|
module_name = f'blocks.{stage_idx}.{block_idx}'
|
|
|
|
|
if 'module' in feature_info and feature_info['module']:
|
|
|
|
|
feature_info['module'] = '.'.join([module_name, feature_info['module']])
|
|
|
|
|
else:
|
|
|
|
|
feature_info['module'] = module_name
|
|
|
|
|
feature_info['stage_idx'] = stage_idx
|
|
|
|
|
feature_info['block_idx'] = block_idx
|
|
|
|
|
feature_info['reduction'] = current_stride
|
|
|
|
|
feature_info = dict(
|
|
|
|
|
stage=stack_idx + 1, reduction=current_stride, **block.feature_info(self.feature_location))
|
|
|
|
|
module_name = f'blocks.{stack_idx}.{block_idx}'
|
|
|
|
|
leaf_name = feature_info.get('module', '')
|
|
|
|
|
feature_info['module'] = '.'.join([module_name, leaf_name]) if leaf_name else module_name
|
|
|
|
|
self.features.append(feature_info)
|
|
|
|
|
|
|
|
|
|
total_block_idx += 1 # incr global block idx (across all stacks)
|
|
|
|
|