From bd05258f7b3808c786ece91cd7622f58e51bcb70 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 17 Mar 2020 23:31:45 -0700 Subject: [PATCH 1/3] EfficientNet-Lite model added w/ converted checkpoints, validation in progress... --- timm/models/efficientnet.py | 203 +++++++++++++++++++++++++++- timm/models/efficientnet_builder.py | 7 +- 2 files changed, 204 insertions(+), 6 deletions(-) diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index fe275c0f..7439ea34 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -52,12 +52,14 @@ default_cfgs = { 'mnasnet_100': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth'), 'mnasnet_140': _cfg(url=''), + 'semnasnet_050': _cfg(url=''), 'semnasnet_075': _cfg(url=''), 'semnasnet_100': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth'), 'semnasnet_140': _cfg(url=''), 'mnasnet_small': _cfg(url=''), + 'mobilenetv2_100': _cfg(url=''), 'fbnetc_100': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth', @@ -65,6 +67,7 @@ default_cfgs = { 'spnasnet_100': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/spnasnet_100-048bc3f4.pth', interpolation='bilinear'), + 'efficientnet_b0': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0_ra-3dd342df.pth'), 'efficientnet_b1': _cfg( @@ -94,15 +97,32 @@ default_cfgs = { url='', input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954), 'efficientnet_l2': _cfg( url='', input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.961), + 'efficientnet_es': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth'), 'efficientnet_em': _cfg( url='', input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), 'efficientnet_el': _cfg( url='', input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'efficientnet_cc_b0_4e': _cfg(url=''), 'efficientnet_cc_b0_8e': _cfg(url=''), 'efficientnet_cc_b1_8e': _cfg(url='', input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + + 'efficientnet_lite0': _cfg( + url=''), + 'efficientnet_lite1': _cfg( + url='', + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'efficientnet_lite2': _cfg( + url='', + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), + 'efficientnet_lite3': _cfg( + url='', + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'efficientnet_lite4': _cfg( + url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), + 'tf_efficientnet_b0': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth', input_size=(3, 224, 224)), @@ -130,6 +150,7 @@ default_cfgs = { 'tf_efficientnet_b8': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth', input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954), + 'tf_efficientnet_b0_ap': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ap-f262efe1.pth', mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, input_size=(3, 224, 224)), @@ -165,6 +186,7 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ap-00e169fa.pth', mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954), + 'tf_efficientnet_b0_ns': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth', input_size=(3, 224, 224)), @@ -195,6 +217,7 @@ default_cfgs = { 'tf_efficientnet_l2_ns': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth', input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.96), + 'tf_efficientnet_es': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), @@ -207,6 +230,7 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_el-5143854e.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'tf_efficientnet_cc_b0_4e': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_4e-4362b6b2.pth', mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), @@ -217,6 +241,33 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b1_8e-f7c79ae1.pth', mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + + 'tf_efficientnet_lite0': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite0-0aa007d2.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + interpolation='bicubic', # should be bilinear but bicubic better match for TF bilinear at low res + ), + 'tf_efficientnet_lite1': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite1-bde8b488.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882, + interpolation='bicubic', # should be bilinear but bicubic better match for TF bilinear at low res + ), + 'tf_efficientnet_lite2': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite2-dcccb7df.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890, + interpolation='bicubic', # should be bilinear but bicubic better match for TF bilinear at low res + ), + 'tf_efficientnet_lite3': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite3-b733e338.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904, interpolation='bilinear'), + 'tf_efficientnet_lite4': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite4-741542c3.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922, interpolation='bilinear'), + 'mixnet_s': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth'), 'mixnet_m': _cfg( @@ -226,6 +277,7 @@ default_cfgs = { 'mixnet_xl': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_xl_ra-aac3c00c.pth'), 'mixnet_xxl': _cfg(), + 'tf_mixnet_s': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth'), 'tf_mixnet_m': _cfg( @@ -253,7 +305,7 @@ class EfficientNet(nn.Module): def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32, channel_multiplier=1.0, channel_divisor=8, channel_min=None, - output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., + output_stride=32, pad_type='', fix_stem=False, act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'): super(EfficientNet, self).__init__() norm_kwargs = norm_kwargs or {} @@ -264,7 +316,8 @@ class EfficientNet(nn.Module): self._in_chs = in_chans # Stem - stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) + if not fix_stem: + stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type) self.bn1 = norm_layer(stem_size, **norm_kwargs) self.act1 = act_layer(inplace=True) @@ -333,7 +386,7 @@ class EfficientNetFeatures(nn.Module): def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl', in_chans=3, stem_size=32, channel_multiplier=1.0, channel_divisor=8, channel_min=None, - output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., + output_stride=32, pad_type='', fix_stem=False, act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None): super(EfficientNetFeatures, self).__init__() norm_kwargs = norm_kwargs or {} @@ -346,7 +399,8 @@ class EfficientNetFeatures(nn.Module): self._in_chs = in_chans # Stem - stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) + if not fix_stem: + stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type) self.bn1 = norm_layer(stem_size, **norm_kwargs) self.act1 = act_layer(inplace=True) @@ -707,6 +761,47 @@ def _gen_efficientnet_condconv( return model +def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """Creates an EfficientNet-Lite model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite + Paper: https://arxiv.org/abs/1905.11946 + + EfficientNet params + name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) + 'efficientnet-lite0': (1.0, 1.0, 224, 0.2), + 'efficientnet-lite1': (1.0, 1.1, 240, 0.2), + 'efficientnet-lite2': (1.1, 1.2, 260, 0.3), + 'efficientnet-lite3': (1.2, 1.4, 280, 0.3), + 'efficientnet-lite4': (1.4, 1.8, 300, 0.3), + + Args: + channel_multiplier: multiplier to number of channels per layer + depth_multiplier: multiplier to number of repeats per stage + """ + arch_def = [ + ['ds_r1_k3_s1_e1_c16'], + ['ir_r2_k3_s2_e6_c24'], + ['ir_r2_k5_s2_e6_c40'], + ['ir_r3_k3_s2_e6_c80'], + ['ir_r3_k5_s1_e6_c112'], + ['ir_r4_k5_s2_e6_c192'], + ['ir_r1_k3_s1_e6_c320'], + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, fix_first_last=True), + num_features=1280, + stem_size=32, + fix_stem=True, + channel_multiplier=channel_multiplier, + act_layer=nn.ReLU6, + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_model(model_kwargs, default_cfgs[variant], pretrained) + return model + + def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs): """Creates a MixNet Small model. @@ -1032,6 +1127,51 @@ def efficientnet_cc_b1_8e(pretrained=False, **kwargs): return model +@register_model +def efficientnet_lite0(pretrained=False, **kwargs): + """ EfficientNet-Lite0 """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + model = _gen_efficientnet_lite( + 'efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_lite1(pretrained=False, **kwargs): + """ EfficientNet-Lite1 """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + model = _gen_efficientnet_lite( + 'efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_lite2(pretrained=False, **kwargs): + """ EfficientNet-Lite2 """ + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + model = _gen_efficientnet_lite( + 'efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_lite3(pretrained=False, **kwargs): + """ EfficientNet-Lite3 """ + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + model = _gen_efficientnet_lite( + 'efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_lite4(pretrained=False, **kwargs): + """ EfficientNet-Lite4 """ + # NOTE for train, drop_rate should be 0.4, drop_path_rate should be 0.2 + model = _gen_efficientnet_lite( + 'efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + @register_model def tf_efficientnet_b0(pretrained=False, **kwargs): """ EfficientNet-B0. Tensorflow compatible variant """ @@ -1386,6 +1526,61 @@ def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs): return model +@register_model +def tf_efficientnet_lite0(pretrained=False, **kwargs): + """ EfficientNet-Lite0 """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_lite1(pretrained=False, **kwargs): + """ EfficientNet-Lite1 """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_lite2(pretrained=False, **kwargs): + """ EfficientNet-Lite2 """ + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_lite3(pretrained=False, **kwargs): + """ EfficientNet-Lite3 """ + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_lite4(pretrained=False, **kwargs): + """ EfficientNet-Lite4 """ + # NOTE for train, drop_rate should be 0.4, drop_path_rate should be 0.2 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + @register_model def mixnet_s(pretrained=False, **kwargs): """Creates a MixNet Small model. diff --git a/timm/models/efficientnet_builder.py b/timm/models/efficientnet_builder.py index f8f0df8a..3876a2d1 100644 --- a/timm/models/efficientnet_builder.py +++ b/timm/models/efficientnet_builder.py @@ -174,7 +174,7 @@ def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='c return sa_scaled -def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1): +def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False): arch_args = [] for stack_idx, block_strings in enumerate(arch_def): assert isinstance(block_strings, list) @@ -187,7 +187,10 @@ def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_ ba['num_experts'] *= experts_multiplier stack_args.append(ba) repeats.append(rep) - arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc)) + if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1): + arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc)) + else: + arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc)) return arch_args From 3406e582cf8a03d883a6a43e96ad43bd8852a69e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 18 Mar 2020 13:12:30 -0700 Subject: [PATCH 2/3] Add EfficientNet-Lite results, update README --- README.md | 13 +++++++++++++ timm/models/efficientnet.py | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 62c93cce..56112e5c 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,9 @@ ## What's New +### March 18, 2020 +* Add EfficientNet-Lite models w/ weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite) + ### Feb 29, 2020 * New MobileNet-V3 Large weights trained from stratch with this code to 75.77% top-1 * IMPORTANT CHANGE - default weight init changed for all MobilenetV3 / EfficientNet / related models @@ -252,7 +255,9 @@ For the models below, the model code and weight porting from Tensorflow or MXNet | tf_efficientnet_b3_ap | 81.828 (18.172) | 95.624 (4.376) | 12.23 | bicubic | 300 | | tf_efficientnet_b3 | 81.636 (18.364) | 95.718 (4.282) | 12.23 | bicubic | 300 | | tf_efficientnet_b3 *tfp | 81.576 (18.424) | 95.662 (4.338) | 12.23 | bicubic | 300 | +| tf_efficientnet_lite4 | 81.528 (18.472) | 95.668 (4.332) | 13.00 | bilinear | 380 | | tf_efficientnet_b1_ns *tfp | 81.514 (18.486) | 95.776 (4.224) | 7.79 | bicubic | 240 | +| tf_efficientnet_lite4 *tfp | 81.502 (18.498) | 95.676 (4.324) | 13.00 | bilinear | 380 | | tf_efficientnet_b1_ns | 81.388 (18.612) | 95.738 (4.262) | 7.79 | bicubic | 240 | | gluon_senet154 | 81.224 (18.776) | 95.356 (4.644) | 115.09 | bicubic | 224 | | gluon_resnet152_v1s | 81.012 (18.988) | 95.416 (4.584) | 60.32 | bicubic | 224 | @@ -271,6 +276,8 @@ For the models below, the model code and weight porting from Tensorflow or MXNet | tf_efficientnet_b2 | 80.086 (19.914) | 94.908 (5.092) | 9.11 | bicubic | 260 | | gluon_resnet152_v1c | 79.916 (20.084) | 94.842 (5.158) | 60.21 | bicubic | 224 | | gluon_seresnext50_32x4d | 79.912 (20.088) | 94.818 (5.182) | 27.56 | bicubic | 224 | +| tf_efficientnet_lite3 | 79.812 (20.188) | 94.914 (5.086) | 8.20 | bilinear | 300 | +| tf_efficientnet_lite3 *tfp | 79.734 (20.266) | 94.838 (5.162) | 8.20 | bilinear | 300 | | gluon_resnet152_v1b | 79.692 (20.308) | 94.738 (5.262) | 60.19 | bicubic | 224 | | gluon_xception65 | 79.604 (20.396) | 94.748 (5.252) | 39.92 | bicubic | 299 | | gluon_resnet101_v1c | 79.544 (20.456) | 94.586 (5.414) | 44.57 | bicubic | 224 | @@ -299,6 +306,8 @@ For the models below, the model code and weight porting from Tensorflow or MXNet | tf_efficientnet_es *tfp | 77.616 (22.384) | 93.750 (6.250) | 5.44 | bicubic | 224 | | gluon_resnet50_v1b | 77.578 (22.422) | 93.718 (6.282) | 25.56 | bicubic | 224 | | adv_inception_v3 | 77.576 (22.424) | 93.724 (6.276) | 27.16M | bicubic | 299 | +| tf_efficientnet_lite2 *tfp | 77.544 (22.456) | 93.800 (6.200) | 6.09 | bilinear | 260 | +| tf_efficientnet_lite2 | 77.460 (22.540) | 93.746 (6.254) | 6.09 | bicubic | 260 | | tf_efficientnet_b0_ap *tfp | 77.514 (22.486) | 93.576 (6.424) | 5.29 | bicubic | 224 | | tf_efficientnet_cc_b0_4e | 77.304 (22.696) | 93.332 (6.668) | 13.3 | bicubic | 224 | | tf_efficientnet_es | 77.264 (22.736) | 93.600 (6.400) | 5.44 | bicubic | 224 | @@ -307,10 +316,14 @@ For the models below, the model code and weight porting from Tensorflow or MXNet | tf_mixnet_m *tfp | 77.072 (22.928) | 93.368 (6.632) | 5.01 | bilinear | 224 | | tf_mixnet_m | 76.950 (23.050) | 93.156 (6.844) | 5.01 | bicubic | 224 | | tf_efficientnet_b0 | 76.848 (23.152) | 93.228 (6.772) | 5.29 | bicubic | 224 | +| tf_efficientnet_lite1 *tfp | 76.764 (23.236) | 93.326 (6.674) | 5.42 | bilinear | 240 | +| tf_efficientnet_lite1 | 76.638 (23.362) | 93.232 (6.768) | 5.42 | bicubic | 240 | | tf_mixnet_s *tfp | 75.800 (24.200) | 92.788 (7.212) | 4.13 | bilinear | 224 | | tf_mobilenetv3_large_100 *tfp | 75.768 (24.232) | 92.710 (7.290) | 5.48 | bilinear | 224 | | tf_mixnet_s | 75.648 (24.352) | 92.636 (7.364) | 4.13 | bicubic | 224 | | tf_mobilenetv3_large_100 | 75.516 (24.484) | 92.600 (7.400) | 5.48 | bilinear | 224 | +| tf_efficientnet_lite0 *tfp | 75.074 (24.926) | 92.314 (7.686) | 4.65 | bilinear | 224 | +| tf_efficientnet_lite0 | 74.842 (25.158) | 92.170 (7.830) | 4.65 | bicubic | 224 | | gluon_resnet34_v1b | 74.580 (25.420) | 91.988 (8.012) | 21.80 | bicubic | 224 | | tf_mobilenetv3_large_075 *tfp | 73.730 (26.270) | 91.616 (8.384) | 3.99 | bilinear | 224 | | tf_mobilenetv3_large_075 | 73.442 (26.558) | 91.352 (8.648) | 3.99 | bilinear | 224 | diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 7439ea34..cddb750e 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -266,7 +266,7 @@ default_cfgs = { 'tf_efficientnet_lite4': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite4-741542c3.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922, interpolation='bilinear'), + input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.920, interpolation='bilinear'), 'mixnet_s': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth'), From ef30e7f2b6fd1a1a9c4a27e21f4aa09d04ad1f54 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 18 Mar 2020 13:15:54 -0700 Subject: [PATCH 3/3] Update sotabench with EffNet Lite models --- sotabench.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/sotabench.py b/sotabench.py index 7b896819..de3828f4 100644 --- a/sotabench.py +++ b/sotabench.py @@ -191,12 +191,25 @@ model_list = [ model_desc='Ported from official Google AI Tensorflow weights'), _entry('tf_efficientnet_cc_b1_8e', 'EfficientNet-CondConv-B1 8 experts', '1904.04971', model_desc='Ported from official Google AI Tensorflow weights'), + _entry('tf_efficientnet_es', 'EfficientNet-EdgeTPU-S', '1905.11946', model_desc='Ported from official Google AI Tensorflow weights'), _entry('tf_efficientnet_em', 'EfficientNet-EdgeTPU-M', '1905.11946', model_desc='Ported from official Google AI Tensorflow weights'), _entry('tf_efficientnet_el', 'EfficientNet-EdgeTPU-L', '1905.11946', batch_size=BATCH_SIZE//2, model_desc='Ported from official Google AI Tensorflow weights'), + + _entry('tf_efficientnet_lite0', 'EfficientNet-Lite0', '1905.11946', + model_desc='Ported from official Google AI Tensorflow weights'), + _entry('tf_efficientnet_lite1', 'EfficientNet-Lite1', '1905.11946', + model_desc='Ported from official Google AI Tensorflow weights'), + _entry('tf_efficientnet_lite2', 'EfficientNet-Lite2', '1905.11946', + model_desc='Ported from official Google AI Tensorflow weights'), + _entry('tf_efficientnet_lite3', 'EfficientNet-Lite3', '1905.11946', batch_size=BATCH_SIZE // 2, + model_desc='Ported from official Google AI Tensorflow weights'), + _entry('tf_efficientnet_lite4', 'EfficientNet-Lite4', '1905.11946', batch_size=BATCH_SIZE // 2, + model_desc='Ported from official Google AI Tensorflow weights'), + _entry('tf_inception_v3', 'Inception V3', '1512.00567', model_desc='Ported from official Tensorflow weights'), _entry('tf_mixnet_l', 'MixNet-L', '1907.09595', model_desc='Ported from official Google AI Tensorflow weights'), _entry('tf_mixnet_m', 'MixNet-M', '1907.09595', model_desc='Ported from official Google AI Tensorflow weights'),