@ -1,6 +1,7 @@
""" Generic MobileNet
A generic MobileNet class with building blocks to support a variety of models :
* EfficientNet ( B0 - B4 in code right now , work in progress , still verifying )
* MNasNet B1 , A1 ( SE ) , Small
* MobileNet V1 , V2 , and V3 ( work in progress )
* FBNet - C ( TODO A & B )
@ -30,7 +31,8 @@ _models = [
' mnasnet_050 ' , ' mnasnet_075 ' , ' mnasnet_100 ' , ' mnasnet_140 ' , ' semnasnet_050 ' , ' semnasnet_075 ' ,
' semnasnet_100 ' , ' semnasnet_140 ' , ' mnasnet_small ' , ' mobilenetv1_100 ' , ' mobilenetv2_100 ' ,
' mobilenetv3_050 ' , ' mobilenetv3_075 ' , ' mobilenetv3_100 ' , ' chamnetv1_100 ' , ' chamnetv2_100 ' ,
' fbnetc_100 ' , ' spnasnet_100 ' , ' tflite_mnasnet_100 ' , ' tflite_semnasnet_100 ' ]
' fbnetc_100 ' , ' spnasnet_100 ' , ' tflite_mnasnet_100 ' , ' tflite_semnasnet_100 ' , ' efficientnet_b0 ' ,
' efficientnet_b1 ' , ' efficientnet_b2 ' , ' efficientnet_b3 ' , ' efficientnet_b4 ' ]
__all__ = [ ' GenMobileNet ' , ' genmobilenet_model_names ' ] + _models
@ -67,6 +69,11 @@ default_cfgs = {
' chamnetv2_100 ' : _cfg ( url = ' ' ) ,
' fbnetc_100 ' : _cfg ( url = ' https://www.dropbox.com/s/0ku2tztuibrynld/fbnetc_100-f49a0c5f.pth?dl=1 ' ) ,
' spnasnet_100 ' : _cfg ( url = ' https://www.dropbox.com/s/iieopt18rytkgaa/spnasnet_100-048bc3f4.pth?dl=1 ' ) ,
' efficientnet_b0 ' : _cfg ( url = ' ' ) ,
' efficientnet_b1 ' : _cfg ( url = ' ' , input_size = ( 3 , 240 , 240 ) ) ,
' efficientnet_b2 ' : _cfg ( url = ' ' , input_size = ( 3 , 260 , 260 ) ) ,
' efficientnet_b3 ' : _cfg ( url = ' ' , input_size = ( 3 , 300 , 300 ) ) ,
' efficientnet_b4 ' : _cfg ( url = ' ' , input_size = ( 3 , 380 , 380 ) ) ,
}
_DEBUG = False
@ -101,23 +108,23 @@ def _resolve_bn_params(kwargs):
return bn_momentum , bn_eps
def _round_channels ( channels , depth_ multiplier= 1.0 , d epth_d ivisor= 8 , min_depth = None ) :
def _round_channels ( channels , multiplier= 1.0 , d ivisor= 8 , channel_ min= None ) :
""" Round number of filters based on depth multiplier. """
if not depth_ multiplier:
if not multiplier:
return channels
channels * = depth_ multiplier
min_depth = min_depth or depth_ divisor
channels * = multiplier
channel_ min = channel_ min or divisor
new_channels = max (
int ( channels + d epth_d ivisor / 2 ) / / d epth_d ivisor * depth_ divisor,
min_depth )
int ( channels + d ivisor / 2 ) / / d ivisor * divisor,
channel_ min)
# Make sure that round down does not go down by more than 10%.
if new_channels < 0.9 * channels :
new_channels + = d epth_d ivisor
new_channels + = d ivisor
return new_channels
def _decode_block_str ( block_str ):
def _decode_block_str ( block_str , depth_multiplier = 1.0 ):
""" Decode block definition string
Gets a list of block arg ( dicts ) through a string notation of arguments .
@ -207,6 +214,7 @@ def _decode_block_str(block_str):
block_type = block_type ,
kernel_size = int ( options [ ' k ' ] ) ,
out_chs = int ( options [ ' c ' ] ) ,
se_ratio = float ( options [ ' se ' ] ) if ' se ' in options else None ,
stride = int ( options [ ' s ' ] ) ,
act_fn = act_fn ,
noskip = block_type == ' dsa ' or noskip ,
@ -223,7 +231,9 @@ def _decode_block_str(block_str):
else :
assert False , ' Unknown block type ( %s ) ' % block_type
# return a list of block args expanded by num_repeat
# return a list of block args expanded by num_repeat and
# scaled by depth_multiplier
num_repeat = int ( math . ceil ( num_repeat * depth_multiplier ) )
return [ deepcopy ( block_args ) for _ in range ( num_repeat ) ]
@ -243,14 +253,14 @@ def _decode_arch_args(string_list):
return block_args
def _decode_arch_def ( arch_def ):
def _decode_arch_def ( arch_def , depth_multiplier = 1.0 ):
arch_args = [ ]
for stack_idx , block_strings in enumerate ( arch_def ) :
assert isinstance ( block_strings , list )
stack_args = [ ]
for block_str in block_strings :
assert isinstance ( block_str , str )
stack_args . extend ( _decode_block_str ( block_str ))
stack_args . extend ( _decode_block_str ( block_str , depth_multiplier ))
arch_args . append ( stack_args )
return arch_args
@ -265,13 +275,13 @@ class _BlockBuilder:
"""
def __init__ ( self , depth_multiplier= 1.0 , depth_divisor = 8 , min_depth = None ,
def __init__ ( self , channel_multiplier= 1.0 , channel_divisor = 8 , channel_min = None ,
act_fn = None , se_gate_fn = torch . sigmoid , se_reduce_mid = False ,
bn_momentum = _BN_MOMENTUM_PT_DEFAULT , bn_eps = _BN_EPS_PT_DEFAULT ,
folded_bn = False , padding_same = False , verbose = False ) :
self . depth_multiplier = depth _multiplier
self . depth_divisor = depth _divisor
self . min_depth = min_depth
self . channel_multiplier = channel _multiplier
self . channel_divisor = channel _divisor
self . channel_ min = channel_ min
self . act_fn = act_fn
self . se_gate_fn = se_gate_fn
self . se_reduce_mid = se_reduce_mid
@ -283,7 +293,7 @@ class _BlockBuilder:
self . in_chs = None
def _round_channels ( self , chs ) :
return _round_channels ( chs , self . depth_multiplier, self . depth_divisor , self . min_depth )
return _round_channels ( chs , self . channel_multiplier, self . channel_divisor , self . channel_min )
def _make_block ( self , ba ) :
bt = ba . pop ( ' block_type ' )
@ -327,7 +337,7 @@ class _BlockBuilder:
blocks . append ( block )
return nn . Sequential ( * blocks )
def __call__ ( self , in_chs , arch_def ) :
def __call__ ( self , in_chs , block_args ) :
""" Build the blocks
Args :
in_chs : Number of input - channels passed to first block
@ -336,13 +346,12 @@ class _BlockBuilder:
Return :
List of block stacks ( each stack wrapped in nn . Sequential )
"""
arch_args = _decode_arch_def ( arch_def ) # convert and expand string defs to arg dicts
if self . verbose :
print ( ' Building model trunk with %d stacks (stages)... ' % len ( arch _args) )
print ( ' Building model trunk with %d stacks (stages)... ' % len ( block _args) )
self . in_chs = in_chs
blocks = [ ]
# outer list of arch _args defines the stacks ('stages' by some conventions)
for stack_idx , stack in enumerate ( arch _args) :
# outer list of block _args defines the stacks ('stages' by some conventions)
for stack_idx , stack in enumerate ( block _args) :
if self . verbose :
print ( ' stack ' , stack_idx )
assert isinstance ( stack , list )
@ -381,6 +390,10 @@ def _initialize_weight_default(m):
nn . init . kaiming_uniform_ ( m . weight , mode = ' fan_in ' , nonlinearity = ' linear ' )
def swish ( x ) :
return x * torch . sigmoid ( x )
def hard_swish ( x ) :
return x * F . relu6 ( x + 3. ) / 6.
@ -389,6 +402,46 @@ def hard_sigmoid(x):
return F . relu6 ( x + 3. ) / 6.
class ChannelShuffle ( nn . Module ) :
# FIXME haven't used yet
def __init__ ( self , groups ) :
super ( ChannelShuffle , self ) . __init__ ( )
self . groups = groups
def forward ( self , x ) :
""" Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W] """
N , C , H , W = x . size ( )
g = self . groups
assert C % g == 0 , " Incompatible group size {} for input channel {} " . format (
g , C
)
return (
x . view ( N , g , int ( C / g ) , H , W )
. permute ( 0 , 2 , 1 , 3 , 4 )
. contiguous ( )
. view ( N , C , H , W )
)
class SqueezeExcite ( nn . Module ) :
def __init__ ( self , in_chs , reduce_chs = None , act_fn = F . relu , gate_fn = torch . sigmoid ) :
super ( SqueezeExcite , self ) . __init__ ( )
self . act_fn = act_fn
self . gate_fn = gate_fn
reduced_chs = reduce_chs or in_chs
self . conv_reduce = nn . Conv2d ( in_chs , reduced_chs , 1 , bias = True )
self . conv_expand = nn . Conv2d ( reduced_chs , in_chs , 1 , bias = True )
def forward ( self , x ) :
# NOTE adaptiveavgpool can be used here, but seems to cause issues with NVIDIA AMP performance
x_se = x . view ( x . size ( 0 ) , x . size ( 1 ) , - 1 ) . mean ( - 1 ) . view ( x . size ( 0 ) , x . size ( 1 ) , 1 , 1 )
x_se = self . conv_reduce ( x_se )
x_se = self . act_fn ( x_se )
x_se = self . conv_expand ( x_se )
x = self . gate_fn ( x_se ) * x
return x
class ConvBnAct ( nn . Module ) :
def __init__ ( self , in_chs , out_chs , kernel_size ,
stride = 1 , act_fn = F . relu ,
@ -413,12 +466,18 @@ class ConvBnAct(nn.Module):
class DepthwiseSeparableConv ( nn . Module ) :
""" DepthwiseSeparable block
Used for DS convs in MobileNet - V1 and in the place of IR blocks with an expansion
factor of 1.0 . This is an alternative to having a IR with optional first pw conv .
"""
def __init__ ( self , in_chs , out_chs , kernel_size ,
stride = 1 , act_fn = F . relu , noskip = False , pw_act = False ,
se_ratio = 0. , se_gate_fn = torch . sigmoid ,
bn_momentum = _BN_MOMENTUM_PT_DEFAULT , bn_eps = _BN_EPS_PT_DEFAULT ,
folded_bn = False , padding_same = False ) :
super ( DepthwiseSeparableConv , self ) . __init__ ( )
assert stride in [ 1 , 2 ]
self . has_se = se_ratio is not None and se_ratio > 0.
self . has_residual = ( stride == 1 and in_chs == out_chs ) and not noskip
self . has_pw_act = pw_act # activation after point-wise conv
self . act_fn = act_fn
@ -429,6 +488,12 @@ class DepthwiseSeparableConv(nn.Module):
in_chs , in_chs , kernel_size ,
stride = stride , padding = dw_padding , groups = in_chs , bias = folded_bn )
self . bn1 = None if folded_bn else nn . BatchNorm2d ( in_chs , momentum = bn_momentum , eps = bn_eps )
# Squeeze-and-excitation
if self . has_se :
self . se = SqueezeExcite (
in_chs , reduce_chs = max ( 1 , int ( in_chs * se_ratio ) ) , act_fn = act_fn , gate_fn = se_gate_fn )
self . conv_pw = sconv2d ( in_chs , out_chs , 1 , padding = pw_padding , bias = folded_bn )
self . bn2 = None if folded_bn else nn . BatchNorm2d ( out_chs , momentum = bn_momentum , eps = bn_eps )
@ -440,6 +505,9 @@ class DepthwiseSeparableConv(nn.Module):
x = self . bn1 ( x )
x = self . act_fn ( x )
if self . has_se :
x = self . se ( x )
x = self . conv_pw ( x )
if self . bn2 is not None :
x = self . bn2 ( x )
@ -447,7 +515,7 @@ class DepthwiseSeparableConv(nn.Module):
x = self . act_fn ( x )
if self . has_residual :
x + = residual
x + = residual # FIXME add drop-connect
return x
@ -481,46 +549,6 @@ class CascadeConv(nn.Sequential):
return x
class ChannelShuffle ( nn . Module ) :
# FIXME haven't used yet
def __init__ ( self , groups ) :
super ( ChannelShuffle , self ) . __init__ ( )
self . groups = groups
def forward ( self , x ) :
""" Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W] """
N , C , H , W = x . size ( )
g = self . groups
assert C % g == 0 , " Incompatible group size {} for input channel {} " . format (
g , C
)
return (
x . view ( N , g , int ( C / g ) , H , W )
. permute ( 0 , 2 , 1 , 3 , 4 )
. contiguous ( )
. view ( N , C , H , W )
)
class SqueezeExcite ( nn . Module ) :
def __init__ ( self , in_chs , reduce_chs = None , act_fn = F . relu , gate_fn = torch . sigmoid ) :
super ( SqueezeExcite , self ) . __init__ ( )
self . act_fn = act_fn
self . gate_fn = gate_fn
reduced_chs = reduce_chs or in_chs
self . conv_reduce = nn . Conv2d ( in_chs , reduced_chs , 1 , bias = True )
self . conv_expand = nn . Conv2d ( reduced_chs , in_chs , 1 , bias = True )
def forward ( self , x ) :
# NOTE adaptiveavgpool can be used here, but seems to cause issues with NVIDIA AMP performance
x_se = x . view ( x . size ( 0 ) , x . size ( 1 ) , - 1 ) . mean ( - 1 ) . view ( x . size ( 0 ) , x . size ( 1 ) , 1 , 1 )
x_se = self . conv_reduce ( x_se )
x_se = self . act_fn ( x_se )
x_se = self . conv_expand ( x_se )
x = self . gate_fn ( x_se ) * x
return x
class InvertedResidual ( nn . Module ) :
""" Inverted residual block w/ optional SE """
@ -554,8 +582,8 @@ class InvertedResidual(nn.Module):
# Squeeze-and-excitation
if self . has_se :
se_base_chs = mid_chs if se_reduce_mid else in_chs
self . se = SqueezeExcite ( mid_chs , reduce_chs = max ( 1 , int ( se_base_chs * se_ratio ) ) ,
act_fn = act_fn , gate_fn = se_gate_fn )
self . se = SqueezeExcite (
mid_chs , reduce_chs = max ( 1 , int ( se_base_chs * se_ratio ) ) , act_fn = act_fn , gate_fn = se_gate_fn )
# Point-wise linear projection
self . conv_pwl = sconv2d ( mid_chs , out_chs , 1 , padding = pw_padding , groups = pw_group , bias = folded_bn )
@ -591,7 +619,7 @@ class InvertedResidual(nn.Module):
x = self . bn3 ( x )
if self . has_residual :
x + = residual
x + = residual # FIXME add drop-connect
# NOTE maskrcnn_benchmark building blocks have an SE module defined here for some variants
@ -609,22 +637,23 @@ class GenMobileNet(nn.Module):
* FBNet A , B , and C
* ChamNet ( arch details are murky )
* Single - Path NAS Pixel1
* EfficientNet
"""
def __init__ ( self , block_args , num_classes = 1000 , in_chans = 3 , stem_size = 32 , num_features = 1280 ,
depth_multiplier= 1.0 , depth_divisor = 8 , min_depth = None ,
channel_multiplier= 1.0 , channel_divisor = 8 , channel_min = None ,
bn_momentum = _BN_MOMENTUM_PT_DEFAULT , bn_eps = _BN_EPS_PT_DEFAULT ,
drop_rate = 0. , act_fn = F . relu , se_gate_fn = torch . sigmoid , se_reduce_mid = False ,
global_pool = ' avg ' , head_conv = ' default ' , weight_init = ' goog ' ,
folded_bn = False , padding_same = False ) :
super ( GenMobileNet , self ) . __init__ ( )
self . num_classes = num_classes
self . depth_multiplier = depth _multiplier
self . depth_multiplier = channel _multiplier
self . drop_rate = drop_rate
self . act_fn = act_fn
self . num_features = num_features
stem_size = _round_channels ( stem_size , depth_multiplier, depth_divisor , min_depth )
stem_size = _round_channels ( stem_size , channel_multiplier, channel_divisor , channel_min )
self . conv_stem = sconv2d (
in_chans , stem_size , 3 ,
padding = _padding_arg ( 1 , padding_same ) , stride = 2 , bias = folded_bn )
@ -632,7 +661,7 @@ class GenMobileNet(nn.Module):
in_chs = stem_size
builder = _BlockBuilder (
depth_multiplier, depth_divisor , min_depth ,
channel_multiplier, channel_divisor , channel_min ,
act_fn , se_gate_fn , se_reduce_mid ,
bn_momentum , bn_eps , folded_bn , padding_same , verbose = _DEBUG )
self . blocks = nn . Sequential ( * builder ( in_chs , block_args ) )
@ -705,14 +734,14 @@ class GenMobileNet(nn.Module):
return self . classifier ( x )
def _gen_mnasnet_a1 ( depth _multiplier, num_classes = 1000 , * * kwargs ) :
def _gen_mnasnet_a1 ( channel _multiplier, num_classes = 1000 , * * kwargs ) :
""" Creates a mnasnet-a1 model.
Ref impl : https : / / github . com / tensorflow / tpu / tree / master / models / official / mnasnet
Paper : https : / / arxiv . org / pdf / 1807.11626 . pdf .
Args :
depth _multiplier: multiplier to number of channels per layer .
channel _multiplier: multiplier to number of channels per layer .
"""
arch_def = [
# stage 0, 112x112 in
@ -732,12 +761,12 @@ def _gen_mnasnet_a1(depth_multiplier, num_classes=1000, **kwargs):
]
bn_momentum , bn_eps = _resolve_bn_params ( kwargs )
model = GenMobileNet (
arch_def,
_decode_ arch_def( arch_def ) ,
num_classes = num_classes ,
stem_size = 32 ,
depth_multiplier= depth _multiplier,
depth _divisor= 8 ,
min_depth = None ,
channel_multiplier= channel _multiplier,
channel _divisor= 8 ,
channel_ min= None ,
bn_momentum = bn_momentum ,
bn_eps = bn_eps ,
* * kwargs
@ -745,14 +774,14 @@ def _gen_mnasnet_a1(depth_multiplier, num_classes=1000, **kwargs):
return model
def _gen_mnasnet_b1 ( depth _multiplier, num_classes = 1000 , * * kwargs ) :
def _gen_mnasnet_b1 ( channel _multiplier, num_classes = 1000 , * * kwargs ) :
""" Creates a mnasnet-b1 model.
Ref impl : https : / / github . com / tensorflow / tpu / tree / master / models / official / mnasnet
Paper : https : / / arxiv . org / pdf / 1807.11626 . pdf .
Args :
depth _multiplier: multiplier to number of channels per layer .
channel _multiplier: multiplier to number of channels per layer .
"""
arch_def = [
# stage 0, 112x112 in
@ -772,12 +801,12 @@ def _gen_mnasnet_b1(depth_multiplier, num_classes=1000, **kwargs):
]
bn_momentum , bn_eps = _resolve_bn_params ( kwargs )
model = GenMobileNet (
arch_def,
_decode_ arch_def( arch_def ) ,
num_classes = num_classes ,
stem_size = 32 ,
depth_multiplier= depth _multiplier,
depth _divisor= 8 ,
min_depth = None ,
channel_multiplier= channel _multiplier,
channel _divisor= 8 ,
channel_ min= None ,
bn_momentum = bn_momentum ,
bn_eps = bn_eps ,
* * kwargs
@ -785,14 +814,14 @@ def _gen_mnasnet_b1(depth_multiplier, num_classes=1000, **kwargs):
return model
def _gen_mnasnet_small ( depth _multiplier, num_classes = 1000 , * * kwargs ) :
def _gen_mnasnet_small ( channel _multiplier, num_classes = 1000 , * * kwargs ) :
""" Creates a mnasnet-b1 model.
Ref impl : https : / / github . com / tensorflow / tpu / tree / master / models / official / mnasnet
Paper : https : / / arxiv . org / pdf / 1807.11626 . pdf .
Args :
depth _multiplier: multiplier to number of channels per layer .
channel _multiplier: multiplier to number of channels per layer .
"""
arch_def = [
[ ' ds_r1_k3_s1_c8 ' ] ,
@ -805,12 +834,12 @@ def _gen_mnasnet_small(depth_multiplier, num_classes=1000, **kwargs):
]
bn_momentum , bn_eps = _resolve_bn_params ( kwargs )
model = GenMobileNet (
arch_def,
_decode_ arch_def( arch_def ) ,
num_classes = num_classes ,
stem_size = 8 ,
depth_multiplier= depth _multiplier,
depth _divisor= 8 ,
min_depth = None ,
channel_multiplier= channel _multiplier,
channel _divisor= 8 ,
channel_ min= None ,
bn_momentum = bn_momentum ,
bn_eps = bn_eps ,
* * kwargs
@ -818,7 +847,7 @@ def _gen_mnasnet_small(depth_multiplier, num_classes=1000, **kwargs):
return model
def _gen_mobilenet_v1 ( depth _multiplier, num_classes = 1000 , * * kwargs ) :
def _gen_mobilenet_v1 ( channel _multiplier, num_classes = 1000 , * * kwargs ) :
""" Generate MobileNet-V1 network
Ref impl : https : / / github . com / tensorflow / models / blob / master / research / slim / nets / mobilenet / mobilenet_v2 . py
Paper : https : / / arxiv . org / abs / 1801.04381
@ -832,13 +861,13 @@ def _gen_mobilenet_v1(depth_multiplier, num_classes=1000, **kwargs):
]
bn_momentum , bn_eps = _resolve_bn_params ( kwargs )
model = GenMobileNet (
arch_def,
_decode_ arch_def( arch_def ) ,
num_classes = num_classes ,
stem_size = 32 ,
num_features = 1024 ,
depth_multiplier= depth _multiplier,
depth _divisor= 8 ,
min_depth = None ,
channel_multiplier= channel _multiplier,
channel _divisor= 8 ,
channel_ min= None ,
bn_momentum = bn_momentum ,
bn_eps = bn_eps ,
act_fn = F . relu6 ,
@ -848,7 +877,7 @@ def _gen_mobilenet_v1(depth_multiplier, num_classes=1000, **kwargs):
return model
def _gen_mobilenet_v2 ( depth _multiplier, num_classes = 1000 , * * kwargs ) :
def _gen_mobilenet_v2 ( channel _multiplier, num_classes = 1000 , * * kwargs ) :
""" Generate MobileNet-V2 network
Ref impl : https : / / github . com / tensorflow / models / blob / master / research / slim / nets / mobilenet / mobilenet_v2 . py
Paper : https : / / arxiv . org / abs / 1801.04381
@ -864,12 +893,12 @@ def _gen_mobilenet_v2(depth_multiplier, num_classes=1000, **kwargs):
]
bn_momentum , bn_eps = _resolve_bn_params ( kwargs )
model = GenMobileNet (
arch_def,
_decode_ arch_def( arch_def ) ,
num_classes = num_classes ,
stem_size = 32 ,
depth_multiplier= depth _multiplier,
depth _divisor= 8 ,
min_depth = None ,
channel_multiplier= channel _multiplier,
channel _divisor= 8 ,
channel_ min= None ,
bn_momentum = bn_momentum ,
bn_eps = bn_eps ,
act_fn = F . relu6 ,
@ -878,14 +907,14 @@ def _gen_mobilenet_v2(depth_multiplier, num_classes=1000, **kwargs):
return model
def _gen_mobilenet_v3 ( depth _multiplier, num_classes = 1000 , * * kwargs ) :
def _gen_mobilenet_v3 ( channel _multiplier, num_classes = 1000 , * * kwargs ) :
""" Creates a MobileNet-V3 model.
Ref impl : ?
Paper : https : / / arxiv . org / abs / 1905.02244
Args :
depth _multiplier: multiplier to number of channels per layer .
channel _multiplier: multiplier to number of channels per layer .
"""
arch_def = [
# stage 0, 112x112 in
@ -905,12 +934,12 @@ def _gen_mobilenet_v3(depth_multiplier, num_classes=1000, **kwargs):
]
bn_momentum , bn_eps = _resolve_bn_params ( kwargs )
model = GenMobileNet (
arch_def,
_decode_ arch_def( arch_def ) ,
num_classes = num_classes ,
stem_size = 16 ,
depth_multiplier= depth _multiplier,
depth _divisor= 8 ,
min_depth = None ,
channel_multiplier= channel _multiplier,
channel _divisor= 8 ,
channel_ min= None ,
bn_momentum = bn_momentum ,
bn_eps = bn_eps ,
act_fn = hard_swish ,
@ -922,7 +951,7 @@ def _gen_mobilenet_v3(depth_multiplier, num_classes=1000, **kwargs):
return model
def _gen_chamnet_v1 ( depth _multiplier, num_classes = 1000 , * * kwargs ) :
def _gen_chamnet_v1 ( channel _multiplier, num_classes = 1000 , * * kwargs ) :
""" Generate Chameleon Network (ChamNet)
Paper : https : / / arxiv . org / abs / 1812.08934
@ -941,13 +970,13 @@ def _gen_chamnet_v1(depth_multiplier, num_classes=1000, **kwargs):
]
bn_momentum , bn_eps = _resolve_bn_params ( kwargs )
model = GenMobileNet (
arch_def,
_decode_ arch_def( arch_def ) ,
num_classes = num_classes ,
stem_size = 32 ,
num_features = 1280 , # no idea what this is? try mobile/mnasnet default?
depth_multiplier= depth _multiplier,
depth _divisor= 8 ,
min_depth = None ,
channel_multiplier= channel _multiplier,
channel _divisor= 8 ,
channel_ min= None ,
bn_momentum = bn_momentum ,
bn_eps = bn_eps ,
* * kwargs
@ -955,7 +984,7 @@ def _gen_chamnet_v1(depth_multiplier, num_classes=1000, **kwargs):
return model
def _gen_chamnet_v2 ( depth _multiplier, num_classes = 1000 , * * kwargs ) :
def _gen_chamnet_v2 ( channel _multiplier, num_classes = 1000 , * * kwargs ) :
""" Generate Chameleon Network (ChamNet)
Paper : https : / / arxiv . org / abs / 1812.08934
@ -974,13 +1003,13 @@ def _gen_chamnet_v2(depth_multiplier, num_classes=1000, **kwargs):
]
bn_momentum , bn_eps = _resolve_bn_params ( kwargs )
model = GenMobileNet (
arch_def,
_decode_ arch_def( arch_def ) ,
num_classes = num_classes ,
stem_size = 32 ,
num_features = 1280 , # no idea what this is? try mobile/mnasnet default?
depth_multiplier= depth _multiplier,
depth _divisor= 8 ,
min_depth = None ,
channel_multiplier= channel _multiplier,
channel _divisor= 8 ,
channel_ min= None ,
bn_momentum = bn_momentum ,
bn_eps = bn_eps ,
* * kwargs
@ -988,7 +1017,7 @@ def _gen_chamnet_v2(depth_multiplier, num_classes=1000, **kwargs):
return model
def _gen_fbnetc ( depth _multiplier, num_classes = 1000 , * * kwargs ) :
def _gen_fbnetc ( channel _multiplier, num_classes = 1000 , * * kwargs ) :
""" FBNet-C
Paper : https : / / arxiv . org / abs / 1812.03443
@ -1008,13 +1037,13 @@ def _gen_fbnetc(depth_multiplier, num_classes=1000, **kwargs):
]
bn_momentum , bn_eps = _resolve_bn_params ( kwargs )
model = GenMobileNet (
arch_def,
_decode_ arch_def( arch_def ) ,
num_classes = num_classes ,
stem_size = 16 ,
num_features = 1984 , # paper suggests this, but is not 100% clear
depth_multiplier= depth _multiplier,
depth _divisor= 8 ,
min_depth = None ,
channel_multiplier= channel _multiplier,
channel _divisor= 8 ,
channel_ min= None ,
bn_momentum = bn_momentum ,
bn_eps = bn_eps ,
* * kwargs
@ -1022,13 +1051,13 @@ def _gen_fbnetc(depth_multiplier, num_classes=1000, **kwargs):
return model
def _gen_spnasnet ( depth _multiplier, num_classes = 1000 , * * kwargs ) :
def _gen_spnasnet ( channel _multiplier, num_classes = 1000 , * * kwargs ) :
""" Creates the Single-Path NAS model from search targeted for Pixel1 phone.
Paper : https : / / arxiv . org / abs / 1904.02877
Args :
depth _multiplier: multiplier to number of channels per layer .
channel _multiplier: multiplier to number of channels per layer .
"""
arch_def = [
# stage 0, 112x112 in
@ -1048,12 +1077,12 @@ def _gen_spnasnet(depth_multiplier, num_classes=1000, **kwargs):
]
bn_momentum , bn_eps = _resolve_bn_params ( kwargs )
model = GenMobileNet (
arch_def,
_decode_ arch_def( arch_def ) ,
num_classes = num_classes ,
stem_size = 32 ,
depth_multiplier= depth _multiplier,
depth _divisor= 8 ,
min_depth = None ,
channel_multiplier= channel _multiplier,
channel _divisor= 8 ,
channel_ min= None ,
bn_momentum = bn_momentum ,
bn_eps = bn_eps ,
* * kwargs
@ -1061,6 +1090,41 @@ def _gen_spnasnet(depth_multiplier, num_classes=1000, **kwargs):
return model
def _gen_efficientnet ( channel_multiplier = 1.0 , depth_multiplier = 1.0 , num_classes = 1000 , * * kwargs ) :
""" Creates a MobileNet-V3 model.
Ref impl : https : / / github . com / tensorflow / tpu / blob / master / models / official / efficientnet / efficientnet_model . py
Paper : https : / / arxiv . org / abs / 1905.11946
Args :
channel_multiplier : multiplier to number of channels per layer
depth_multiplier : multiplier to number of repeats per stage
"""
arch_def = [
[ ' ds_r1_k3_s1_e1_c16_se0.25 ' ] ,
[ ' ir_r2_k3_s2_e6_c24_se0.25 ' ] ,
[ ' ir_r2_k5_s2_e6_c40_se0.25 ' ] ,
[ ' ir_r3_k3_s2_e6_c80_se0.25 ' ] ,
[ ' ir_r3_k5_s1_e6_c112_se0.25 ' ] ,
[ ' ir_r4_k5_s2_e6_c192_se0.25 ' ] ,
[ ' ir_r1_k3_s1_e6_c320_se0.25 ' ] ,
]
bn_momentum , bn_eps = _resolve_bn_params ( kwargs )
model = GenMobileNet (
_decode_arch_def ( arch_def , depth_multiplier ) ,
num_classes = num_classes ,
stem_size = 32 ,
channel_multiplier = channel_multiplier ,
channel_divisor = 8 ,
channel_min = None ,
bn_momentum = bn_momentum ,
bn_eps = bn_eps ,
act_fn = swish ,
* * kwargs
)
return model
def mnasnet_050 ( num_classes = 1000 , in_chans = 3 , pretrained = False , * * kwargs ) :
""" MNASNet B1, depth multiplier of 0.5. """
default_cfg = default_cfgs [ ' mnasnet_050 ' ]
@ -1270,5 +1334,81 @@ def spnasnet_100(num_classes, in_chans=3, pretrained=False, **kwargs):
return model
# EfficientNet params
# (width_coefficient, depth_coefficient, resolution, dropout_rate)
# 'efficientnet-b0': (1.0, 1.0, 224, 0.2),
# 'efficientnet-b1': (1.0, 1.1, 240, 0.2),
# 'efficientnet-b2': (1.1, 1.2, 260, 0.3),
# 'efficientnet-b3': (1.2, 1.4, 300, 0.3),
# 'efficientnet-b4': (1.4, 1.8, 380, 0.4),
# 'efficientnet-b5': (1.6, 2.2, 456, 0.4),
# 'efficientnet-b6': (1.8, 2.6, 528, 0.5),
# 'efficientnet-b7': (2.0, 3.1, 600, 0.5),
def efficientnet_b0 ( num_classes , in_chans = 3 , pretrained = False , * * kwargs ) :
""" EfficientNet """
default_cfg = default_cfgs [ ' efficientnet_b0 ' ]
# NOTE dropout should be 0.2 for train
model = _gen_efficientnet (
channel_multiplier = 1.0 , depth_multiplier = 1.0 ,
num_classes = num_classes , in_chans = in_chans , * * kwargs )
model . default_cfg = default_cfg
if pretrained :
load_pretrained ( model , default_cfg , num_classes , in_chans )
return model
def efficientnet_b1 ( num_classes , in_chans = 3 , pretrained = False , * * kwargs ) :
""" EfficientNet """
default_cfg = default_cfgs [ ' efficientnet_b1 ' ]
# NOTE dropout should be 0.2 for train
model = _gen_efficientnet (
channel_multiplier = 1.0 , depth_multiplier = 1.1 ,
num_classes = num_classes , in_chans = in_chans , * * kwargs )
model . default_cfg = default_cfg
if pretrained :
load_pretrained ( model , default_cfg , num_classes , in_chans )
return model
def efficientnet_b2 ( num_classes , in_chans = 3 , pretrained = False , * * kwargs ) :
""" EfficientNet """
default_cfg = default_cfgs [ ' efficientnet_b2 ' ]
# NOTE dropout should be 0.3 for train
model = _gen_efficientnet (
channel_multiplier = 1.1 , depth_multiplier = 1.2 ,
num_classes = num_classes , in_chans = in_chans , * * kwargs )
model . default_cfg = default_cfg
if pretrained :
load_pretrained ( model , default_cfg , num_classes , in_chans )
return model
def efficientnet_b3 ( num_classes , in_chans = 3 , pretrained = False , * * kwargs ) :
""" EfficientNet """
default_cfg = default_cfgs [ ' efficientnet_b3 ' ]
# NOTE dropout should be 0.3 for train
model = _gen_efficientnet (
channel_multiplier = 1.2 , depth_multiplier = 1.4 ,
num_classes = num_classes , in_chans = in_chans , * * kwargs )
model . default_cfg = default_cfg
if pretrained :
load_pretrained ( model , default_cfg , num_classes , in_chans )
return model
def efficientnet_b4 ( num_classes , in_chans = 3 , pretrained = False , * * kwargs ) :
""" EfficientNet """
default_cfg = default_cfgs [ ' efficientnet_b4 ' ]
# NOTE dropout should be 0.4 for train
model = _gen_efficientnet (
channel_multiplier = 1.4 , depth_multiplier = 1.8 ,
num_classes = num_classes , in_chans = in_chans , * * kwargs )
model . default_cfg = default_cfg
if pretrained :
load_pretrained ( model , default_cfg , num_classes , in_chans )
return model
def genmobilenet_model_names ( ) :
return set ( _models )