Add ported EfficientNet-L2, B0-B7 NoisyStudent weights from TF TPU

pull/88/head
Ross Wightman 4 years ago
parent d0eb59ef46
commit ba15ca47e8

@ -2,6 +2,9 @@
## What's New
### Feb 12, 2020
* Add EfficientNet-L2 and B0-B7 NoisyStudent weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet)
### Feb 6, 2020
* Add RandAugment trained EfficientNet-ES (EdgeTPU-Small) weights with 78.1 top-1. Trained by [Andrew Lavin](https://github.com/andravin) (see Training section for hparams)
@ -98,15 +101,16 @@ Included models:
* DPN (from [myself](https://github.com/rwightman/pytorch-dpn-pretrained))
* DPN-68, DPN-68b, DPN-92, DPN-98, DPN-131, DPN-107
* EfficientNet (from my standalone [GenEfficientNet](https://github.com/rwightman/gen-efficientnet-pytorch)) - A generic model that implements many of the efficient models that utilize similar DepthwiseSeparable and InvertedResidual blocks
* EfficientNet AdvProp (B0-B8) (https://arxiv.org/abs/1911.09665) -- TF weights ported
* EfficientNet (B0-B7) (https://arxiv.org/abs/1905.11946) -- TF weights ported, B0-B2 finetuned PyTorch
* EfficientNet-EdgeTPU (S, M, L) (https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html) --TF weights ported
* MixNet (https://arxiv.org/abs/1907.09595) -- TF weights ported, PyTorch finetuned (S, M, L) or trained models (XL)
* MNASNet B1, A1 (Squeeze-Excite), and Small (https://arxiv.org/abs/1807.11626) -- trained in PyTorch
* EfficientNet NoisyStudent (B0-B7, L2) (https://arxiv.org/abs/1911.04252)
* EfficientNet AdvProp (B0-B8) (https://arxiv.org/abs/1911.09665)
* EfficientNet (B0-B7) (https://arxiv.org/abs/1905.11946)
* EfficientNet-EdgeTPU (S, M, L) (https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html)
* MixNet (https://arxiv.org/abs/1907.09595)
* MNASNet B1, A1 (Squeeze-Excite), and Small (https://arxiv.org/abs/1807.11626)
* MobileNet-V2 (https://arxiv.org/abs/1801.04381)
* FBNet-C (https://arxiv.org/abs/1812.03443) -- trained in PyTorch
* Single-Path NAS (https://arxiv.org/abs/1904.02877) -- pixel1 variant
* MobileNet-V3 (https://arxiv.org/abs/1905.02244) -- pretrained PyTorch model, official TF weights ported
* FBNet-C (https://arxiv.org/abs/1812.03443)
* Single-Path NAS (https://arxiv.org/abs/1904.02877)
* MobileNet-V3 (https://arxiv.org/abs/1905.02244)
* HRNet
* code from https://github.com/HRNet/HRNet-Image-Classification, paper https://arxiv.org/abs/1908.07919
* SelecSLS
@ -178,30 +182,48 @@ For the models below, the model code and weight porting from Tensorflow or MXNet
| Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling | Image Size |
|---|---|---|---|---|---|
| tf_efficientnet_l2_ns *tfp | 88.352 (11.648) | 98.652 (1.348) | 480 | bicubic | 800 |
| tf_efficientnet_l2_ns | TBD | TBD | 480 | bicubic | 800 |
| tf_efficientnet_l2_ns *tfp | 88.172 (11.828) | 98.566 (1.434) | 480 | bicubic | 475 |
| tf_efficientnet_l2_ns | TBD | TBD | 480 | bicubic | 475 |
| tf_efficientnet_b7_ns *tfp | 86.844 (13.156) | 98.084 (1.916) | 66.35 | bicubic | 600 |
| tf_efficientnet_b7_ns | 86.840 (13.160) | 98.094 (1.906) | 66.35 | bicubic | 600 |
| tf_efficientnet_b6_ns | 86.452 (13.548) | 97.882 (2.118) | 43.04 | bicubic | 528 |
| tf_efficientnet_b6_ns *tfp | 86.444 (13.556) | 97.880 (2.120) | 43.04 | bicubic | 528 |
| tf_efficientnet_b5_ns *tfp | 86.064 (13.936) | 97.746 (2.254) | 30.39 | bicubic | 456 |
| tf_efficientnet_b5_ns | 86.088 (13.912) | 97.752 (2.248) | 30.39 | bicubic | 456 |
| tf_efficientnet_b8_ap *tfp | 85.436 (14.564) | 97.272 (2.728) | 87.4 | bicubic | 672 |
| tf_efficientnet_b8 *tfp | 85.384 (14.616) | 97.394 (2.606) | 87.4 | bicubic | 672 |
| tf_efficientnet_b8 | 85.37 (14.63) | 97.39 (2.61) | 87.4 | bicubic | 672 |
| tf_efficientnet_b8 *tfp | 85.384 (14.616) | 97.394 (2.606) | 87.4 | bicubic | 672 |
| tf_efficientnet_b8 | 85.37 (14.63) | 97.39 (2.61) | 87.4 | bicubic | 672 |
| tf_efficientnet_b8_ap | 85.368 (14.632) | 97.294 (2.706) | 87.4 | bicubic | 672 |
| tf_efficientnet_b4_ns *tfp | 85.298 (14.702) | 97.504 (2.496) | 19.34 | bicubic | 380 |
| tf_efficientnet_b4_ns | 85.162 (14.838) | 97.470 (2.530) | 19.34 | bicubic | 380 |
| tf_efficientnet_b7_ap *tfp | 85.154 (14.846) | 97.244 (2.756) | 66.35 | bicubic | 600 |
| tf_efficientnet_b7_ap | 85.118 (14.882) | 97.252 (2.748) | 66.35 | bicubic | 600 |
| tf_efficientnet_b7 *tfp | 84.940 (15.060) | 97.214 (2.786) | 66.35 | bicubic | 600 |
| tf_efficientnet_b7 | 84.932 (15.068) | 97.208 (2.792) | 66.35 | bicubic | 600 |
| tf_efficientnet_b7 *tfp | 84.940 (15.060) | 97.214 (2.786) | 66.35 | bicubic | 600 |
| tf_efficientnet_b7 | 84.932 (15.068) | 97.208 (2.792) | 66.35 | bicubic | 600 |
| tf_efficientnet_b6_ap | 84.786 (15.214) | 97.138 (2.862) | 43.04 | bicubic | 528 |
| tf_efficientnet_b6_ap *tfp | 84.760 (15.240) | 97.124 (2.876) | 43.04 | bicubic | 528 |
| tf_efficientnet_b5_ap *tfp | 84.276 (15.724) | 96.932 (3.068) | 30.39 | bicubic | 456 |
| tf_efficientnet_b5_ap | 84.254 (15.746) | 96.976 (3.024) | 30.39 | bicubic | 456 |
| tf_efficientnet_b6 *tfp | 84.140 (15.860) | 96.852 (3.148) | 43.04 | bicubic | 528 |
| tf_efficientnet_b6 | 84.110 (15.890) | 96.886 (3.114) | 43.04 | bicubic | 528 |
| tf_efficientnet_b5 *tfp | 83.822 (16.178) | 96.756 (3.244) | 30.39 | bicubic | 456 |
| tf_efficientnet_b5 | 83.812 (16.188) | 96.748 (3.252) | 30.39 | bicubic | 456 |
| tf_efficientnet_b6 *tfp | 84.140 (15.860) | 96.852 (3.148) | 43.04 | bicubic | 528 |
| tf_efficientnet_b6 | 84.110 (15.890) | 96.886 (3.114) | 43.04 | bicubic | 528 |
| tf_efficientnet_b3_ns *tfp | 84.054 (15.946) | 96.918 (3.082) | 12.23 | bicubic | 300 |
| tf_efficientnet_b3_ns | 84.048 (15.952) | 96.910 (3.090) | 12.23 | bicubic | 300 |
| tf_efficientnet_b5 *tfp | 83.822 (16.178) | 96.756 (3.244) | 30.39 | bicubic | 456 |
| tf_efficientnet_b5 | 83.812 (16.188) | 96.748 (3.252) | 30.39 | bicubic | 456 |
| tf_efficientnet_b4_ap *tfp | 83.278 (16.722) | 96.376 (3.624) | 19.34 | bicubic | 380 |
| tf_efficientnet_b4_ap | 83.248 (16.752) | 96.388 (3.612) | 19.34 | bicubic | 380 |
| tf_efficientnet_b4 | 83.022 (16.978) | 96.300 (3.700) | 19.34 | bicubic | 380 |
| tf_efficientnet_b4 *tfp | 82.948 (17.052) | 96.308 (3.692) | 19.34 | bicubic | 380 |
| tf_efficientnet_b4 | 83.022 (16.978) | 96.300 (3.700) | 19.34 | bicubic | 380 |
| tf_efficientnet_b4 *tfp | 82.948 (17.052) | 96.308 (3.692) | 19.34 | bicubic | 380 |
| tf_efficientnet_b2_ns *tfp | 82.436 (17.564) | 96.268 (3.732) | 9.11 | bicubic | 260 |
| tf_efficientnet_b2_ns | 82.380 (17.620) | 96.248 (3.752) | 9.11 | bicubic | 260 |
| tf_efficientnet_b3_ap *tfp | 81.882 (18.118) | 95.662 (4.338) | 12.23 | bicubic | 300 |
| tf_efficientnet_b3_ap | 81.828 (18.172) | 95.624 (4.376) | 12.23 | bicubic | 300 |
| tf_efficientnet_b3 | 81.636 (18.364) | 95.718 (4.282) | 12.23 | bicubic | 300 |
| tf_efficientnet_b3 *tfp | 81.576 (18.424) | 95.662 (4.338) | 12.23 | bicubic | 300 |
| tf_efficientnet_b3 | 81.636 (18.364) | 95.718 (4.282) | 12.23 | bicubic | 300 |
| tf_efficientnet_b3 *tfp | 81.576 (18.424) | 95.662 (4.338) | 12.23 | bicubic | 300 |
| tf_efficientnet_b1_ns *tfp | 81.514 (18.486) | 95.776 (4.224) | 7.79 | bicubic | 240 |
| tf_efficientnet_b1_ns | 81.388 (18.612) | 95.738 (4.262) | 7.79 | bicubic | 240 |
| gluon_senet154 | 81.224 (18.776) | 95.356 (4.644) | 115.09 | bicubic | 224 |
| gluon_resnet152_v1s | 81.012 (18.988) | 95.416 (4.584) | 60.32 | bicubic | 224 |
| gluon_seresnext101_32x4d | 80.902 (19.098) | 95.294 (4.706) | 48.96 | bicubic | 224 |
@ -233,10 +255,12 @@ For the models below, the model code and weight porting from Tensorflow or MXNet
| tf_efficientnet_em *tfp | 78.958 (21.042) | 94.458 (5.542) | 6.90 | bicubic | 240 |
| tf_mixnet_l *tfp | 78.846 (21.154) | 94.212 (5.788) | 7.33 | bilinear | 224 |
| tf_efficientnet_b1 | 78.826 (21.174) | 94.198 (5.802) | 7.79 | bicubic | 240 |
| tf_efficientnet_b0_ns *tfp | 78.806 (21.194) | 94.496 (5.504) | 5.29 | bicubic | 224 |
| gluon_inception_v3 | 78.804 (21.196) | 94.380 (5.620) | 27.16M | bicubic | 299 |
| tf_mixnet_l | 78.770 (21.230) | 94.004 (5.996) | 7.33 | bicubic | 224 |
| tf_efficientnet_em | 78.742 (21.258) | 94.332 (5.668) | 6.90 | bicubic | 240 |
| gluon_resnet50_v1s | 78.712 (21.288) | 94.242 (5.758) | 25.68 | bicubic | 224 |
| tf_efficientnet_b0_ns | 78.658 (21.342) | 94.376 (5.624) | 5.29 | bicubic | 224 |
| tf_efficientnet_cc_b0_8e *tfp | 78.314 (21.686) | 93.790 (6.210) | 24.0 | bicubic | 224 |
| gluon_resnet50_v1c | 78.010 (21.990) | 93.988 (6.012) | 25.58 | bicubic | 224 |
| tf_efficientnet_cc_b0_8e | 77.908 (22.092) | 93.656 (6.344) | 24.0 | bicubic | 224 |

@ -2,10 +2,11 @@
An implementation of EfficienNet that covers variety of related models with efficient architectures:
* EfficientNet (B0-B8 + Tensorflow pretrained AutoAug/RandAug/AdvProp weight ports)
* EfficientNet (B0-B8, L2 + Tensorflow pretrained AutoAug/RandAug/AdvProp/NoisyStudent weight ports)
- EfficientNet: Rethinking Model Scaling for CNNs - https://arxiv.org/abs/1905.11946
- CondConv: Conditionally Parameterized Convolutions for Efficient Inference - https://arxiv.org/abs/1904.04971
- Adversarial Examples Improve Image Recognition - https://arxiv.org/abs/1911.09665
- Self-training with Noisy Student improves ImageNet classification - https://arxiv.org/abs/1911.04252
* MixNet (Small, Medium, and Large)
- MixConv: Mixed Depthwise Convolutional Kernels - https://arxiv.org/abs/1907.09595
@ -91,6 +92,8 @@ default_cfgs = {
url='', input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
'efficientnet_b8': _cfg(
url='', input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
'efficientnet_l2': _cfg(
url='', input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.961),
'efficientnet_es': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth'),
'efficientnet_em': _cfg(
@ -162,6 +165,36 @@ default_cfgs = {
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ap-00e169fa.pth',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
'tf_efficientnet_b0_ns': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth',
input_size=(3, 224, 224)),
'tf_efficientnet_b1_ns': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ns-99dd0c41.pth',
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
'tf_efficientnet_b2_ns': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ns-00306e48.pth',
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
'tf_efficientnet_b3_ns': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth',
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
'tf_efficientnet_b4_ns': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth',
input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
'tf_efficientnet_b5_ns': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth',
input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
'tf_efficientnet_b6_ns': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth',
input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942),
'tf_efficientnet_b7_ns': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth',
input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
'tf_efficientnet_l2_ns_475': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns_475-bebbd00a.pth',
input_size=(3, 475, 475), pool_size=(15, 15), crop_pct=0.936),
'tf_efficientnet_l2_ns': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth',
input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.961),
'tf_efficientnet_es': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
@ -208,7 +241,7 @@ class EfficientNet(nn.Module):
""" (Generic) EfficientNet
A flexible and performant PyTorch implementation of efficient network architectures, including:
* EfficientNet B0-B8
* EfficientNet B0-B8, L2
* EfficientNet-EdgeTPU
* EfficientNet-CondConv
* MixNet S, M, L, XL
@ -586,6 +619,7 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
'efficientnet-b6': (1.8, 2.6, 528, 0.5),
'efficientnet-b7': (2.0, 3.1, 600, 0.5),
'efficientnet-b8': (2.2, 3.6, 672, 0.5),
'efficientnet-l2': (4.3, 5.3, 800, 0.5),
Args:
channel_multiplier: multiplier to number of channels per layer
@ -928,6 +962,24 @@ def efficientnet_b7(pretrained=False, **kwargs):
return model
@register_model
def efficientnet_b8(pretrained=False, **kwargs):
""" EfficientNet-B8 """
# NOTE for train, drop_rate should be 0.5, drop_connect_rate should be 0.2
model = _gen_efficientnet(
'efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs)
return model
@register_model
def efficientnet_l2(pretrained=False, **kwargs):
""" EfficientNet-L2."""
# NOTE for train, drop_rate should be 0.5, drop_connect_rate should be 0.2
model = _gen_efficientnet(
'efficientnet_l2', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs)
return model
@register_model
def efficientnet_es(pretrained=False, **kwargs):
""" EfficientNet-Edge Small. """
@ -1166,6 +1218,109 @@ def tf_efficientnet_b8_ap(pretrained=False, **kwargs):
return model
@register_model
def tf_efficientnet_b0_ns(pretrained=False, **kwargs):
""" EfficientNet-B0 NoisyStudent. Tensorflow compatible variant """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b0_ns', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
return model
@register_model
def tf_efficientnet_b1_ns(pretrained=False, **kwargs):
""" EfficientNet-B1 NoisyStudent. Tensorflow compatible variant """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b1_ns', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
return model
@register_model
def tf_efficientnet_b2_ns(pretrained=False, **kwargs):
""" EfficientNet-B2 NoisyStudent. Tensorflow compatible variant """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b2_ns', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
return model
@register_model
def tf_efficientnet_b3_ns(pretrained=False, **kwargs):
""" EfficientNet-B3 NoisyStudent. Tensorflow compatible variant """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b3_ns', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
return model
@register_model
def tf_efficientnet_b4_ns(pretrained=False, **kwargs):
""" EfficientNet-B4 NoisyStudent. Tensorflow compatible variant """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b4_ns', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
return model
@register_model
def tf_efficientnet_b5_ns(pretrained=False, **kwargs):
""" EfficientNet-B5 NoisyStudent. Tensorflow compatible variant """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b5_ns', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
return model
@register_model
def tf_efficientnet_b6_ns(pretrained=False, **kwargs):
""" EfficientNet-B6 NoisyStudent. Tensorflow compatible variant """
# NOTE for train, drop_rate should be 0.5
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b6_ns', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs)
return model
@register_model
def tf_efficientnet_b7_ns(pretrained=False, **kwargs):
""" EfficientNet-B7 NoisyStudent. Tensorflow compatible variant """
# NOTE for train, drop_rate should be 0.5
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b7_ns', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs)
return model
@register_model
def tf_efficientnet_l2_ns_475(pretrained=False, **kwargs):
""" EfficientNet-L2 NoisyStudent @ 475x475. Tensorflow compatible variant """
# NOTE for train, drop_rate should be 0.5
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_l2_ns_475', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs)
return model
@register_model
def tf_efficientnet_l2_ns(pretrained=False, **kwargs):
""" EfficientNet-L2 NoisyStudent. Tensorflow compatible variant """
# NOTE for train, drop_rate should be 0.5
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_l2_ns', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs)
return model
@register_model
def tf_efficientnet_es(pretrained=False, **kwargs):

Loading…
Cancel
Save