@ -23,6 +23,7 @@ import torch.nn as nn
import torch . nn . functional as F
from models . helpers import load_pretrained
from models . adaptive_avgmax_pool import SelectAdaptivePool2d
from models . conv2d_same import sconv2d
from data . transforms import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
__all__ = [ ' GenMobileNet ' , ' mnasnet0_50 ' , ' mnasnet0_75 ' , ' mnasnet1_00 ' , ' mnasnet1_40 ' ,
@ -45,10 +46,12 @@ default_cfgs = {
' mnasnet0_50 ' : _cfg ( url = ' ' ) ,
' mnasnet0_75 ' : _cfg ( url = ' ' ) ,
' mnasnet1_00 ' : _cfg ( url = ' ' ) ,
' tflite_mnasnet1_00 ' : _cfg ( url = ' ' , interpolation = ' bicubic ' ) ,
' mnasnet1_40 ' : _cfg ( url = ' ' ) ,
' semnasnet0_50 ' : _cfg ( url = ' ' ) ,
' semnasnet0_75 ' : _cfg ( url = ' ' ) ,
' semnasnet1_00 ' : _cfg ( url = ' ' ) ,
' tflite_semnasnet1_00 ' : _cfg ( url = ' ' , interpolation = ' bicubic ' ) ,
' semnasnet1_40 ' : _cfg ( url = ' ' ) ,
' mnasnet_small ' : _cfg ( url = ' ' ) ,
' mobilenetv1_1_00 ' : _cfg ( url = ' ' ) ,
@ -56,7 +59,7 @@ default_cfgs = {
' chamnetv1_1_00 ' : _cfg ( url = ' ' ) ,
' chamnetv2_1_00 ' : _cfg ( url = ' ' ) ,
' fbnetc_1_00 ' : _cfg ( url = ' ' ) ,
' spnasnet1_00 ' : _cfg ( url = ' ') ,
' spnasnet1_00 ' : _cfg ( url = ' https://www.dropbox.com/s/iieopt18rytkgaa/spnasnet1_00-048bc3f4.pth?dl=1 ') ,
}
_DEBUG = True
@ -184,11 +187,15 @@ def _decode_block_str(block_str):
return [ deepcopy ( block_args ) for _ in range ( num_repeat ) ]
def _get_padding ( kernel_size , stride , dilation ) :
def _get_padding ( kernel_size , stride , dilation = 1 ) :
padding = ( ( stride - 1 ) + dilation * ( kernel_size - 1 ) ) / / 2
return padding
def _padding_arg ( default , padding_same = False ) :
return ' SAME ' if padding_same else default
def _decode_arch_args ( string_list ) :
block_args = [ ]
for block_str in string_list :
@ -219,12 +226,15 @@ class _BlockBuilder:
"""
def __init__ ( self , depth_multiplier = 1.0 , depth_divisor = 8 , min_depth = None ,
bn_momentum = _BN_MOMENTUM_PT_DEFAULT , bn_eps = _BN_EPS_PT_DEFAULT ) :
bn_momentum = _BN_MOMENTUM_PT_DEFAULT , bn_eps = _BN_EPS_PT_DEFAULT ,
folded_bn = False , padding_same = False ) :
self . depth_multiplier = depth_multiplier
self . depth_divisor = depth_divisor
self . min_depth = min_depth
self . bn_momentum = bn_momentum
self . bn_eps = bn_eps
self . folded_bn = folded_bn
self . padding_same = padding_same
self . in_chs = None
def _round_channels ( self , chs ) :
@ -236,6 +246,8 @@ class _BlockBuilder:
ba [ ' out_chs ' ] = _round_channels ( ba [ ' out_chs ' ] )
ba [ ' bn_momentum ' ] = self . bn_momentum
ba [ ' bn_eps ' ] = self . bn_eps
ba [ ' folded_bn ' ] = self . folded_bn
ba [ ' padding_same ' ] = self . padding_same
if _DEBUG :
print ( ' args: ' , ba )
# could replace this with lambdas or functools binding if variety increases
@ -320,29 +332,37 @@ def _initialize_weight_default(m):
class DepthwiseSeparableConv ( nn . Module ) :
def __init__ ( self , in_chs , out_chs , kernel_size ,
stride = 1 , act_fn = F . relu , noskip = False , pw_act = False ,
bn_momentum = _BN_MOMENTUM_PT_DEFAULT , bn_eps = _BN_EPS_PT_DEFAULT ) :
bn_momentum = _BN_MOMENTUM_PT_DEFAULT , bn_eps = _BN_EPS_PT_DEFAULT ,
folded_bn = False , padding_same = False ) :
super ( DepthwiseSeparableConv , self ) . __init__ ( )
assert stride in [ 1 , 2 ]
self . has_residual = ( stride == 1 and in_chs == out_chs ) and not noskip
self . has_pw_act = pw_act # activation after point-wise conv
self . act_fn = act_fn
dw_padding = _padding_arg ( kernel_size / / 2 , padding_same )
pw_padding = _padding_arg ( 0 , padding_same )
self . conv_dw = nn . Conv2d (
self . conv_dw = sc onv2d(
in_chs , in_chs , kernel_size ,
stride = stride , padding = kernel_size / / 2 , groups = in_chs , bias = False )
self . bn1 = nn . BatchNorm2d ( in_chs , momentum = bn_momentum , eps = bn_eps )
self . conv_pw = nn. Conv2d ( in_chs , out_chs , 1 , bias = False )
self . bn2 = nn . BatchNorm2d ( out_chs , momentum = bn_momentum , eps = bn_eps )
stride = stride , padding = dw_padding, groups = in_chs , bias = folded_bn )
self . bn1 = None if folded_bn else nn . BatchNorm2d ( in_chs , momentum = bn_momentum , eps = bn_eps )
self . conv_pw = sconv2d( in_chs , out_chs , 1 , padding = pw_padding , bias = folded_bn )
self . bn2 = None if folded_bn else nn . BatchNorm2d ( out_chs , momentum = bn_momentum , eps = bn_eps )
def forward ( self , x ) :
residual = x
x = self . conv_dw ( x )
if self . bn1 is not None :
x = self . bn1 ( x )
x = self . act_fn ( x )
x = self . conv_pw ( x )
if self . bn2 is not None :
x = self . bn2 ( x )
if self . has_pw_act :
x = self . act_fn ( x )
if self . has_residual :
x + = residual
return x
@ -351,23 +371,27 @@ class DepthwiseSeparableConv(nn.Module):
class CascadeConv3x3 ( nn . Sequential ) :
# FIXME lifted from maskrcnn_benchmark blocks, haven't used yet
def __init__ ( self , in_chs , out_chs , stride , act_fn = F . relu , noskip = False ,
bn_momentum = _BN_MOMENTUM_PT_DEFAULT , bn_eps = _BN_EPS_PT_DEFAULT ) :
bn_momentum = _BN_MOMENTUM_PT_DEFAULT , bn_eps = _BN_EPS_PT_DEFAULT ,
folded_bn = False , padding_same = False ) :
super ( CascadeConv3x3 , self ) . __init__ ( )
assert stride in [ 1 , 2 ]
self . has_residual = not noskip and ( stride == 1 and in_chs == out_chs )
self . has_residual = ( stride == 1 and in_chs == out_chs ) and not noskip
self . act_fn = act_fn
padding = _padding_arg ( 1 , padding_same )
self . conv1 = nn. C onv2d( in_chs , in_chs , 3 , stride = stride , padding = 1 , bias = False )
self . bn1 = nn . BatchNorm2d ( in_chs , momentum = bn_momentum , eps = bn_eps )
self . conv2 = nn. C onv2d( in_chs , out_chs , 3 , stride = 1 , padding = 1 , bias = False )
self . bn2 = nn . BatchNorm2d ( out_chs , momentum = bn_momentum , eps = bn_eps )
self . conv1 = sc onv2d( in_chs , in_chs , 3 , stride = stride , padding = padding , bias = folded_bn )
self . bn1 = None if folded_bn else nn . BatchNorm2d ( in_chs , momentum = bn_momentum , eps = bn_eps )
self . conv2 = sc onv2d( in_chs , out_chs , 3 , stride = 1 , padding = padding , bias = folded_bn )
self . bn2 = None if folded_bn else nn . BatchNorm2d ( out_chs , momentum = bn_momentum , eps = bn_eps )
def forward ( self , x ) :
residual = x
x = self . conv1 ( x )
if self . bn1 is not None :
x = self . bn1 ( x )
x = self . act_fn ( x )
x = self . conv2 ( x )
if self . bn2 is not None :
x = self . bn2 ( x )
if self . has_residual :
x + = residual
@ -396,10 +420,10 @@ class ChannelShuffle(nn.Module):
class SqueezeExcite ( nn . Module ) :
def __init__ ( self , in_chs , se_ratio= 0.25 , act_fn = F . relu ) :
def __init__ ( self , in_chs , reduce_chs= None , act_fn = F . relu ) :
super ( SqueezeExcite , self ) . __init__ ( )
self . act_fn = act_fn
reduced_chs = max ( 1 , int ( in_chs * se_ratio ) )
reduced_chs = reduce_chs or in_chs
self . conv_reduce = nn . Conv2d ( in_chs , reduced_chs , 1 , bias = True )
self . conv_expand = nn . Conv2d ( reduced_chs , in_chs , 1 , bias = True )
@ -419,40 +443,43 @@ class InvertedResidual(nn.Module):
def __init__ ( self , in_chs , out_chs , kernel_size ,
stride = 1 , act_fn = F . relu , exp_ratio = 1.0 , noskip = False ,
se_ratio = 0. , shuffle_type = None , pw_group = 1 ,
bn_momentum = _BN_MOMENTUM_PT_DEFAULT , bn_eps = _BN_EPS_PT_DEFAULT ) :
bn_momentum = _BN_MOMENTUM_PT_DEFAULT , bn_eps = _BN_EPS_PT_DEFAULT ,
folded_bn = False , padding_same = False ) :
super ( InvertedResidual , self ) . __init__ ( )
mid_chs = int ( in_chs * exp_ratio )
self . has_se = se_ratio is not None and se_ratio > 0.
self . has_residual = ( in_chs == out_chs and stride == 1 ) and not noskip
self . act_fn = act_fn
dw_padding = _padding_arg ( kernel_size / / 2 , padding_same )
pw_padding = _padding_arg ( 0 , padding_same )
# Point-wise expansion
self . conv_pw = nn. C onv2d( in_chs , mid_chs , 1 , groups= pw_group , bias = False )
self . bn1 = nn . BatchNorm2d ( mid_chs , momentum = bn_momentum , eps = bn_eps )
self . conv_pw = sc onv2d( in_chs , mid_chs , 1 , padding= pw_padding , groups= pw_group , bias = folded_bn )
self . bn1 = None if folded_bn else nn . BatchNorm2d ( mid_chs , momentum = bn_momentum , eps = bn_eps )
self . shuffle_type = shuffle_type
if shuffle_type is not None :
self . shuffle = ChannelShuffle ( pw_group )
# Depth-wise convolution
self . conv_dw = nn . Conv2d (
mid_chs , mid_chs , kernel_size , padding = kernel_size / / 2 ,
stride = stride , groups = mid_chs , bias = False )
self . bn2 = nn . BatchNorm2d ( mid_chs , momentum = bn_momentum , eps = bn_eps )
self . conv_dw = sconv2d (
mid_chs , mid_chs , kernel_size , padding = dw_padding , stride = stride , groups = mid_chs , bias = folded_bn )
self . bn2 = None if folded_bn else nn . BatchNorm2d ( mid_chs , momentum = bn_momentum , eps = bn_eps )
# Squeeze-and-excitation
if self . has_se :
self . se = SqueezeExcite ( mid_chs , se_ratio)
self . se = SqueezeExcite ( mid_chs , reduce_chs= max ( 1 , int ( in_chs * se_ratio) ) )
# Point-wise linear projection
self . conv_pwl = nn. C onv2d( mid_chs , out_chs , 1 , groups= pw_group , bias = False )
self . bn3 = nn . BatchNorm2d ( out_chs , momentum = bn_momentum , eps = bn_eps )
self . conv_pwl = sc onv2d( mid_chs , out_chs , 1 , padding= pw_padding , groups= pw_group , bias = folded_bn )
self . bn3 = None if folded_bn else nn . BatchNorm2d ( out_chs , momentum = bn_momentum , eps = bn_eps )
def forward ( self , x ) :
residual = x
# Point-wise expansion
x = self . conv_pw ( x )
if self . bn1 is not None :
x = self . bn1 ( x )
x = self . act_fn ( x )
@ -463,6 +490,7 @@ class InvertedResidual(nn.Module):
# Depth-wise convolution
x = self . conv_dw ( x )
if self . bn2 is not None :
x = self . bn2 ( x )
x = self . act_fn ( x )
@ -472,6 +500,7 @@ class InvertedResidual(nn.Module):
# Point-wise linear projection
x = self . conv_pwl ( x )
if self . bn3 is not None :
x = self . bn3 ( x )
if self . has_residual :
@ -498,7 +527,7 @@ class GenMobileNet(nn.Module):
depth_multiplier = 1.0 , depth_divisor = 8 , min_depth = None ,
bn_momentum = _BN_MOMENTUM_PT_DEFAULT , bn_eps = _BN_EPS_PT_DEFAULT ,
drop_rate = 0. , act_fn = F . relu , global_pool = ' avg ' , skip_head_conv = False ,
weight_init = ' goog ' ):
weight_init = ' goog ' , folded_bn = False , padding_same = False ):
super ( GenMobileNet , self ) . __init__ ( )
self . num_classes = num_classes
self . depth_multiplier = depth_multiplier
@ -507,13 +536,15 @@ class GenMobileNet(nn.Module):
self . num_features = num_features
stem_size = _round_channels ( stem_size , depth_multiplier , depth_divisor , min_depth )
self . conv_stem = nn . Conv2d ( in_chans , stem_size , 3 , padding = 1 , stride = 2 , bias = False )
self . bn1 = nn . BatchNorm2d ( stem_size , momentum = bn_momentum , eps = bn_eps )
self . conv_stem = sconv2d (
in_chans , stem_size , 3 ,
padding = _padding_arg ( 1 , padding_same ) , stride = 2 , bias = folded_bn )
self . bn1 = None if folded_bn else nn . BatchNorm2d ( stem_size , momentum = bn_momentum , eps = bn_eps )
in_chs = stem_size
builder = _BlockBuilder (
depth_multiplier , depth_divisor , min_depth ,
bn_momentum , bn_eps )
bn_momentum , bn_eps , folded_bn , padding_same )
self . blocks = nn . Sequential ( * builder ( in_chs , block_args ) )
in_chs = builder . in_chs
@ -521,8 +552,10 @@ class GenMobileNet(nn.Module):
self . conv_head = None
assert in_chs == self . num_features
else :
self . conv_head = nn . Conv2d ( in_chs , self . num_features , 1 , padding = 0 , stride = 1 , bias = False )
self . bn2 = nn . BatchNorm2d ( self . num_features , momentum = bn_momentum , eps = bn_eps )
self . conv_head = sconv2d (
in_chs , self . num_features , 1 ,
padding = _padding_arg ( 0 , padding_same ) , bias = folded_bn )
self . bn2 = None if folded_bn else nn . BatchNorm2d ( self . num_features , momentum = bn_momentum , eps = bn_eps )
self . global_pool = SelectAdaptivePool2d ( pool_type = global_pool )
self . classifier = nn . Linear ( self . num_features , self . num_classes )
@ -548,11 +581,13 @@ class GenMobileNet(nn.Module):
def forward_features ( self , x , pool = True ) :
x = self . conv_stem ( x )
if self . bn1 is not None :
x = self . bn1 ( x )
x = self . act_fn ( x )
x = self . blocks ( x )
if self . conv_head is not None :
x = self . conv_head ( x )
if self . bn2 is not None :
x = self . bn2 ( x )
x = self . act_fn ( x )
if pool :
@ -909,6 +944,19 @@ def mnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs):
return model
def tflite_mnasnet1_00 ( num_classes , in_chans = 3 , pretrained = False , * * kwargs ) :
""" MNASNet B1, depth multiplier of 1.0. """
default_cfg = default_cfgs [ ' tflite_mnasnet1_00 ' ]
# these two args are for compat with tflite pretrained weights
kwargs [ ' folded_bn ' ] = True
kwargs [ ' padding_same ' ] = True
model = _gen_mnasnet_b1 ( 1.0 , num_classes = num_classes , in_chans = in_chans , * * kwargs )
model . default_cfg = default_cfg
if pretrained :
load_pretrained ( model , default_cfg , num_classes , in_chans )
return model
def mnasnet1_40 ( num_classes , in_chans = 3 , pretrained = False , * * kwargs ) :
""" MNASNet B1, depth multiplier of 1.4 """
default_cfg = default_cfgs [ ' mnasnet1_40 ' ]
@ -949,6 +997,19 @@ def semnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs):
return model
def tflite_semnasnet1_00 ( num_classes , in_chans = 3 , pretrained = False , * * kwargs ) :
""" MNASNet A1, depth multiplier of 1.0. """
default_cfg = default_cfgs [ ' tflite_semnasnet1_00 ' ]
# these two args are for compat with tflite pretrained weights
kwargs [ ' folded_bn ' ] = True
kwargs [ ' padding_same ' ] = True
model = _gen_mnasnet_a1 ( 1.0 , num_classes = num_classes , in_chans = in_chans , * * kwargs )
model . default_cfg = default_cfg
if pretrained :
load_pretrained ( model , default_cfg , num_classes , in_chans )
return model
def semnasnet1_40 ( num_classes , in_chans = 3 , pretrained = False , * * kwargs ) :
""" MNASNet A1 (w/ SE), depth multiplier of 1.4. """
default_cfg = default_cfgs [ ' semnasnet1_40 ' ]