@ -50,18 +50,12 @@ default_cfgs = {
' mnasnet_100 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth ' ,
interpolation = ' bicubic ' ) ,
' tflite_mnasnet_100 ' : _cfg (
url = ' https://www.dropbox.com/s/q55ir3tx8mpeyol/tflite_mnasnet_100-31639cdc.pth?dl=1 ' ,
interpolation = ' bicubic ' ) ,
' mnasnet_140 ' : _cfg ( url = ' ' ) ,
' semnasnet_050 ' : _cfg ( url = ' ' ) ,
' semnasnet_075 ' : _cfg ( url = ' ' ) ,
' semnasnet_100 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth ' ,
interpolation = ' bicubic ' ) ,
' tflite_semnasnet_100 ' : _cfg (
url = ' https://www.dropbox.com/s/yiori47sr9dydev/tflite_semnasnet_100-7c780429.pth?dl=1 ' ,
interpolation = ' bicubic ' ) ,
' semnasnet_140 ' : _cfg ( url = ' ' ) ,
' mnasnet_small ' : _cfg ( url = ' ' ) ,
' mobilenetv1_100 ' : _cfg ( url = ' ' ) ,
@ -118,6 +112,7 @@ _DEBUG = False
# Default args for PyTorch BN impl
_BN_MOMENTUM_PT_DEFAULT = 0.1
_BN_EPS_PT_DEFAULT = 1e-5
_BN_ARGS_PT = dict ( momentum = _BN_MOMENTUM_PT_DEFAULT , eps = _BN_EPS_PT_DEFAULT )
# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)
@ -126,23 +121,18 @@ _BN_EPS_PT_DEFAULT = 1e-5
# .9997 (/w .999 in search space) for paper
_BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
_BN_EPS_TF_DEFAULT = 1e-3
_BN_ARGS_TF = dict ( momentum = _BN_MOMENTUM_TF_DEFAULT , eps = _BN_EPS_TF_DEFAULT )
def _resolve_bn_params ( kwargs ) :
# NOTE kwargs passed as dict intentionally
bn_momentum_default = _BN_MOMENTUM_PT_DEFAULT
bn_eps_default = _BN_EPS_PT_DEFAULT
bn_tf = kwargs . pop ( ' bn_tf ' , False )
if bn_tf :
bn_momentum_default = _BN_MOMENTUM_TF_DEFAULT
bn_eps_default = _BN_EPS_TF_DEFAULT
def _resolve_bn_args ( kwargs ) :
bn_args = _BN_ARGS_TF . copy ( ) if kwargs . pop ( ' bn_tf ' , False ) else _BN_ARGS_PT . copy ( )
bn_momentum = kwargs . pop ( ' bn_momentum ' , None )
if bn_momentum is not None :
bn_args [ ' momentum ' ] = bn_momentum
bn_eps = kwargs . pop ( ' bn_eps ' , None )
if bn_momentum is None :
bn_momentum = bn_momentum_default
if bn_eps is None :
bn_eps = bn_eps_default
return bn_momentum , bn_eps
if bn_eps is not None :
bn_args [ ' eps ' ] = bn_eps
return bn_args
def _round_channels ( channels , multiplier = 1.0 , divisor = 8 , channel_min = None ) :
@ -292,6 +282,31 @@ def _decode_arch_def(arch_def, depth_multiplier=1.0):
return arch_args
def swish ( x , inplace = False ) :
if inplace :
return x . mul_ ( x . sigmoid ( ) )
else :
return x * x . sigmoid ( )
def sigmoid ( x , inplace = False ) :
return x . sigmoid_ ( ) if inplace else x . sigmoid ( )
def hard_swish ( x , inplace = False ) :
if inplace :
return x . mul_ ( F . relu6 ( x + 3. ) / 6. )
else :
return x * F . relu6 ( x + 3. ) / 6.
def hard_sigmoid ( x , inplace = False ) :
if inplace :
return x . add_ ( 3. ) . clamp_ ( 0. , 6. ) . div_ ( 6. )
else :
return F . relu6 ( x + 3. ) / 6.
class _BlockBuilder :
""" Build Trunk Blocks
@ -303,9 +318,9 @@ class _BlockBuilder:
"""
def __init__ ( self , channel_multiplier = 1.0 , channel_divisor = 8 , channel_min = None ,
drop_connect_rate = 0. , act_fn = None , se_gate_fn = torch. sigmoid, se_reduce_mid = False ,
bn_ momentum= _BN_MOMENTUM_PT_DEFAULT , bn_eps = _BN_EPS_PT_DEFAULT ,
folded_bn= False , padding_same = False , verbose= False ) :
drop_connect_rate = 0. , act_fn = None , se_gate_fn = sigmoid, se_reduce_mid = False ,
bn_ args= _BN_ARGS_PT , padding_same = False ,
verbose= False ) :
self . channel_multiplier = channel_multiplier
self . channel_divisor = channel_divisor
self . channel_min = channel_min
@ -313,9 +328,7 @@ class _BlockBuilder:
self . act_fn = act_fn
self . se_gate_fn = se_gate_fn
self . se_reduce_mid = se_reduce_mid
self . bn_momentum = bn_momentum
self . bn_eps = bn_eps
self . folded_bn = folded_bn
self . bn_args = bn_args
self . padding_same = padding_same
self . verbose = verbose
@ -331,9 +344,7 @@ class _BlockBuilder:
bt = ba . pop ( ' block_type ' )
ba [ ' in_chs ' ] = self . in_chs
ba [ ' out_chs ' ] = self . _round_channels ( ba [ ' out_chs ' ] )
ba [ ' bn_momentum ' ] = self . bn_momentum
ba [ ' bn_eps ' ] = self . bn_eps
ba [ ' folded_bn ' ] = self . folded_bn
ba [ ' bn_args ' ] = self . bn_args
ba [ ' padding_same ' ] = self . padding_same
# block act fn overrides the model default
ba [ ' act_fn ' ] = ba [ ' act_fn ' ] if ba [ ' act_fn ' ] is not None else self . act_fn
@ -427,18 +438,6 @@ def _initialize_weight_default(m):
nn . init . kaiming_uniform_ ( m . weight , mode = ' fan_in ' , nonlinearity = ' linear ' )
def swish ( x ) :
return x * torch . sigmoid ( x )
def hard_swish ( x ) :
return x * F . relu6 ( x + 3. ) / 6.
def hard_sigmoid ( x ) :
return F . relu6 ( x + 3. ) / 6.
def drop_connect ( inputs , training = False , drop_connect_rate = 0. ) :
""" Apply drop connect. """
if not training :
@ -474,7 +473,7 @@ class ChannelShuffle(nn.Module):
class SqueezeExcite ( nn . Module ) :
def __init__ ( self , in_chs , reduce_chs = None , act_fn = F . relu , gate_fn = torch. sigmoid) :
def __init__ ( self , in_chs , reduce_chs = None , act_fn = F . relu , gate_fn = sigmoid) :
super ( SqueezeExcite , self ) . __init__ ( )
self . act_fn = act_fn
self . gate_fn = gate_fn
@ -486,17 +485,16 @@ class SqueezeExcite(nn.Module):
# NOTE adaptiveavgpool can be used here, but seems to cause issues with NVIDIA AMP performance
x_se = x . view ( x . size ( 0 ) , x . size ( 1 ) , - 1 ) . mean ( - 1 ) . view ( x . size ( 0 ) , x . size ( 1 ) , 1 , 1 )
x_se = self . conv_reduce ( x_se )
x_se = self . act_fn ( x_se )
x_se = self . act_fn ( x_se , inplace = True )
x_se = self . conv_expand ( x_se )
x = self . gate_fn ( x_se ) * x
x = x * self . gate_fn ( x_se )
return x
class ConvBnAct ( nn . Module ) :
def __init__ ( self , in_chs , out_chs , kernel_size ,
stride = 1 , act_fn = F . relu ,
bn_momentum = _BN_MOMENTUM_PT_DEFAULT , bn_eps = _BN_EPS_PT_DEFAULT ,
folded_bn = False , padding_same = False ) :
bn_args = _BN_ARGS_PT , padding_same = False ) :
super ( ConvBnAct , self ) . __init__ ( )
assert stride in [ 1 , 2 ]
self . act_fn = act_fn
@ -504,14 +502,13 @@ class ConvBnAct(nn.Module):
self . conv = sconv2d (
in_chs , out_chs , kernel_size ,
stride = stride , padding = padding , bias = folded_bn )
self . bn1 = None if folded_bn else nn . BatchNorm2d ( out_chs , momentum = bn_momentum , eps = bn_ep s)
stride = stride , padding = padding , bias = False )
self . bn1 = nn . BatchNorm2d ( out_chs , * * bn_arg s)
def forward ( self , x ) :
x = self . conv ( x )
if self . bn1 is not None :
x = self . bn1 ( x )
x = self . act_fn ( x )
x = self . act_fn ( x , inplace = True )
return x
@ -522,9 +519,8 @@ class DepthwiseSeparableConv(nn.Module):
"""
def __init__ ( self , in_chs , out_chs , kernel_size ,
stride = 1 , act_fn = F . relu , noskip = False , pw_act = False ,
se_ratio = 0. , se_gate_fn = torch . sigmoid ,
bn_momentum = _BN_MOMENTUM_PT_DEFAULT , bn_eps = _BN_EPS_PT_DEFAULT ,
folded_bn = False , padding_same = False , drop_connect_rate = 0. ) :
se_ratio = 0. , se_gate_fn = sigmoid ,
bn_args = _BN_ARGS_PT , padding_same = False , drop_connect_rate = 0. ) :
super ( DepthwiseSeparableConv , self ) . __init__ ( )
assert stride in [ 1 , 2 ]
self . has_se = se_ratio is not None and se_ratio > 0.
@ -537,33 +533,31 @@ class DepthwiseSeparableConv(nn.Module):
self . conv_dw = sconv2d (
in_chs , in_chs , kernel_size ,
stride = stride , padding = dw_padding , groups = in_chs , bias = folded_bn )
self . bn1 = None if folded_bn else nn . BatchNorm2d ( in_chs , momentum = bn_momentum , eps = bn_ep s)
stride = stride , padding = dw_padding , groups = in_chs , bias = False )
self . bn1 = nn . BatchNorm2d ( in_chs , * * bn_arg s)
# Squeeze-and-excitation
if self . has_se :
self . se = SqueezeExcite (
in_chs , reduce_chs = max ( 1 , int ( in_chs * se_ratio ) ) , act_fn = act_fn , gate_fn = se_gate_fn )
self . conv_pw = sconv2d ( in_chs , out_chs , 1 , padding = pw_padding , bias = folded_bn )
self . bn2 = None if folded_bn else nn . BatchNorm2d ( out_chs , momentum = bn_momentum , eps = bn_ep s)
self . conv_pw = sconv2d ( in_chs , out_chs , 1 , padding = pw_padding , bias = False )
self . bn2 = nn . BatchNorm2d ( out_chs , * * bn_arg s)
def forward ( self , x ) :
residual = x
x = self . conv_dw ( x )
if self . bn1 is not None :
x = self . bn1 ( x )
x = self . act_fn ( x )
x = self . act_fn ( x , inplace = True )
if self . has_se :
x = self . se ( x )
x = self . conv_pw ( x )
if self . bn2 is not None :
x = self . bn2 ( x )
if self . has_pw_act :
x = self . act_fn ( x )
x = self . act_fn ( x , inplace = True )
if self . has_residual :
if self . drop_connect_rate > 0. :
@ -577,10 +571,9 @@ class InvertedResidual(nn.Module):
def __init__ ( self , in_chs , out_chs , kernel_size ,
stride = 1 , act_fn = F . relu , exp_ratio = 1.0 , noskip = False ,
se_ratio = 0. , se_reduce_mid = False , se_gate_fn = torch. sigmoid,
se_ratio = 0. , se_reduce_mid = False , se_gate_fn = sigmoid,
shuffle_type = None , pw_group = 1 ,
bn_momentum = _BN_MOMENTUM_PT_DEFAULT , bn_eps = _BN_EPS_PT_DEFAULT ,
folded_bn = False , padding_same = False , drop_connect_rate = 0. ) :
bn_args = _BN_ARGS_PT , padding_same = False , drop_connect_rate = 0. ) :
super ( InvertedResidual , self ) . __init__ ( )
mid_chs = int ( in_chs * exp_ratio )
self . has_se = se_ratio is not None and se_ratio > 0.
@ -591,8 +584,8 @@ class InvertedResidual(nn.Module):
pw_padding = _padding_arg ( 0 , padding_same )
# Point-wise expansion
self . conv_pw = sconv2d ( in_chs , mid_chs , 1 , padding = pw_padding , groups = pw_group , bias = folded_bn )
self . bn1 = None if folded_bn else nn . BatchNorm2d ( mid_chs , momentum = bn_momentum , eps = bn_ep s)
self . conv_pw = sconv2d ( in_chs , mid_chs , 1 , padding = pw_padding , groups = pw_group , bias = False )
self . bn1 = nn . BatchNorm2d ( mid_chs , * * bn_arg s)
self . shuffle_type = shuffle_type
if shuffle_type is not None :
@ -600,8 +593,8 @@ class InvertedResidual(nn.Module):
# Depth-wise convolution
self . conv_dw = sconv2d (
mid_chs , mid_chs , kernel_size , padding = dw_padding , stride = stride , groups = mid_chs , bias = folded_bn )
self . bn2 = None if folded_bn else nn . BatchNorm2d ( mid_chs , momentum = bn_momentum , eps = bn_ep s)
mid_chs , mid_chs , kernel_size , padding = dw_padding , stride = stride , groups = mid_chs , bias = False )
self . bn2 = nn . BatchNorm2d ( mid_chs , * * bn_arg s)
# Squeeze-and-excitation
if self . has_se :
@ -610,17 +603,16 @@ class InvertedResidual(nn.Module):
mid_chs , reduce_chs = max ( 1 , int ( se_base_chs * se_ratio ) ) , act_fn = act_fn , gate_fn = se_gate_fn )
# Point-wise linear projection
self . conv_pwl = sconv2d ( mid_chs , out_chs , 1 , padding = pw_padding , groups = pw_group , bias = folded_bn )
self . bn3 = None if folded_bn else nn . BatchNorm2d ( out_chs , momentum = bn_momentum , eps = bn_ep s)
self . conv_pwl = sconv2d ( mid_chs , out_chs , 1 , padding = pw_padding , groups = pw_group , bias = False )
self . bn3 = nn . BatchNorm2d ( out_chs , * * bn_arg s)
def forward ( self , x ) :
residual = x
# Point-wise expansion
x = self . conv_pw ( x )
if self . bn1 is not None :
x = self . bn1 ( x )
x = self . act_fn ( x )
x = self . act_fn ( x , inplace = True )
# FIXME haven't tried this yet
# for channel shuffle when using groups with pointwise convs as per FBNet variants
@ -629,9 +621,8 @@ class InvertedResidual(nn.Module):
# Depth-wise convolution
x = self . conv_dw ( x )
if self . bn2 is not None :
x = self . bn2 ( x )
x = self . act_fn ( x )
x = self . act_fn ( x , inplace = True )
# Squeeze-and-excitation
if self . has_se :
@ -639,7 +630,6 @@ class InvertedResidual(nn.Module):
# Point-wise linear projection
x = self . conv_pwl ( x )
if self . bn3 is not None :
x = self . bn3 ( x )
if self . has_residual :
@ -668,11 +658,9 @@ class GenEfficientNet(nn.Module):
def __init__ ( self , block_args , num_classes = 1000 , in_chans = 3 , stem_size = 32 , num_features = 1280 ,
channel_multiplier = 1.0 , channel_divisor = 8 , channel_min = None ,
bn_momentum = _BN_MOMENTUM_PT_DEFAULT , bn_eps = _BN_EPS_PT_DEFAULT ,
drop_rate = 0. , drop_connect_rate = 0. , act_fn = F . relu ,
se_gate_fn = torch . sigmoid , se_reduce_mid = False ,
global_pool = ' avg ' , head_conv = ' default ' , weight_init = ' goog ' ,
folded_bn = False , padding_same = False , ) :
se_gate_fn = sigmoid , se_reduce_mid = False , bn_args = _BN_ARGS_PT ,
global_pool = ' avg ' , head_conv = ' default ' , weight_init = ' goog ' , padding_same = False ) :
super ( GenEfficientNet , self ) . __init__ ( )
self . num_classes = num_classes
self . drop_rate = drop_rate
@ -682,14 +670,14 @@ class GenEfficientNet(nn.Module):
stem_size = _round_channels ( stem_size , channel_multiplier , channel_divisor , channel_min )
self . conv_stem = sconv2d (
in_chans , stem_size , 3 ,
padding = _padding_arg ( 1 , padding_same ) , stride = 2 , bias = folded_bn )
self . bn1 = None if folded_bn else nn . BatchNorm2d ( stem_size , momentum = bn_momentum , eps = bn_ep s)
padding = _padding_arg ( 1 , padding_same ) , stride = 2 , bias = False )
self . bn1 = nn . BatchNorm2d ( stem_size , * * bn_arg s)
in_chs = stem_size
builder = _BlockBuilder (
channel_multiplier , channel_divisor , channel_min ,
drop_connect_rate , act_fn , se_gate_fn , se_reduce_mid ,
bn_ momentum, bn_eps , folded_bn , padding_same , verbose = _DEBUG )
bn_ args , padding_same , verbose = _DEBUG )
self . blocks = nn . Sequential ( * builder ( in_chs , block_args ) )
in_chs = builder . in_chs
@ -701,9 +689,8 @@ class GenEfficientNet(nn.Module):
self . efficient_head = head_conv == ' efficient '
self . conv_head = sconv2d (
in_chs , self . num_features , 1 ,
padding = _padding_arg ( 0 , padding_same ) , bias = folded_bn and not self . efficient_head )
self . bn2 = None if ( folded_bn or self . efficient_head ) else \
nn . BatchNorm2d ( self . num_features , momentum = bn_momentum , eps = bn_eps )
padding = _padding_arg ( 0 , padding_same ) , bias = False )
self . bn2 = None if self . efficient_head else nn . BatchNorm2d ( self . num_features , * * bn_args )
self . global_pool = SelectAdaptivePool2d ( pool_type = global_pool )
self . classifier = nn . Linear ( self . num_features * self . global_pool . feat_mult ( ) , self . num_classes )
@ -729,25 +716,23 @@ class GenEfficientNet(nn.Module):
def forward_features ( self , x , pool = True ) :
x = self . conv_stem ( x )
if self . bn1 is not None :
x = self . bn1 ( x )
x = self . act_fn ( x )
x = self . act_fn ( x , inplace = True )
x = self . blocks ( x )
if self . efficient_head :
# efficient head, currently only mobilenet-v3 performs pool before last 1x1 conv
x = self . global_pool ( x ) # always need to pool here regardless of flag
x = self . conv_head ( x )
# no BN
x = self . act_fn ( x )
x = self . act_fn ( x , inplace = True )
if pool :
# expect flattened output if pool is true, otherwise keep dim
x = x . view ( x . size ( 0 ) , - 1 )
else :
if self . conv_head is not None :
x = self . conv_head ( x )
if self . bn2 is not None :
x = self . bn2 ( x )
x = self . act_fn ( x )
x = self . act_fn ( x , inplace = True )
if pool :
x = self . global_pool ( x )
x = x . view ( x . size ( 0 ) , - 1 )
@ -785,7 +770,6 @@ def _gen_mnasnet_a1(channel_multiplier, num_classes=1000, **kwargs):
# stage 6, 7x7 in
[ ' ir_r1_k3_s1_e6_c320 ' ] ,
]
bn_momentum , bn_eps = _resolve_bn_params ( kwargs )
model = GenEfficientNet (
_decode_arch_def ( arch_def ) ,
num_classes = num_classes ,
@ -793,8 +777,7 @@ def _gen_mnasnet_a1(channel_multiplier, num_classes=1000, **kwargs):
channel_multiplier = channel_multiplier ,
channel_divisor = 8 ,
channel_min = None ,
bn_momentum = bn_momentum ,
bn_eps = bn_eps ,
bn_args = _resolve_bn_args ( kwargs ) ,
* * kwargs
)
return model
@ -825,7 +808,6 @@ def _gen_mnasnet_b1(channel_multiplier, num_classes=1000, **kwargs):
# stage 6, 7x7 in
[ ' ir_r1_k3_s1_e6_c320_noskip ' ]
]
bn_momentum , bn_eps = _resolve_bn_params ( kwargs )
model = GenEfficientNet (
_decode_arch_def ( arch_def ) ,
num_classes = num_classes ,
@ -833,8 +815,7 @@ def _gen_mnasnet_b1(channel_multiplier, num_classes=1000, **kwargs):
channel_multiplier = channel_multiplier ,
channel_divisor = 8 ,
channel_min = None ,
bn_momentum = bn_momentum ,
bn_eps = bn_eps ,
bn_args = _resolve_bn_args ( kwargs ) ,
* * kwargs
)
return model
@ -858,7 +839,6 @@ def _gen_mnasnet_small(channel_multiplier, num_classes=1000, **kwargs):
[ ' ir_r3_k5_s2_e6_c88_se0.25 ' ] ,
[ ' ir_r1_k3_s1_e6_c144 ' ]
]
bn_momentum , bn_eps = _resolve_bn_params ( kwargs )
model = GenEfficientNet (
_decode_arch_def ( arch_def ) ,
num_classes = num_classes ,
@ -866,8 +846,7 @@ def _gen_mnasnet_small(channel_multiplier, num_classes=1000, **kwargs):
channel_multiplier = channel_multiplier ,
channel_divisor = 8 ,
channel_min = None ,
bn_momentum = bn_momentum ,
bn_eps = bn_eps ,
bn_args = _resolve_bn_args ( kwargs ) ,
* * kwargs
)
return model
@ -885,7 +864,6 @@ def _gen_mobilenet_v1(channel_multiplier, num_classes=1000, **kwargs):
[ ' dsa_r6_k3_s2_c512 ' ] ,
[ ' dsa_r2_k3_s2_c1024 ' ] ,
]
bn_momentum , bn_eps = _resolve_bn_params ( kwargs )
model = GenEfficientNet (
_decode_arch_def ( arch_def ) ,
num_classes = num_classes ,
@ -894,8 +872,7 @@ def _gen_mobilenet_v1(channel_multiplier, num_classes=1000, **kwargs):
channel_multiplier = channel_multiplier ,
channel_divisor = 8 ,
channel_min = None ,
bn_momentum = bn_momentum ,
bn_eps = bn_eps ,
bn_args = _resolve_bn_args ( kwargs ) ,
act_fn = F . relu6 ,
head_conv = ' none ' ,
* * kwargs
@ -917,7 +894,6 @@ def _gen_mobilenet_v2(channel_multiplier, num_classes=1000, **kwargs):
[ ' ir_r3_k3_s2_e6_c160 ' ] ,
[ ' ir_r1_k3_s1_e6_c320 ' ] ,
]
bn_momentum , bn_eps = _resolve_bn_params ( kwargs )
model = GenEfficientNet (
_decode_arch_def ( arch_def ) ,
num_classes = num_classes ,
@ -925,8 +901,7 @@ def _gen_mobilenet_v2(channel_multiplier, num_classes=1000, **kwargs):
channel_multiplier = channel_multiplier ,
channel_divisor = 8 ,
channel_min = None ,
bn_momentum = bn_momentum ,
bn_eps = bn_eps ,
bn_args = _resolve_bn_args ( kwargs ) ,
act_fn = F . relu6 ,
* * kwargs
)
@ -958,7 +933,6 @@ def _gen_mobilenet_v3(channel_multiplier, num_classes=1000, **kwargs):
# stage 6, 7x7 in
[ ' cn_r1_k1_s1_c960 ' ] , # hard-swish
]
bn_momentum , bn_eps = _resolve_bn_params ( kwargs )
model = GenEfficientNet (
_decode_arch_def ( arch_def ) ,
num_classes = num_classes ,
@ -966,8 +940,7 @@ def _gen_mobilenet_v3(channel_multiplier, num_classes=1000, **kwargs):
channel_multiplier = channel_multiplier ,
channel_divisor = 8 ,
channel_min = None ,
bn_momentum = bn_momentum ,
bn_eps = bn_eps ,
bn_args = _resolve_bn_args ( kwargs ) ,
act_fn = hard_swish ,
se_gate_fn = hard_sigmoid ,
se_reduce_mid = True ,
@ -994,7 +967,6 @@ def _gen_chamnet_v1(channel_multiplier, num_classes=1000, **kwargs):
[ ' ir_r4_k3_s2_e7_c152 ' ] ,
[ ' ir_r1_k3_s1_e10_c104 ' ] ,
]
bn_momentum , bn_eps = _resolve_bn_params ( kwargs )
model = GenEfficientNet (
_decode_arch_def ( arch_def ) ,
num_classes = num_classes ,
@ -1003,8 +975,7 @@ def _gen_chamnet_v1(channel_multiplier, num_classes=1000, **kwargs):
channel_multiplier = channel_multiplier ,
channel_divisor = 8 ,
channel_min = None ,
bn_momentum = bn_momentum ,
bn_eps = bn_eps ,
bn_args = _resolve_bn_args ( kwargs ) ,
* * kwargs
)
return model
@ -1027,7 +998,6 @@ def _gen_chamnet_v2(channel_multiplier, num_classes=1000, **kwargs):
[ ' ir_r6_k3_s2_e2_c152 ' ] ,
[ ' ir_r1_k3_s1_e6_c112 ' ] ,
]
bn_momentum , bn_eps = _resolve_bn_params ( kwargs )
model = GenEfficientNet (
_decode_arch_def ( arch_def ) ,
num_classes = num_classes ,
@ -1036,8 +1006,7 @@ def _gen_chamnet_v2(channel_multiplier, num_classes=1000, **kwargs):
channel_multiplier = channel_multiplier ,
channel_divisor = 8 ,
channel_min = None ,
bn_momentum = bn_momentum ,
bn_eps = bn_eps ,
bn_args = _resolve_bn_args ( kwargs ) ,
* * kwargs
)
return model
@ -1061,7 +1030,6 @@ def _gen_fbnetc(channel_multiplier, num_classes=1000, **kwargs):
[ ' ir_r4_k5_s2_e6_c184 ' ] ,
[ ' ir_r1_k3_s1_e6_c352 ' ] ,
]
bn_momentum , bn_eps = _resolve_bn_params ( kwargs )
model = GenEfficientNet (
_decode_arch_def ( arch_def ) ,
num_classes = num_classes ,
@ -1070,8 +1038,7 @@ def _gen_fbnetc(channel_multiplier, num_classes=1000, **kwargs):
channel_multiplier = channel_multiplier ,
channel_divisor = 8 ,
channel_min = None ,
bn_momentum = bn_momentum ,
bn_eps = bn_eps ,
bn_args = _resolve_bn_args ( kwargs ) ,
* * kwargs
)
return model
@ -1101,7 +1068,6 @@ def _gen_spnasnet(channel_multiplier, num_classes=1000, **kwargs):
# stage 6, 7x7 in
[ ' ir_r1_k3_s1_e6_c320_noskip ' ]
]
bn_momentum , bn_eps = _resolve_bn_params ( kwargs )
model = GenEfficientNet (
_decode_arch_def ( arch_def ) ,
num_classes = num_classes ,
@ -1109,8 +1075,7 @@ def _gen_spnasnet(channel_multiplier, num_classes=1000, **kwargs):
channel_multiplier = channel_multiplier ,
channel_divisor = 8 ,
channel_min = None ,
bn_momentum = bn_momentum ,
bn_eps = bn_eps ,
bn_args = _resolve_bn_args ( kwargs ) ,
* * kwargs
)
return model
@ -1147,7 +1112,6 @@ def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=
[ ' ir_r4_k5_s2_e6_c192_se0.25 ' ] ,
[ ' ir_r1_k3_s1_e6_c320_se0.25 ' ] ,
]
bn_momentum , bn_eps = _resolve_bn_params ( kwargs )
# NOTE: other models in the family didn't scale the feature count
num_features = _round_channels ( 1280 , channel_multiplier , 8 , None )
model = GenEfficientNet (
@ -1158,8 +1122,7 @@ def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=
channel_divisor = 8 ,
channel_min = None ,
num_features = num_features ,
bn_momentum = bn_momentum ,
bn_eps = bn_eps ,
bn_args = _resolve_bn_args ( kwargs ) ,
act_fn = swish ,
* * kwargs
)
@ -1205,20 +1168,6 @@ def mnasnet_b1(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
return mnasnet_100 ( pretrained , num_classes , in_chans , * * kwargs )
@register_model
def tflite_mnasnet_100 ( pretrained = False , num_classes = 1000 , in_chans = 3 , * * kwargs ) :
""" MNASNet B1, depth multiplier of 1.0. """
default_cfg = default_cfgs [ ' tflite_mnasnet_100 ' ]
# these two args are for compat with tflite pretrained weights
kwargs [ ' folded_bn ' ] = True
kwargs [ ' padding_same ' ] = True
model = _gen_mnasnet_b1 ( 1.0 , num_classes = num_classes , in_chans = in_chans , * * kwargs )
model . default_cfg = default_cfg
if pretrained :
load_pretrained ( model , default_cfg , num_classes , in_chans )
return model
@register_model
def mnasnet_140 ( pretrained = False , num_classes = 1000 , in_chans = 3 , * * kwargs ) :
""" MNASNet B1, depth multiplier of 1.4 """
@ -1269,20 +1218,6 @@ def mnasnet_a1(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
return semnasnet_100 ( pretrained , num_classes , in_chans , * * kwargs )
@register_model
def tflite_semnasnet_100 ( pretrained = False , num_classes = 1000 , in_chans = 3 , * * kwargs ) :
""" MNASNet A1, depth multiplier of 1.0. """
default_cfg = default_cfgs [ ' tflite_semnasnet_100 ' ]
# these two args are for compat with tflite pretrained weights
kwargs [ ' folded_bn ' ] = True
kwargs [ ' padding_same ' ] = True
model = _gen_mnasnet_a1 ( 1.0 , num_classes = num_classes , in_chans = in_chans , * * kwargs )
model . default_cfg = default_cfg
if pretrained :
load_pretrained ( model , default_cfg , num_classes , in_chans )
return model
@register_model
def semnasnet_140 ( pretrained = False , num_classes = 1000 , in_chans = 3 , * * kwargs ) :
""" MNASNet A1 (w/ SE), depth multiplier of 1.4. """