@ -148,7 +148,7 @@ class MobileNetV3Features(nn.Module):
and object detection models .
"""
def __init__ ( self , block_args , out_indices = ( 0 , 1 , 2 , 3 , 4 ) , feature_location = ' pre_pwl ' ,
def __init__ ( self , block_args , out_indices = ( 0 , 1 , 2 , 3 , 4 ) , feature_location = ' bottleneck ' ,
in_chans = 3 , stem_size = 16 , channel_multiplier = 1.0 , output_stride = 32 , pad_type = ' ' ,
act_layer = nn . ReLU , drop_rate = 0. , drop_path_rate = 0. , se_kwargs = None ,
norm_layer = nn . BatchNorm2d , norm_kwargs = None ) :
@ -174,18 +174,23 @@ class MobileNetV3Features(nn.Module):
channel_multiplier , 8 , None , output_stride , pad_type , act_layer , se_kwargs ,
norm_layer , norm_kwargs , drop_path_rate , feature_location = feature_location , verbose = _DEBUG )
self . blocks = nn . Sequential ( * builder ( self . _in_chs , block_args ) )
self . feature_info = builder . features # builder provides info about feature channels for each block
self . _feature_info = builder . features # builder provides info about feature channels for each block
self . _stage_to_feature_idx = {
v [ ' stage_idx ' ] : fi for fi , v in self . _feature_info . items ( ) if fi in self . out_indices }
self . _in_chs = builder . in_chs
efficientnet_init_weights ( self )
if _DEBUG :
for k , v in self . feature_info. items ( ) :
for k , v in self . _ feature_info. items ( ) :
print ( ' Feature idx: {} : Name: {} , Channels: {} ' . format ( k , v [ ' name ' ] , v [ ' num_chs ' ] ) )
# Register feature extraction hooks with FeatureHooks helper
hook_type = ' forward_pre ' if feature_location == ' pre_pwl ' else ' forward '
hooks = [ dict ( name = self . feature_info [ idx ] [ ' name ' ] , type = hook_type ) for idx in out_indices ]
self . feature_hooks = FeatureHooks ( hooks , self . named_modules ( ) )
self . feature_hooks = None
if feature_location != ' bottleneck ' :
hooks = [ dict (
name = self . _feature_info [ idx ] [ ' module ' ] ,
type = self . _feature_info [ idx ] [ ' hook_type ' ] ) for idx in out_indices ]
self . feature_hooks = FeatureHooks ( hooks , self . named_modules ( ) )
def feature_channels ( self , idx = None ) :
""" Feature Channel Shortcut
@ -193,15 +198,23 @@ class MobileNetV3Features(nn.Module):
return feature channel count for that feature block index ( independent of out_indices setting ) .
"""
if isinstance ( idx , int ) :
return self . feature_info[ idx ] [ ' num_chs ' ]
return [ self . feature_info[ i ] [ ' num_chs ' ] for i in self . out_indices ]
return self . _ feature_info[ idx ] [ ' num_chs ' ]
return [ self . _ feature_info[ i ] [ ' num_chs ' ] for i in self . out_indices ]
def forward ( self , x ) :
x = self . conv_stem ( x )
x = self . bn1 ( x )
x = self . act1 ( x )
self . blocks ( x )
return self . feature_hooks . get_output ( x . device )
if self . feature_hooks is None :
features = [ ]
for i , b in enumerate ( self . blocks ) :
x = b ( x )
if i in self . _stage_to_feature_idx :
features . append ( x )
return features
else :
self . blocks ( x )
return self . feature_hooks . get_output ( x . device )
def _create_model ( model_kwargs , default_cfg , pretrained = False ) :