Merge pull request #105 from rwightman/efficientnet-lite

EfficientNet-Lite
pull/115/head
Ross Wightman 4 years ago committed by GitHub
commit 71b5cd67da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,6 +2,9 @@
## What's New
### March 18, 2020
* Add EfficientNet-Lite models w/ weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite)
### Feb 29, 2020
* New MobileNet-V3 Large weights trained from stratch with this code to 75.77% top-1
* IMPORTANT CHANGE - default weight init changed for all MobilenetV3 / EfficientNet / related models
@ -252,7 +255,9 @@ For the models below, the model code and weight porting from Tensorflow or MXNet
| tf_efficientnet_b3_ap | 81.828 (18.172) | 95.624 (4.376) | 12.23 | bicubic | 300 |
| tf_efficientnet_b3 | 81.636 (18.364) | 95.718 (4.282) | 12.23 | bicubic | 300 |
| tf_efficientnet_b3 *tfp | 81.576 (18.424) | 95.662 (4.338) | 12.23 | bicubic | 300 |
| tf_efficientnet_lite4 | 81.528 (18.472) | 95.668 (4.332) | 13.00 | bilinear | 380 |
| tf_efficientnet_b1_ns *tfp | 81.514 (18.486) | 95.776 (4.224) | 7.79 | bicubic | 240 |
| tf_efficientnet_lite4 *tfp | 81.502 (18.498) | 95.676 (4.324) | 13.00 | bilinear | 380 |
| tf_efficientnet_b1_ns | 81.388 (18.612) | 95.738 (4.262) | 7.79 | bicubic | 240 |
| gluon_senet154 | 81.224 (18.776) | 95.356 (4.644) | 115.09 | bicubic | 224 |
| gluon_resnet152_v1s | 81.012 (18.988) | 95.416 (4.584) | 60.32 | bicubic | 224 |
@ -271,6 +276,8 @@ For the models below, the model code and weight porting from Tensorflow or MXNet
| tf_efficientnet_b2 | 80.086 (19.914) | 94.908 (5.092) | 9.11 | bicubic | 260 |
| gluon_resnet152_v1c | 79.916 (20.084) | 94.842 (5.158) | 60.21 | bicubic | 224 |
| gluon_seresnext50_32x4d | 79.912 (20.088) | 94.818 (5.182) | 27.56 | bicubic | 224 |
| tf_efficientnet_lite3 | 79.812 (20.188) | 94.914 (5.086) | 8.20 | bilinear | 300 |
| tf_efficientnet_lite3 *tfp | 79.734 (20.266) | 94.838 (5.162) | 8.20 | bilinear | 300 |
| gluon_resnet152_v1b | 79.692 (20.308) | 94.738 (5.262) | 60.19 | bicubic | 224 |
| gluon_xception65 | 79.604 (20.396) | 94.748 (5.252) | 39.92 | bicubic | 299 |
| gluon_resnet101_v1c | 79.544 (20.456) | 94.586 (5.414) | 44.57 | bicubic | 224 |
@ -299,6 +306,8 @@ For the models below, the model code and weight porting from Tensorflow or MXNet
| tf_efficientnet_es *tfp | 77.616 (22.384) | 93.750 (6.250) | 5.44 | bicubic | 224 |
| gluon_resnet50_v1b | 77.578 (22.422) | 93.718 (6.282) | 25.56 | bicubic | 224 |
| adv_inception_v3 | 77.576 (22.424) | 93.724 (6.276) | 27.16M | bicubic | 299 |
| tf_efficientnet_lite2 *tfp | 77.544 (22.456) | 93.800 (6.200) | 6.09 | bilinear | 260 |
| tf_efficientnet_lite2 | 77.460 (22.540) | 93.746 (6.254) | 6.09 | bicubic | 260 |
| tf_efficientnet_b0_ap *tfp | 77.514 (22.486) | 93.576 (6.424) | 5.29 | bicubic | 224 |
| tf_efficientnet_cc_b0_4e | 77.304 (22.696) | 93.332 (6.668) | 13.3 | bicubic | 224 |
| tf_efficientnet_es | 77.264 (22.736) | 93.600 (6.400) | 5.44 | bicubic | 224 |
@ -307,10 +316,14 @@ For the models below, the model code and weight porting from Tensorflow or MXNet
| tf_mixnet_m *tfp | 77.072 (22.928) | 93.368 (6.632) | 5.01 | bilinear | 224 |
| tf_mixnet_m | 76.950 (23.050) | 93.156 (6.844) | 5.01 | bicubic | 224 |
| tf_efficientnet_b0 | 76.848 (23.152) | 93.228 (6.772) | 5.29 | bicubic | 224 |
| tf_efficientnet_lite1 *tfp | 76.764 (23.236) | 93.326 (6.674) | 5.42 | bilinear | 240 |
| tf_efficientnet_lite1 | 76.638 (23.362) | 93.232 (6.768) | 5.42 | bicubic | 240 |
| tf_mixnet_s *tfp | 75.800 (24.200) | 92.788 (7.212) | 4.13 | bilinear | 224 |
| tf_mobilenetv3_large_100 *tfp | 75.768 (24.232) | 92.710 (7.290) | 5.48 | bilinear | 224 |
| tf_mixnet_s | 75.648 (24.352) | 92.636 (7.364) | 4.13 | bicubic | 224 |
| tf_mobilenetv3_large_100 | 75.516 (24.484) | 92.600 (7.400) | 5.48 | bilinear | 224 |
| tf_efficientnet_lite0 *tfp | 75.074 (24.926) | 92.314 (7.686) | 4.65 | bilinear | 224 |
| tf_efficientnet_lite0 | 74.842 (25.158) | 92.170 (7.830) | 4.65 | bicubic | 224 |
| gluon_resnet34_v1b | 74.580 (25.420) | 91.988 (8.012) | 21.80 | bicubic | 224 |
| tf_mobilenetv3_large_075 *tfp | 73.730 (26.270) | 91.616 (8.384) | 3.99 | bilinear | 224 |
| tf_mobilenetv3_large_075 | 73.442 (26.558) | 91.352 (8.648) | 3.99 | bilinear | 224 |

@ -191,12 +191,25 @@ model_list = [
model_desc='Ported from official Google AI Tensorflow weights'),
_entry('tf_efficientnet_cc_b1_8e', 'EfficientNet-CondConv-B1 8 experts', '1904.04971',
model_desc='Ported from official Google AI Tensorflow weights'),
_entry('tf_efficientnet_es', 'EfficientNet-EdgeTPU-S', '1905.11946',
model_desc='Ported from official Google AI Tensorflow weights'),
_entry('tf_efficientnet_em', 'EfficientNet-EdgeTPU-M', '1905.11946',
model_desc='Ported from official Google AI Tensorflow weights'),
_entry('tf_efficientnet_el', 'EfficientNet-EdgeTPU-L', '1905.11946', batch_size=BATCH_SIZE//2,
model_desc='Ported from official Google AI Tensorflow weights'),
_entry('tf_efficientnet_lite0', 'EfficientNet-Lite0', '1905.11946',
model_desc='Ported from official Google AI Tensorflow weights'),
_entry('tf_efficientnet_lite1', 'EfficientNet-Lite1', '1905.11946',
model_desc='Ported from official Google AI Tensorflow weights'),
_entry('tf_efficientnet_lite2', 'EfficientNet-Lite2', '1905.11946',
model_desc='Ported from official Google AI Tensorflow weights'),
_entry('tf_efficientnet_lite3', 'EfficientNet-Lite3', '1905.11946', batch_size=BATCH_SIZE // 2,
model_desc='Ported from official Google AI Tensorflow weights'),
_entry('tf_efficientnet_lite4', 'EfficientNet-Lite4', '1905.11946', batch_size=BATCH_SIZE // 2,
model_desc='Ported from official Google AI Tensorflow weights'),
_entry('tf_inception_v3', 'Inception V3', '1512.00567', model_desc='Ported from official Tensorflow weights'),
_entry('tf_mixnet_l', 'MixNet-L', '1907.09595', model_desc='Ported from official Google AI Tensorflow weights'),
_entry('tf_mixnet_m', 'MixNet-M', '1907.09595', model_desc='Ported from official Google AI Tensorflow weights'),

@ -52,12 +52,14 @@ default_cfgs = {
'mnasnet_100': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth'),
'mnasnet_140': _cfg(url=''),
'semnasnet_050': _cfg(url=''),
'semnasnet_075': _cfg(url=''),
'semnasnet_100': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth'),
'semnasnet_140': _cfg(url=''),
'mnasnet_small': _cfg(url=''),
'mobilenetv2_100': _cfg(url=''),
'fbnetc_100': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth',
@ -65,6 +67,7 @@ default_cfgs = {
'spnasnet_100': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/spnasnet_100-048bc3f4.pth',
interpolation='bilinear'),
'efficientnet_b0': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0_ra-3dd342df.pth'),
'efficientnet_b1': _cfg(
@ -94,15 +97,32 @@ default_cfgs = {
url='', input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
'efficientnet_l2': _cfg(
url='', input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.961),
'efficientnet_es': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth'),
'efficientnet_em': _cfg(
url='', input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
'efficientnet_el': _cfg(
url='', input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
'efficientnet_cc_b0_4e': _cfg(url=''),
'efficientnet_cc_b0_8e': _cfg(url=''),
'efficientnet_cc_b1_8e': _cfg(url='', input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
'efficientnet_lite0': _cfg(
url=''),
'efficientnet_lite1': _cfg(
url='',
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
'efficientnet_lite2': _cfg(
url='',
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
'efficientnet_lite3': _cfg(
url='',
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
'efficientnet_lite4': _cfg(
url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
'tf_efficientnet_b0': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
input_size=(3, 224, 224)),
@ -130,6 +150,7 @@ default_cfgs = {
'tf_efficientnet_b8': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth',
input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
'tf_efficientnet_b0_ap': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ap-f262efe1.pth',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, input_size=(3, 224, 224)),
@ -165,6 +186,7 @@ default_cfgs = {
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ap-00e169fa.pth',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
'tf_efficientnet_b0_ns': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth',
input_size=(3, 224, 224)),
@ -195,6 +217,7 @@ default_cfgs = {
'tf_efficientnet_l2_ns': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth',
input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.96),
'tf_efficientnet_es': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
@ -207,6 +230,7 @@ default_cfgs = {
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_el-5143854e.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
'tf_efficientnet_cc_b0_4e': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_4e-4362b6b2.pth',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
@ -217,6 +241,33 @@ default_cfgs = {
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b1_8e-f7c79ae1.pth',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
'tf_efficientnet_lite0': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite0-0aa007d2.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
interpolation='bicubic', # should be bilinear but bicubic better match for TF bilinear at low res
),
'tf_efficientnet_lite1': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite1-bde8b488.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882,
interpolation='bicubic', # should be bilinear but bicubic better match for TF bilinear at low res
),
'tf_efficientnet_lite2': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite2-dcccb7df.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890,
interpolation='bicubic', # should be bilinear but bicubic better match for TF bilinear at low res
),
'tf_efficientnet_lite3': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite3-b733e338.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904, interpolation='bilinear'),
'tf_efficientnet_lite4': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite4-741542c3.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.920, interpolation='bilinear'),
'mixnet_s': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth'),
'mixnet_m': _cfg(
@ -226,6 +277,7 @@ default_cfgs = {
'mixnet_xl': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_xl_ra-aac3c00c.pth'),
'mixnet_xxl': _cfg(),
'tf_mixnet_s': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth'),
'tf_mixnet_m': _cfg(
@ -253,7 +305,7 @@ class EfficientNet(nn.Module):
def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32,
channel_multiplier=1.0, channel_divisor=8, channel_min=None,
output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
output_stride=32, pad_type='', fix_stem=False, act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'):
super(EfficientNet, self).__init__()
norm_kwargs = norm_kwargs or {}
@ -264,7 +316,8 @@ class EfficientNet(nn.Module):
self._in_chs = in_chans
# Stem
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
if not fix_stem:
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size, **norm_kwargs)
self.act1 = act_layer(inplace=True)
@ -333,7 +386,7 @@ class EfficientNetFeatures(nn.Module):
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl',
in_chans=3, stem_size=32, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
output_stride=32, pad_type='', fix_stem=False, act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
super(EfficientNetFeatures, self).__init__()
norm_kwargs = norm_kwargs or {}
@ -346,7 +399,8 @@ class EfficientNetFeatures(nn.Module):
self._in_chs = in_chans
# Stem
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
if not fix_stem:
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size, **norm_kwargs)
self.act1 = act_layer(inplace=True)
@ -707,6 +761,47 @@ def _gen_efficientnet_condconv(
return model
def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
"""Creates an EfficientNet-Lite model.
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite
Paper: https://arxiv.org/abs/1905.11946
EfficientNet params
name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
'efficientnet-lite0': (1.0, 1.0, 224, 0.2),
'efficientnet-lite1': (1.0, 1.1, 240, 0.2),
'efficientnet-lite2': (1.1, 1.2, 260, 0.3),
'efficientnet-lite3': (1.2, 1.4, 280, 0.3),
'efficientnet-lite4': (1.4, 1.8, 300, 0.3),
Args:
channel_multiplier: multiplier to number of channels per layer
depth_multiplier: multiplier to number of repeats per stage
"""
arch_def = [
['ds_r1_k3_s1_e1_c16'],
['ir_r2_k3_s2_e6_c24'],
['ir_r2_k5_s2_e6_c40'],
['ir_r3_k3_s2_e6_c80'],
['ir_r3_k5_s1_e6_c112'],
['ir_r4_k5_s2_e6_c192'],
['ir_r1_k3_s1_e6_c320'],
]
model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier, fix_first_last=True),
num_features=1280,
stem_size=32,
fix_stem=True,
channel_multiplier=channel_multiplier,
act_layer=nn.ReLU6,
norm_kwargs=resolve_bn_args(kwargs),
**kwargs,
)
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
return model
def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
"""Creates a MixNet Small model.
@ -1032,6 +1127,51 @@ def efficientnet_cc_b1_8e(pretrained=False, **kwargs):
return model
@register_model
def efficientnet_lite0(pretrained=False, **kwargs):
""" EfficientNet-Lite0 """
# NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
model = _gen_efficientnet_lite(
'efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
return model
@register_model
def efficientnet_lite1(pretrained=False, **kwargs):
""" EfficientNet-Lite1 """
# NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
model = _gen_efficientnet_lite(
'efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
return model
@register_model
def efficientnet_lite2(pretrained=False, **kwargs):
""" EfficientNet-Lite2 """
# NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
model = _gen_efficientnet_lite(
'efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
return model
@register_model
def efficientnet_lite3(pretrained=False, **kwargs):
""" EfficientNet-Lite3 """
# NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
model = _gen_efficientnet_lite(
'efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
return model
@register_model
def efficientnet_lite4(pretrained=False, **kwargs):
""" EfficientNet-Lite4 """
# NOTE for train, drop_rate should be 0.4, drop_path_rate should be 0.2
model = _gen_efficientnet_lite(
'efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
return model
@register_model
def tf_efficientnet_b0(pretrained=False, **kwargs):
""" EfficientNet-B0. Tensorflow compatible variant """
@ -1386,6 +1526,61 @@ def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs):
return model
@register_model
def tf_efficientnet_lite0(pretrained=False, **kwargs):
""" EfficientNet-Lite0 """
# NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet_lite(
'tf_efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
return model
@register_model
def tf_efficientnet_lite1(pretrained=False, **kwargs):
""" EfficientNet-Lite1 """
# NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet_lite(
'tf_efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
return model
@register_model
def tf_efficientnet_lite2(pretrained=False, **kwargs):
""" EfficientNet-Lite2 """
# NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet_lite(
'tf_efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
return model
@register_model
def tf_efficientnet_lite3(pretrained=False, **kwargs):
""" EfficientNet-Lite3 """
# NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet_lite(
'tf_efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
return model
@register_model
def tf_efficientnet_lite4(pretrained=False, **kwargs):
""" EfficientNet-Lite4 """
# NOTE for train, drop_rate should be 0.4, drop_path_rate should be 0.2
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet_lite(
'tf_efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
return model
@register_model
def mixnet_s(pretrained=False, **kwargs):
"""Creates a MixNet Small model.

@ -174,7 +174,7 @@ def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='c
return sa_scaled
def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1):
def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False):
arch_args = []
for stack_idx, block_strings in enumerate(arch_def):
assert isinstance(block_strings, list)
@ -187,7 +187,10 @@ def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_
ba['num_experts'] *= experts_multiplier
stack_args.append(ba)
repeats.append(rep)
arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc))
if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1):
arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc))
else:
arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc))
return arch_args

Loading…
Cancel
Save