@ -1,13 +1,13 @@
""" Generic EfficientNets
""" Generic EfficientNets
A generic class with building blocks to support a variety of models with efficient architectures :
A generic class with building blocks to support a variety of models with efficient architectures :
* EfficientNet ( B0 - B4 in code right now , work in progress , still verifying )
* EfficientNet ( B0 - B5 )
* MNasNet B1 , A1 ( SE ) , Small
* MixNet ( Small , Medium , and Large )
* MobileNet V1 , V2 , and V3 ( work in progress )
* MnasNet B1 , A1 ( SE ) , Small
* MobileNet V1 , V2 , and V3
* FBNet - C ( TODO A & B )
* FBNet - C ( TODO A & B )
* ChamNet ( TODO still guessing at architecture definition )
* ChamNet ( TODO still guessing at architecture definition )
* Single - Path NAS Pixel1
* Single - Path NAS Pixel1
* ShuffleNetV2 ( TODO add IR shuffle block )
* And likely more . . .
* And likely more . . .
TODO not all combinations and variations have been tested . Currently working on training hyper - params . . .
TODO not all combinations and variations have been tested . Currently working on training hyper - params . . .
@ -27,7 +27,7 @@ import torch.nn.functional as F
from . registry import register_model
from . registry import register_model
from . helpers import load_pretrained
from . helpers import load_pretrained
from . adaptive_avgmax_pool import SelectAdaptivePool2d
from . adaptive_avgmax_pool import SelectAdaptivePool2d
from . conv2d_ same import s conv2d
from . conv2d_ helper s import s elect_ conv2d
from timm . data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
from timm . data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
@ -37,7 +37,7 @@ __all__ = ['GenEfficientNet']
def _cfg ( url = ' ' , * * kwargs ) :
def _cfg ( url = ' ' , * * kwargs ) :
return {
return {
' url ' : url , ' num_classes ' : 1000 , ' input_size ' : ( 3 , 224 , 224 ) , ' pool_size ' : ( 7 , 7 ) ,
' url ' : url , ' num_classes ' : 1000 , ' input_size ' : ( 3 , 224 , 224 ) , ' pool_size ' : ( 7 , 7 ) ,
' crop_pct ' : 0.875 , ' interpolation ' : ' bi linear ' ,
' crop_pct ' : 0.875 , ' interpolation ' : ' bi cubic ' ,
' mean ' : IMAGENET_DEFAULT_MEAN , ' std ' : IMAGENET_DEFAULT_STD ,
' mean ' : IMAGENET_DEFAULT_MEAN , ' std ' : IMAGENET_DEFAULT_STD ,
' first_conv ' : ' conv_stem ' , ' classifier ' : ' classifier ' ,
' first_conv ' : ' conv_stem ' , ' classifier ' : ' classifier ' ,
* * kwargs
* * kwargs
@ -48,14 +48,12 @@ default_cfgs = {
' mnasnet_050 ' : _cfg ( url = ' ' ) ,
' mnasnet_050 ' : _cfg ( url = ' ' ) ,
' mnasnet_075 ' : _cfg ( url = ' ' ) ,
' mnasnet_075 ' : _cfg ( url = ' ' ) ,
' mnasnet_100 ' : _cfg (
' mnasnet_100 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth ' ) ,
interpolation = ' bicubic ' ) ,
' mnasnet_140 ' : _cfg ( url = ' ' ) ,
' mnasnet_140 ' : _cfg ( url = ' ' ) ,
' semnasnet_050 ' : _cfg ( url = ' ' ) ,
' semnasnet_050 ' : _cfg ( url = ' ' ) ,
' semnasnet_075 ' : _cfg ( url = ' ' ) ,
' semnasnet_075 ' : _cfg ( url = ' ' ) ,
' semnasnet_100 ' : _cfg (
' semnasnet_100 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth ' ) ,
interpolation = ' bicubic ' ) ,
' semnasnet_140 ' : _cfg ( url = ' ' ) ,
' semnasnet_140 ' : _cfg ( url = ' ' ) ,
' mnasnet_small ' : _cfg ( url = ' ' ) ,
' mnasnet_small ' : _cfg ( url = ' ' ) ,
' mobilenetv1_100 ' : _cfg ( url = ' ' ) ,
' mobilenetv1_100 ' : _cfg ( url = ' ' ) ,
@ -63,23 +61,23 @@ default_cfgs = {
' mobilenetv3_050 ' : _cfg ( url = ' ' ) ,
' mobilenetv3_050 ' : _cfg ( url = ' ' ) ,
' mobilenetv3_075 ' : _cfg ( url = ' ' ) ,
' mobilenetv3_075 ' : _cfg ( url = ' ' ) ,
' mobilenetv3_100 ' : _cfg (
' mobilenetv3_100 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth ' ) ,
interpolation = ' bicubic ' ) ,
' chamnetv1_100 ' : _cfg ( url = ' ' ) ,
' chamnetv1_100 ' : _cfg ( url = ' ' ) ,
' chamnetv2_100 ' : _cfg ( url = ' ' ) ,
' chamnetv2_100 ' : _cfg ( url = ' ' ) ,
' fbnetc_100 ' : _cfg (
' fbnetc_100 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth ' ) ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth ' ,
interpolation = ' bilinear ' ) ,
' spnasnet_100 ' : _cfg (
' spnasnet_100 ' : _cfg (
url = ' https://www.dropbox.com/s/iieopt18rytkgaa/spnasnet_100-048bc3f4.pth?dl=1 ' ) ,
url = ' https://www.dropbox.com/s/iieopt18rytkgaa/spnasnet_100-048bc3f4.pth?dl=1 ' ,
interpolation = ' bilinear ' ) ,
' efficientnet_b0 ' : _cfg (
' efficientnet_b0 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0-d6904d92.pth ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0-d6904d92.pth ' ) ,
interpolation = ' bicubic ' ) ,
' efficientnet_b1 ' : _cfg (
' efficientnet_b1 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth ' ,
input_size = ( 3 , 240 , 240 ) , pool_size = ( 8 , 8 ) , interpolation= ' bicubic ' , crop_pct= 0.882 ) ,
input_size = ( 3 , 240 , 240 ) , pool_size = ( 8 , 8 ) , crop_pct= 0.882 ) ,
' efficientnet_b2 ' : _cfg (
' efficientnet_b2 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2-cf78dc4d.pth ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2-cf78dc4d.pth ' ,
input_size = ( 3 , 260 , 260 ) , pool_size = ( 9 , 9 ) , interpolation= ' bicubic ' , crop_pct= 0.890 ) ,
input_size = ( 3 , 260 , 260 ) , pool_size = ( 9 , 9 ) , crop_pct= 0.890 ) ,
' efficientnet_b3 ' : _cfg (
' efficientnet_b3 ' : _cfg (
url = ' ' , input_size = ( 3 , 300 , 300 ) , pool_size = ( 10 , 10 ) , crop_pct = 0.904 ) ,
url = ' ' , input_size = ( 3 , 300 , 300 ) , pool_size = ( 10 , 10 ) , crop_pct = 0.904 ) ,
' efficientnet_b4 ' : _cfg (
' efficientnet_b4 ' : _cfg (
@ -88,22 +86,31 @@ default_cfgs = {
url = ' ' , input_size = ( 3 , 456 , 456 ) , pool_size = ( 15 , 15 ) , crop_pct = 0.934 ) ,
url = ' ' , input_size = ( 3 , 456 , 456 ) , pool_size = ( 15 , 15 ) , crop_pct = 0.934 ) ,
' tf_efficientnet_b0 ' : _cfg (
' tf_efficientnet_b0 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0-0af12548.pth ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0-0af12548.pth ' ,
input_size = ( 3 , 224 , 224 ) , interpolation = ' bicubic ' ),
input_size = ( 3 , 224 , 224 ) ),
' tf_efficientnet_b1 ' : _cfg (
' tf_efficientnet_b1 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1-5c1377c4.pth ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1-5c1377c4.pth ' ,
input_size = ( 3 , 240 , 240 ) , pool_size = ( 8 , 8 ) , interpolation= ' bicubic ' , crop_pct= 0.882 ) ,
input_size = ( 3 , 240 , 240 ) , pool_size = ( 8 , 8 ) , crop_pct= 0.882 ) ,
' tf_efficientnet_b2 ' : _cfg (
' tf_efficientnet_b2 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2-e393ef04.pth ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2-e393ef04.pth ' ,
input_size = ( 3 , 260 , 260 ) , pool_size = ( 9 , 9 ) , interpolation= ' bicubic ' , crop_pct= 0.890 ) ,
input_size = ( 3 , 260 , 260 ) , pool_size = ( 9 , 9 ) , crop_pct= 0.890 ) ,
' tf_efficientnet_b3 ' : _cfg (
' tf_efficientnet_b3 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3-e3bd6955.pth ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3-e3bd6955.pth ' ,
input_size = ( 3 , 300 , 300 ) , pool_size = ( 10 , 10 ) , interpolation= ' bicubic ' , crop_pct= 0.904 ) ,
input_size = ( 3 , 300 , 300 ) , pool_size = ( 10 , 10 ) , crop_pct= 0.904 ) ,
' tf_efficientnet_b4 ' : _cfg (
' tf_efficientnet_b4 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4-74ee3bed.pth ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4-74ee3bed.pth ' ,
input_size = ( 3 , 380 , 380 ) , pool_size = ( 12 , 12 ) , interpolation= ' bicubic ' , crop_pct= 0.922 ) ,
input_size = ( 3 , 380 , 380 ) , pool_size = ( 12 , 12 ) , crop_pct= 0.922 ) ,
' tf_efficientnet_b5 ' : _cfg (
' tf_efficientnet_b5 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5-c6949ce9.pth ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5-c6949ce9.pth ' ,
input_size = ( 3 , 456 , 456 ) , pool_size = ( 15 , 15 ) , interpolation = ' bicubic ' , crop_pct = 0.934 )
input_size = ( 3 , 456 , 456 ) , pool_size = ( 15 , 15 ) , crop_pct = 0.934 ) ,
' mixnet_s ' : _cfg ( url = ' ' ) ,
' mixnet_m ' : _cfg ( url = ' ' ) ,
' mixnet_l ' : _cfg ( url = ' ' ) ,
' tf_mixnet_s ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth ' ) ,
' tf_mixnet_m ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_m-0f4d8805.pth ' ) ,
' tf_mixnet_l ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth ' ) ,
}
}
@ -151,6 +158,13 @@ def _round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
return new_channels
return new_channels
def _parse_ksize ( ss ) :
if ss . isdigit ( ) :
return int ( ss )
else :
return [ int ( k ) for k in ss . split ( ' . ' ) ]
def _decode_block_str ( block_str , depth_multiplier = 1.0 ) :
def _decode_block_str ( block_str , depth_multiplier = 1.0 ) :
""" Decode block definition string
""" Decode block definition string
@ -168,7 +182,7 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
e - expansion ratio ,
e - expansion ratio ,
c - output channels ,
c - output channels ,
se - squeeze / excitation ratio
se - squeeze / excitation ratio
a - activation fn ( ' re ' , ' r6 ' , or ' h s' )
n - activation fn ( ' re ' , ' r6 ' , ' hs ' , or ' sw ' )
Args :
Args :
block_str : a string representation of block arguments .
block_str : a string representation of block arguments .
Returns :
Returns :
@ -184,7 +198,9 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
noskip = False
noskip = False
for op in ops :
for op in ops :
# string options being checked on individual basis, combine if they grow
# string options being checked on individual basis, combine if they grow
if op . startswith ( ' a ' ) :
if op == ' noskip ' :
noskip = True
elif op . startswith ( ' n ' ) :
# activation fn
# activation fn
key = op [ 0 ]
key = op [ 0 ]
v = op [ 1 : ]
v = op [ 1 : ]
@ -194,11 +210,11 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
value = F . relu6
value = F . relu6
elif v == ' hs ' :
elif v == ' hs ' :
value = hard_swish
value = hard_swish
elif v == ' sw ' :
value = swish
else :
else :
continue
continue
options [ key ] = value
options [ key ] = value
elif op == ' noskip ' :
noskip = True
else :
else :
# all numeric options
# all numeric options
splits = re . split ( r ' ( \ d.*) ' , op )
splits = re . split ( r ' ( \ d.*) ' , op )
@ -207,14 +223,18 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
options [ key ] = value
options [ key ] = value
# if act_fn is None, the model default (passed to model init) will be used
# if act_fn is None, the model default (passed to model init) will be used
act_fn = options [ ' a ' ] if ' a ' in options else None
act_fn = options [ ' n ' ] if ' n ' in options else None
exp_kernel_size = _parse_ksize ( options [ ' a ' ] ) if ' a ' in options else 1
pw_kernel_size = _parse_ksize ( options [ ' p ' ] ) if ' p ' in options else 1
num_repeat = int ( options [ ' r ' ] )
num_repeat = int ( options [ ' r ' ] )
# each type of block has different valid arguments, fill accordingly
# each type of block has different valid arguments, fill accordingly
if block_type == ' ir ' :
if block_type == ' ir ' :
block_args = dict (
block_args = dict (
block_type = block_type ,
block_type = block_type ,
kernel_size = int ( options [ ' k ' ] ) ,
dw_kernel_size = _parse_ksize ( options [ ' k ' ] ) ,
exp_kernel_size = exp_kernel_size ,
pw_kernel_size = pw_kernel_size ,
out_chs = int ( options [ ' c ' ] ) ,
out_chs = int ( options [ ' c ' ] ) ,
exp_ratio = float ( options [ ' e ' ] ) ,
exp_ratio = float ( options [ ' e ' ] ) ,
se_ratio = float ( options [ ' se ' ] ) if ' se ' in options else None ,
se_ratio = float ( options [ ' se ' ] ) if ' se ' in options else None ,
@ -222,20 +242,17 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
act_fn = act_fn ,
act_fn = act_fn ,
noskip = noskip ,
noskip = noskip ,
)
)
if ' g ' in options :
block_args [ ' pw_group ' ] = options [ ' g ' ]
if options [ ' g ' ] > 1 :
block_args [ ' shuffle_type ' ] = ' mid '
elif block_type == ' ds ' or block_type == ' dsa ' :
elif block_type == ' ds ' or block_type == ' dsa ' :
block_args = dict (
block_args = dict (
block_type = block_type ,
block_type = block_type ,
kernel_size = int ( options [ ' k ' ] ) ,
dw_kernel_size = _parse_ksize ( options [ ' k ' ] ) ,
pw_kernel_size = pw_kernel_size ,
out_chs = int ( options [ ' c ' ] ) ,
out_chs = int ( options [ ' c ' ] ) ,
se_ratio = float ( options [ ' se ' ] ) if ' se ' in options else None ,
se_ratio = float ( options [ ' se ' ] ) if ' se ' in options else None ,
stride = int ( options [ ' s ' ] ) ,
stride = int ( options [ ' s ' ] ) ,
act_fn = act_fn ,
act_fn = act_fn ,
noskip = block_type == ' dsa ' or noskip ,
pw_act = block_type == ' dsa ' ,
pw_act = block_type == ' dsa ' ,
noskip = block_type == ' dsa ' or noskip ,
)
)
elif block_type == ' cn ' :
elif block_type == ' cn ' :
block_args = dict (
block_args = dict (
@ -254,15 +271,6 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
return [ deepcopy ( block_args ) for _ in range ( num_repeat ) ]
return [ deepcopy ( block_args ) for _ in range ( num_repeat ) ]
def _get_padding ( kernel_size , stride , dilation = 1 ) :
padding = ( ( stride - 1 ) + dilation * ( kernel_size - 1 ) ) / / 2
return padding
def _padding_arg ( default , padding_same = False ) :
return ' SAME ' if padding_same else default
def _decode_arch_args ( string_list ) :
def _decode_arch_args ( string_list ) :
block_args = [ ]
block_args = [ ]
for block_str in string_list :
for block_str in string_list :
@ -316,20 +324,18 @@ class _BlockBuilder:
https : / / github . com / facebookresearch / maskrcnn - benchmark / blob / master / maskrcnn_benchmark / modeling / backbone / fbnet_builder . py
https : / / github . com / facebookresearch / maskrcnn - benchmark / blob / master / maskrcnn_benchmark / modeling / backbone / fbnet_builder . py
"""
"""
def __init__ ( self , channel_multiplier = 1.0 , channel_divisor = 8 , channel_min = None ,
def __init__ ( self , channel_multiplier = 1.0 , channel_divisor = 8 , channel_min = None ,
drop_connect_rate = 0. , act_fn = None , se_gate_fn = sigmoid , se_reduce_mid = False ,
pad_type = ' ' , act_fn = None , se_gate_fn = sigmoid , se_reduce_mid = False ,
bn_args = _BN_ARGS_PT , padding_same = False ,
bn_args = _BN_ARGS_PT , drop_connect_rate = 0. , verbose = False ) :
verbose = False ) :
self . channel_multiplier = channel_multiplier
self . channel_multiplier = channel_multiplier
self . channel_divisor = channel_divisor
self . channel_divisor = channel_divisor
self . channel_min = channel_min
self . channel_min = channel_min
self . drop_connect_rate = drop_connect_rat e
self . pad_type = pad_typ e
self . act_fn = act_fn
self . act_fn = act_fn
self . se_gate_fn = se_gate_fn
self . se_gate_fn = se_gate_fn
self . se_reduce_mid = se_reduce_mid
self . se_reduce_mid = se_reduce_mid
self . bn_args = bn_args
self . bn_args = bn_args
self . padding_same = padding_sam e
self . drop_connect_rate = drop_connect_rat e
self . verbose = verbose
self . verbose = verbose
# updated during build
# updated during build
@ -345,7 +351,7 @@ class _BlockBuilder:
ba [ ' in_chs ' ] = self . in_chs
ba [ ' in_chs ' ] = self . in_chs
ba [ ' out_chs ' ] = self . _round_channels ( ba [ ' out_chs ' ] )
ba [ ' out_chs ' ] = self . _round_channels ( ba [ ' out_chs ' ] )
ba [ ' bn_args ' ] = self . bn_args
ba [ ' bn_args ' ] = self . bn_args
ba [ ' pad ding_same' ] = self . padding_sam e
ba [ ' pad _type' ] = self . pad_typ e
# block act fn overrides the model default
# block act fn overrides the model default
ba [ ' act_fn ' ] = ba [ ' act_fn ' ] if ba [ ' act_fn ' ] is not None else self . act_fn
ba [ ' act_fn ' ] = ba [ ' act_fn ' ] if ba [ ' act_fn ' ] is not None else self . act_fn
assert ba [ ' act_fn ' ] is not None
assert ba [ ' act_fn ' ] is not None
@ -493,16 +499,11 @@ class SqueezeExcite(nn.Module):
class ConvBnAct ( nn . Module ) :
class ConvBnAct ( nn . Module ) :
def __init__ ( self , in_chs , out_chs , kernel_size ,
def __init__ ( self , in_chs , out_chs , kernel_size ,
stride = 1 , act_fn = F . relu ,
stride = 1 , pad_type = ' ' , act_fn = F . relu , bn_args = _BN_ARGS_PT ) :
bn_args = _BN_ARGS_PT , padding_same = False ) :
super ( ConvBnAct , self ) . __init__ ( )
super ( ConvBnAct , self ) . __init__ ( )
assert stride in [ 1 , 2 ]
assert stride in [ 1 , 2 ]
self . act_fn = act_fn
self . act_fn = act_fn
padding = _padding_arg ( _get_padding ( kernel_size , stride ) , padding_same )
self . conv = select_conv2d ( in_chs , out_chs , kernel_size , stride = stride , padding = pad_type )
self . conv = sconv2d (
in_chs , out_chs , kernel_size ,
stride = stride , padding = padding , bias = False )
self . bn1 = nn . BatchNorm2d ( out_chs , * * bn_args )
self . bn1 = nn . BatchNorm2d ( out_chs , * * bn_args )
def forward ( self , x ) :
def forward ( self , x ) :
@ -517,10 +518,11 @@ class DepthwiseSeparableConv(nn.Module):
Used for DS convs in MobileNet - V1 and in the place of IR blocks with an expansion
Used for DS convs in MobileNet - V1 and in the place of IR blocks with an expansion
factor of 1.0 . This is an alternative to having a IR with optional first pw conv .
factor of 1.0 . This is an alternative to having a IR with optional first pw conv .
"""
"""
def __init__ ( self , in_chs , out_chs , kernel_size ,
def __init__ ( self , in_chs , out_chs , dw_kernel_size = 3 ,
stride = 1 , act_fn = F . relu , noskip = False , pw_act = False ,
stride = 1 , pad_type = ' ' , act_fn = F . relu , noskip = False ,
pw_kernel_size = 1 , pw_act = False ,
se_ratio = 0. , se_gate_fn = sigmoid ,
se_ratio = 0. , se_gate_fn = sigmoid ,
bn_args = _BN_ARGS_PT , padding_same= False , drop_connect_rate= 0. ) :
bn_args = _BN_ARGS_PT , drop_connect_rate= 0. ) :
super ( DepthwiseSeparableConv , self ) . __init__ ( )
super ( DepthwiseSeparableConv , self ) . __init__ ( )
assert stride in [ 1 , 2 ]
assert stride in [ 1 , 2 ]
self . has_se = se_ratio is not None and se_ratio > 0.
self . has_se = se_ratio is not None and se_ratio > 0.
@ -528,12 +530,9 @@ class DepthwiseSeparableConv(nn.Module):
self . has_pw_act = pw_act # activation after point-wise conv
self . has_pw_act = pw_act # activation after point-wise conv
self . act_fn = act_fn
self . act_fn = act_fn
self . drop_connect_rate = drop_connect_rate
self . drop_connect_rate = drop_connect_rate
dw_padding = _padding_arg ( kernel_size / / 2 , padding_same )
pw_padding = _padding_arg ( 0 , padding_same )
self . conv_dw = sconv2d (
self . conv_dw = select_conv2d (
in_chs , in_chs , kernel_size ,
in_chs , in_chs , dw_kernel_size , stride = stride , padding = pad_type , depthwise = True )
stride = stride , padding = dw_padding , groups = in_chs , bias = False )
self . bn1 = nn . BatchNorm2d ( in_chs , * * bn_args )
self . bn1 = nn . BatchNorm2d ( in_chs , * * bn_args )
# Squeeze-and-excitation
# Squeeze-and-excitation
@ -541,7 +540,7 @@ class DepthwiseSeparableConv(nn.Module):
self . se = SqueezeExcite (
self . se = SqueezeExcite (
in_chs , reduce_chs = max ( 1 , int ( in_chs * se_ratio ) ) , act_fn = act_fn , gate_fn = se_gate_fn )
in_chs , reduce_chs = max ( 1 , int ( in_chs * se_ratio ) ) , act_fn = act_fn , gate_fn = se_gate_fn )
self . conv_pw = s conv2d( in_chs , out_chs , 1 , padding = pw_padding , bias = Fals e)
self . conv_pw = s elect_ conv2d( in_chs , out_chs , pw_kernel_size , padding = pad_typ e)
self . bn2 = nn . BatchNorm2d ( out_chs , * * bn_args )
self . bn2 = nn . BatchNorm2d ( out_chs , * * bn_args )
def forward ( self , x ) :
def forward ( self , x ) :
@ -569,31 +568,29 @@ class DepthwiseSeparableConv(nn.Module):
class InvertedResidual ( nn . Module ) :
class InvertedResidual ( nn . Module ) :
""" Inverted residual block w/ optional SE """
""" Inverted residual block w/ optional SE """
def __init__ ( self , in_chs , out_chs , kernel_size ,
def __init__ ( self , in_chs , out_chs , dw_kernel_size = 3 ,
stride = 1 , act_fn = F . relu , exp_ratio = 1.0 , noskip = False ,
stride = 1 , pad_type = ' ' , act_fn = F . relu , noskip = False ,
exp_ratio = 1.0 , exp_kernel_size = 1 , pw_kernel_size = 1 ,
se_ratio = 0. , se_reduce_mid = False , se_gate_fn = sigmoid ,
se_ratio = 0. , se_reduce_mid = False , se_gate_fn = sigmoid ,
shuffle_type = None , pw_group = 1 ,
shuffle_type = None , bn_args = _BN_ARGS_PT , drop_connect_rate = 0. ) :
bn_args = _BN_ARGS_PT , padding_same = False , drop_connect_rate = 0. ) :
super ( InvertedResidual , self ) . __init__ ( )
super ( InvertedResidual , self ) . __init__ ( )
mid_chs = int ( in_chs * exp_ratio )
mid_chs = int ( in_chs * exp_ratio )
self . has_se = se_ratio is not None and se_ratio > 0.
self . has_se = se_ratio is not None and se_ratio > 0.
self . has_residual = ( in_chs == out_chs and stride == 1 ) and not noskip
self . has_residual = ( in_chs == out_chs and stride == 1 ) and not noskip
self . act_fn = act_fn
self . act_fn = act_fn
self . drop_connect_rate = drop_connect_rate
self . drop_connect_rate = drop_connect_rate
dw_padding = _padding_arg ( kernel_size / / 2 , padding_same )
pw_padding = _padding_arg ( 0 , padding_same )
# Point-wise expansion
# Point-wise expansion
self . conv_pw = s conv2d( in_chs , mid_chs , 1 , padding = pw_padding , groups = pw_group , bias = Fals e)
self . conv_pw = s elect_ conv2d( in_chs , mid_chs , exp_kernel_size , padding = pad_typ e)
self . bn1 = nn . BatchNorm2d ( mid_chs , * * bn_args )
self . bn1 = nn . BatchNorm2d ( mid_chs , * * bn_args )
self . shuffle_type = shuffle_type
self . shuffle_type = shuffle_type
if shuffle_type is not None :
if shuffle_type is not None and isinstance ( exp_kernel_size , list ) :
self . shuffle = ChannelShuffle ( pw_group )
self . shuffle = ChannelShuffle ( len ( exp_kernel_size ) )
# Depth-wise convolution
# Depth-wise convolution
self . conv_dw = s conv2d(
self . conv_dw = s elect_ conv2d(
mid_chs , mid_chs , kernel_size, padding = dw_padding , stride = stride , groups = mid_chs , bias = Fals e)
mid_chs , mid_chs , dw_kernel_size, stride = stride , padding = pad_type , depthwise = Tru e)
self . bn2 = nn . BatchNorm2d ( mid_chs , * * bn_args )
self . bn2 = nn . BatchNorm2d ( mid_chs , * * bn_args )
# Squeeze-and-excitation
# Squeeze-and-excitation
@ -603,7 +600,7 @@ class InvertedResidual(nn.Module):
mid_chs , reduce_chs = max ( 1 , int ( se_base_chs * se_ratio ) ) , act_fn = act_fn , gate_fn = se_gate_fn )
mid_chs , reduce_chs = max ( 1 , int ( se_base_chs * se_ratio ) ) , act_fn = act_fn , gate_fn = se_gate_fn )
# Point-wise linear projection
# Point-wise linear projection
self . conv_pwl = s conv2d( mid_chs , out_chs , 1 , padding = pw_padding , groups = pw_group , bias = Fals e)
self . conv_pwl = s elect_ conv2d( mid_chs , out_chs , pw_kernel_size , padding = pad_typ e)
self . bn3 = nn . BatchNorm2d ( out_chs , * * bn_args )
self . bn3 = nn . BatchNorm2d ( out_chs , * * bn_args )
def forward ( self , x ) :
def forward ( self , x ) :
@ -649,18 +646,19 @@ class GenEfficientNet(nn.Module):
* MobileNet - V1
* MobileNet - V1
* MobileNet - V2
* MobileNet - V2
* MobileNet - V3
* MobileNet - V3
* M NAS Net A1 , B1 , and small
* M nas Net A1 , B1 , and small
* FBNet A , B , and C
* FBNet A , B , and C
* ChamNet ( arch details are murky )
* ChamNet ( arch details are murky )
* Single - Path NAS Pixel1
* Single - Path NAS Pixel1
* EfficientNetB0 - B4 ( rest easy to add )
* EfficientNet B0 - B5
* MixNet S , M , L
"""
"""
def __init__ ( self , block_args , num_classes = 1000 , in_chans = 3 , stem_size = 32 , num_features = 1280 ,
def __init__ ( self , block_args , num_classes = 1000 , in_chans = 3 , stem_size = 32 , num_features = 1280 ,
channel_multiplier = 1.0 , channel_divisor = 8 , channel_min = None ,
channel_multiplier = 1.0 , channel_divisor = 8 , channel_min = None ,
drop_rate= 0. , drop_connect_rate = 0. , act_fn = F . relu ,
pad_type= ' ' , act_fn = F . relu , drop_rate= 0. , drop_connect_rate = 0. ,
se_gate_fn = sigmoid , se_reduce_mid = False , bn_args = _BN_ARGS_PT ,
se_gate_fn = sigmoid , se_reduce_mid = False , bn_args = _BN_ARGS_PT ,
global_pool = ' avg ' , head_conv = ' default ' , weight_init = ' goog ' , padding_same = False ):
global_pool = ' avg ' , head_conv = ' default ' , weight_init = ' goog ' ):
super ( GenEfficientNet , self ) . __init__ ( )
super ( GenEfficientNet , self ) . __init__ ( )
self . num_classes = num_classes
self . num_classes = num_classes
self . drop_rate = drop_rate
self . drop_rate = drop_rate
@ -668,16 +666,14 @@ class GenEfficientNet(nn.Module):
self . num_features = num_features
self . num_features = num_features
stem_size = _round_channels ( stem_size , channel_multiplier , channel_divisor , channel_min )
stem_size = _round_channels ( stem_size , channel_multiplier , channel_divisor , channel_min )
self . conv_stem = sconv2d (
self . conv_stem = select_conv2d ( in_chans , stem_size , 3 , stride = 2 , padding = pad_type )
in_chans , stem_size , 3 ,
padding = _padding_arg ( 1 , padding_same ) , stride = 2 , bias = False )
self . bn1 = nn . BatchNorm2d ( stem_size , * * bn_args )
self . bn1 = nn . BatchNorm2d ( stem_size , * * bn_args )
in_chs = stem_size
in_chs = stem_size
builder = _BlockBuilder (
builder = _BlockBuilder (
channel_multiplier , channel_divisor , channel_min ,
channel_multiplier , channel_divisor , channel_min ,
drop_connect_rat e, act_fn , se_gate_fn , se_reduce_mid ,
pad_typ e, act_fn , se_gate_fn , se_reduce_mid ,
bn_args , padding_sam e, verbose = _DEBUG )
bn_args , drop_connect_rat e, verbose = _DEBUG )
self . blocks = nn . Sequential ( * builder ( in_chs , block_args ) )
self . blocks = nn . Sequential ( * builder ( in_chs , block_args ) )
in_chs = builder . in_chs
in_chs = builder . in_chs
@ -687,9 +683,7 @@ class GenEfficientNet(nn.Module):
assert in_chs == self . num_features
assert in_chs == self . num_features
else :
else :
self . efficient_head = head_conv == ' efficient '
self . efficient_head = head_conv == ' efficient '
self . conv_head = sconv2d (
self . conv_head = select_conv2d ( in_chs , self . num_features , 1 , padding = pad_type )
in_chs , self . num_features , 1 ,
padding = _padding_arg ( 0 , padding_same ) , bias = False )
self . bn2 = None if self . efficient_head else nn . BatchNorm2d ( self . num_features , * * bn_args )
self . bn2 = None if self . efficient_head else nn . BatchNorm2d ( self . num_features , * * bn_args )
self . global_pool = SelectAdaptivePool2d ( pool_type = global_pool )
self . global_pool = SelectAdaptivePool2d ( pool_type = global_pool )
@ -919,11 +913,11 @@ def _gen_mobilenet_v3(channel_multiplier, num_classes=1000, **kwargs):
"""
"""
arch_def = [
arch_def = [
# stage 0, 112x112 in
# stage 0, 112x112 in
[ ' ds_r1_k3_s1_e1_c16_ a re_noskip' ] , # relu
[ ' ds_r1_k3_s1_e1_c16_ n re_noskip' ] , # relu
# stage 1, 112x112 in
# stage 1, 112x112 in
[ ' ir_r1_k3_s2_e4_c24_ are' , ' ir_r1_k3_s1_e3_c24_a re' ] , # relu
[ ' ir_r1_k3_s2_e4_c24_ nre' , ' ir_r1_k3_s1_e3_c24_n re' ] , # relu
# stage 2, 56x56 in
# stage 2, 56x56 in
[ ' ir_r3_k5_s2_e3_c40_se0.25_ a re' ] , # relu
[ ' ir_r3_k5_s2_e3_c40_se0.25_ n re' ] , # relu
# stage 3, 28x28 in
# stage 3, 28x28 in
[ ' ir_r1_k3_s2_e6_c80 ' , ' ir_r1_k3_s1_e2.5_c80 ' , ' ir_r2_k3_s1_e2.3_c80 ' ] , # hard-swish
[ ' ir_r1_k3_s2_e6_c80 ' , ' ir_r1_k3_s1_e2.5_c80 ' , ' ir_r2_k3_s1_e2.3_c80 ' ] , # hard-swish
# stage 4, 14x14in
# stage 4, 14x14in
@ -1129,6 +1123,78 @@ def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=
return model
return model
def _gen_mixnet_s ( channel_multiplier = 1.0 , num_classes = 1000 , * * kwargs ) :
""" Creates a MixNet Small model.
Ref impl : https : / / github . com / tensorflow / tpu / tree / master / models / official / mnasnet / mixnet
Paper : https : / / arxiv . org / abs / 1907.09595
"""
arch_def = [
# stage 0, 112x112 in
[ ' ds_r1_k3_s1_e1_c16 ' ] , # relu
# stage 1, 112x112 in
[ ' ir_r1_k3_a1.1_p1.1_s2_e6_c24 ' , ' ir_r1_k3_a1.1_p1.1_s1_e3_c24 ' ] , # relu
# stage 2, 56x56 in
[ ' ir_r1_k3.5.7_s2_e6_c40_se0.5_nsw ' , ' ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw ' ] , # swish
# stage 3, 28x28 in
[ ' ir_r1_k3.5.7_p1.1_s2_e6_c80_se0.25_nsw ' , ' ir_r2_k3.5_p1.1_s1_e6_c80_se0.25_nsw ' ] , # swish
# stage 4, 14x14in
[ ' ir_r1_k3.5.7_a1.1_p1.1_s1_e6_c120_se0.5_nsw ' , ' ir_r2_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw ' ] , # swish
# stage 5, 14x14in
[ ' ir_r1_k3.5.7.9.11_s2_e6_c200_se0.5_nsw ' , ' ir_r2_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw ' ] , # swish
# 7x7
]
model = GenEfficientNet (
_decode_arch_def ( arch_def ) ,
num_classes = num_classes ,
stem_size = 16 ,
num_features = 1536 ,
channel_multiplier = channel_multiplier ,
channel_divisor = 8 ,
channel_min = None ,
bn_args = _resolve_bn_args ( kwargs ) ,
act_fn = F . relu ,
* * kwargs
)
return model
def _gen_mixnet_m ( channel_multiplier = 1.0 , num_classes = 1000 , * * kwargs ) :
""" Creates a MixNet Medium-Large model.
Ref impl : https : / / github . com / tensorflow / tpu / tree / master / models / official / mnasnet / mixnet
Paper : https : / / arxiv . org / abs / 1907.09595
"""
arch_def = [
# stage 0, 112x112 in
[ ' ds_r1_k3_s1_e1_c24 ' ] , # relu
# stage 1, 112x112 in
[ ' ir_r1_k3.5.7_a1.1_p1.1_s2_e6_c32 ' , ' ir_r1_k3_a1.1_p1.1_s1_e3_c32 ' ] , # relu
# stage 2, 56x56 in
[ ' ir_r1_k3.5.7.9_s2_e6_c40_se0.5_nsw ' , ' ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw ' ] , # swish
# stage 3, 28x28 in
[ ' ir_r1_k3.5.7_s2_e6_c80_se0.25_nsw ' , ' ir_r3_k3.5.7.9_a1.1_p1.1_s1_e6_c80_se0.25_nsw ' ] , # swish
# stage 4, 14x14in
[ ' ir_r1_k3_s1_e6_c120_se0.5_nsw ' , ' ir_r3_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw ' ] , # swish
# stage 5, 14x14in
[ ' ir_r1_k3.5.7.9_s2_e6_c200_se0.5_nsw ' , ' ir_r3_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw ' ] , # swish
# 7x7
]
model = GenEfficientNet (
_decode_arch_def ( arch_def ) ,
num_classes = num_classes ,
stem_size = 24 ,
num_features = 1536 ,
channel_multiplier = channel_multiplier ,
channel_divisor = 8 ,
channel_min = None ,
bn_args = _resolve_bn_args ( kwargs ) ,
act_fn = F . relu ,
* * kwargs
)
return model
@register_model
@register_model
def mnasnet_050 ( pretrained = False , num_classes = 1000 , in_chans = 3 , * * kwargs ) :
def mnasnet_050 ( pretrained = False , num_classes = 1000 , in_chans = 3 , * * kwargs ) :
""" MNASNet B1, depth multiplier of 0.5. """
""" MNASNet B1, depth multiplier of 0.5. """
@ -1440,7 +1506,7 @@ def tf_efficientnet_b0(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
""" EfficientNet-B0. Tensorflow compatible variant """
""" EfficientNet-B0. Tensorflow compatible variant """
default_cfg = default_cfgs [ ' tf_efficientnet_b0 ' ]
default_cfg = default_cfgs [ ' tf_efficientnet_b0 ' ]
kwargs [ ' bn_eps ' ] = _BN_EPS_TF_DEFAULT
kwargs [ ' bn_eps ' ] = _BN_EPS_TF_DEFAULT
kwargs [ ' pad ding_same' ] = True
kwargs [ ' pad _type' ] = ' same '
model = _gen_efficientnet (
model = _gen_efficientnet (
channel_multiplier = 1.0 , depth_multiplier = 1.0 ,
channel_multiplier = 1.0 , depth_multiplier = 1.0 ,
num_classes = num_classes , in_chans = in_chans , * * kwargs )
num_classes = num_classes , in_chans = in_chans , * * kwargs )
@ -1455,7 +1521,7 @@ def tf_efficientnet_b1(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
""" EfficientNet-B1. Tensorflow compatible variant """
""" EfficientNet-B1. Tensorflow compatible variant """
default_cfg = default_cfgs [ ' tf_efficientnet_b1 ' ]
default_cfg = default_cfgs [ ' tf_efficientnet_b1 ' ]
kwargs [ ' bn_eps ' ] = _BN_EPS_TF_DEFAULT
kwargs [ ' bn_eps ' ] = _BN_EPS_TF_DEFAULT
kwargs [ ' pad ding_same' ] = True
kwargs [ ' pad _type' ] = ' same '
model = _gen_efficientnet (
model = _gen_efficientnet (
channel_multiplier = 1.0 , depth_multiplier = 1.1 ,
channel_multiplier = 1.0 , depth_multiplier = 1.1 ,
num_classes = num_classes , in_chans = in_chans , * * kwargs )
num_classes = num_classes , in_chans = in_chans , * * kwargs )
@ -1470,7 +1536,7 @@ def tf_efficientnet_b2(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
""" EfficientNet-B2. Tensorflow compatible variant """
""" EfficientNet-B2. Tensorflow compatible variant """
default_cfg = default_cfgs [ ' tf_efficientnet_b2 ' ]
default_cfg = default_cfgs [ ' tf_efficientnet_b2 ' ]
kwargs [ ' bn_eps ' ] = _BN_EPS_TF_DEFAULT
kwargs [ ' bn_eps ' ] = _BN_EPS_TF_DEFAULT
kwargs [ ' pad ding_same' ] = True
kwargs [ ' pad _type' ] = ' same '
model = _gen_efficientnet (
model = _gen_efficientnet (
channel_multiplier = 1.1 , depth_multiplier = 1.2 ,
channel_multiplier = 1.1 , depth_multiplier = 1.2 ,
num_classes = num_classes , in_chans = in_chans , * * kwargs )
num_classes = num_classes , in_chans = in_chans , * * kwargs )
@ -1485,7 +1551,7 @@ def tf_efficientnet_b3(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
""" EfficientNet-B3. Tensorflow compatible variant """
""" EfficientNet-B3. Tensorflow compatible variant """
default_cfg = default_cfgs [ ' tf_efficientnet_b3 ' ]
default_cfg = default_cfgs [ ' tf_efficientnet_b3 ' ]
kwargs [ ' bn_eps ' ] = _BN_EPS_TF_DEFAULT
kwargs [ ' bn_eps ' ] = _BN_EPS_TF_DEFAULT
kwargs [ ' pad ding_same' ] = True
kwargs [ ' pad _type' ] = ' same '
model = _gen_efficientnet (
model = _gen_efficientnet (
channel_multiplier = 1.2 , depth_multiplier = 1.4 ,
channel_multiplier = 1.2 , depth_multiplier = 1.4 ,
num_classes = num_classes , in_chans = in_chans , * * kwargs )
num_classes = num_classes , in_chans = in_chans , * * kwargs )
@ -1500,7 +1566,7 @@ def tf_efficientnet_b4(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
""" EfficientNet-B4. Tensorflow compatible variant """
""" EfficientNet-B4. Tensorflow compatible variant """
default_cfg = default_cfgs [ ' tf_efficientnet_b4 ' ]
default_cfg = default_cfgs [ ' tf_efficientnet_b4 ' ]
kwargs [ ' bn_eps ' ] = _BN_EPS_TF_DEFAULT
kwargs [ ' bn_eps ' ] = _BN_EPS_TF_DEFAULT
kwargs [ ' pad ding_same' ] = True
kwargs [ ' pad _type' ] = ' same '
model = _gen_efficientnet (
model = _gen_efficientnet (
channel_multiplier = 1.4 , depth_multiplier = 1.8 ,
channel_multiplier = 1.4 , depth_multiplier = 1.8 ,
num_classes = num_classes , in_chans = in_chans , * * kwargs )
num_classes = num_classes , in_chans = in_chans , * * kwargs )
@ -1515,7 +1581,7 @@ def tf_efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
""" EfficientNet-B5. Tensorflow compatible variant """
""" EfficientNet-B5. Tensorflow compatible variant """
default_cfg = default_cfgs [ ' tf_efficientnet_b5 ' ]
default_cfg = default_cfgs [ ' tf_efficientnet_b5 ' ]
kwargs [ ' bn_eps ' ] = _BN_EPS_TF_DEFAULT
kwargs [ ' bn_eps ' ] = _BN_EPS_TF_DEFAULT
kwargs [ ' pad ding_same' ] = True
kwargs [ ' pad _type' ] = ' same '
model = _gen_efficientnet (
model = _gen_efficientnet (
channel_multiplier = 1.6 , depth_multiplier = 2.2 ,
channel_multiplier = 1.6 , depth_multiplier = 2.2 ,
num_classes = num_classes , in_chans = in_chans , * * kwargs )
num_classes = num_classes , in_chans = in_chans , * * kwargs )
@ -1525,5 +1591,89 @@ def tf_efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
return model
return model
@register_model
def mixnet_s ( pretrained = False , num_classes = 1000 , in_chans = 3 , * * kwargs ) :
""" Creates a MixNet Small model.
"""
default_cfg = default_cfgs [ ' mixnet_m ' ]
model = _gen_mixnet_s (
channel_multiplier = 1.0 , num_classes = num_classes , in_chans = in_chans , * * kwargs )
model . default_cfg = default_cfg
#if pretrained:
# load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def mixnet_m ( pretrained = False , num_classes = 1000 , in_chans = 3 , * * kwargs ) :
""" Creates a MixNet Medium model.
"""
default_cfg = default_cfgs [ ' mixnet_m ' ]
model = _gen_mixnet_m (
channel_multiplier = 1.0 , num_classes = num_classes , in_chans = in_chans , * * kwargs )
model . default_cfg = default_cfg
#if pretrained:
# load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def mixnet_l ( pretrained = False , num_classes = 1000 , in_chans = 3 , * * kwargs ) :
""" Creates a MixNet Large model.
"""
default_cfg = default_cfgs [ ' mixnet_l ' ]
model = _gen_mixnet_m (
channel_multiplier = 1.3 , num_classes = num_classes , in_chans = in_chans , * * kwargs )
model . default_cfg = default_cfg
if pretrained :
load_pretrained ( model , default_cfg , num_classes , in_chans )
return model
@register_model
def tf_mixnet_s ( pretrained = False , num_classes = 1000 , in_chans = 3 , * * kwargs ) :
""" Creates a MixNet Small model. Tensorflow compatible variant
"""
default_cfg = default_cfgs [ ' tf_mixnet_s ' ]
kwargs [ ' bn_eps ' ] = _BN_EPS_TF_DEFAULT
kwargs [ ' pad_type ' ] = ' same '
model = _gen_mixnet_s (
channel_multiplier = 1.0 , num_classes = num_classes , in_chans = in_chans , * * kwargs )
model . default_cfg = default_cfg
if pretrained :
load_pretrained ( model , default_cfg , num_classes , in_chans )
return model
@register_model
def tf_mixnet_m ( pretrained = False , num_classes = 1000 , in_chans = 3 , * * kwargs ) :
""" Creates a MixNet Medium model. Tensorflow compatible variant
"""
default_cfg = default_cfgs [ ' tf_mixnet_m ' ]
kwargs [ ' bn_eps ' ] = _BN_EPS_TF_DEFAULT
kwargs [ ' pad_type ' ] = ' same '
model = _gen_mixnet_m (
channel_multiplier = 1.0 , num_classes = num_classes , in_chans = in_chans , * * kwargs )
model . default_cfg = default_cfg
if pretrained :
load_pretrained ( model , default_cfg , num_classes , in_chans )
return model
@register_model
def tf_mixnet_l ( pretrained = False , num_classes = 1000 , in_chans = 3 , * * kwargs ) :
""" Creates a MixNet Large model. Tensorflow compatible variant
"""
default_cfg = default_cfgs [ ' tf_mixnet_l ' ]
kwargs [ ' bn_eps ' ] = _BN_EPS_TF_DEFAULT
kwargs [ ' pad_type ' ] = ' same '
model = _gen_mixnet_m (
channel_multiplier = 1.3 , num_classes = num_classes , in_chans = in_chans , * * kwargs )
model . default_cfg = default_cfg
if pretrained :
load_pretrained ( model , default_cfg , num_classes , in_chans )
return model
def gen_efficientnet_model_names ( ) :
def gen_efficientnet_model_names ( ) :
return set ( _models )
return set ( _models )