Add ported Tensorflow EfficientNet B4/B5 weights

pull/13/head
Ross Wightman 6 years ago
parent c9a61b7d98
commit 1019414fd2

@ -129,6 +129,19 @@ I've leveraged the training scripts in this repository to train a few of the mod
| tf_inception_v3 | 77.856 (22.144) | 93.644 (6.356) | 27.16M | bicubic | [Tensorflow Slim](https://github.com/tensorflow/models/tree/master/research/slim) |
| adv_inception_v3 | 77.576 (22.424) | 93.724 (6.276) | 27.16M | bicubic | [Tensorflow Adv models](https://github.com/tensorflow/models/tree/master/research/adv_imagenet_models) |
#### @ 380x380
| Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling | Source |
|---|---|---|---|---|---|
| tf_efficientnet_b4 | 82.604 (17.396) | 96.128 (3.872) | 19.34 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) |
| tf_efficientnet_b4 *tfp | 82.604 (17.396) | 96.094 (3.906) | 19.34 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) |
#### @ 456x456
| Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling | Source |
|---|---|---|---|---|---|
| tf_efficientnet_b5 *tfp | 83.200 (16.800) | 96.456 (3.544) | 30.39 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) |
| tf_efficientnet_b5 | 83.176 (16.824) | 96.536 (3.464) | 30.39 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) |
NOTE: For some reason I can't hit the stated accuracy with my impl of MNASNet and Google's tflite weights. Using a TF equivalent to 'SAME' padding was important to get > 70%, but something small is still missing. Trying to train my own weights from scratch with these models has so far to leveled off in the same 72-73% range.
Models with `*tfp` next to them were scored with `--tf-preprocessing` flag.

@ -31,9 +31,9 @@ _models = [
'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_b1', 'mnasnet_140', 'semnasnet_050', 'semnasnet_075',
'semnasnet_100', 'mnasnet_a1', 'semnasnet_140', 'mnasnet_small', 'mobilenetv1_100', 'mobilenetv2_100',
'mobilenetv3_050', 'mobilenetv3_075', 'mobilenetv3_100', 'chamnetv1_100', 'chamnetv2_100',
'fbnetc_100', 'spnasnet_100', 'tflite_mnasnet_100', 'tflite_semnasnet_100', 'efficientnet_b0',
'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'tf_efficientnet_b0',
'tf_efficientnet_b1', 'tf_efficientnet_b2', 'tf_efficientnet_b3']
'fbnetc_100', 'spnasnet_100', 'tflite_mnasnet_100', 'tflite_semnasnet_100', 'efficientnet_b0', 'efficientnet_b1',
'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'efficientnet_b5', 'tf_efficientnet_b0',
'tf_efficientnet_b1', 'tf_efficientnet_b2', 'tf_efficientnet_b3', 'tf_efficientnet_b4', 'tf_efficientnet_b5']
__all__ = ['GenEfficientNet', 'gen_efficientnet_model_names'] + _models
@ -91,6 +91,8 @@ default_cfgs = {
url='', input_size=(3, 300, 300), pool_size=(10, 10)),
'efficientnet_b4': _cfg(
url='', input_size=(3, 380, 380), pool_size=(12, 12)),
'efficientnet_b5': _cfg(
url='', input_size=(3, 456, 456), pool_size=(15, 15)),
'tf_efficientnet_b0': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0-0af12548.pth',
input_size=(3, 224, 224), interpolation='bicubic'),
@ -103,8 +105,15 @@ default_cfgs = {
'tf_efficientnet_b3': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3-e3bd6955.pth',
input_size=(3, 300, 300), pool_size=(10, 10), interpolation='bicubic', crop_pct=0.904),
'tf_efficientnet_b4': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4-74ee3bed.pth',
input_size=(3, 380, 380), pool_size=(12, 12), interpolation='bicubic', crop_pct=0.922),
'tf_efficientnet_b5': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5-c6949ce9.pth',
input_size=(3, 456, 456), pool_size=(15, 15), interpolation='bicubic', crop_pct=0.934)
}
_DEBUG = False
# Default args for PyTorch BN impl
@ -1436,6 +1445,19 @@ def efficientnet_b4(num_classes, in_chans=3, pretrained=False, **kwargs):
return model
def efficientnet_b5(num_classes, in_chans=3, pretrained=False, **kwargs):
""" EfficientNet-B5 """
# NOTE for train, drop_rate should be 0.4
default_cfg = default_cfgs['efficientnet_b5']
model = _gen_efficientnet(
channel_multiplier=1.6, depth_multiplier=2.2,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def tf_efficientnet_b0(num_classes, in_chans=3, pretrained=False, **kwargs):
""" EfficientNet-B0. Tensorflow compatible variant """
default_cfg = default_cfgs['tf_efficientnet_b0']
@ -1492,5 +1514,33 @@ def tf_efficientnet_b3(num_classes, in_chans=3, pretrained=False, **kwargs):
return model
def tf_efficientnet_b4(num_classes, in_chans=3, pretrained=False, **kwargs):
""" EfficientNet-B4. Tensorflow compatible variant """
default_cfg = default_cfgs['tf_efficientnet_b4']
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
kwargs['padding_same'] = True
model = _gen_efficientnet(
channel_multiplier=1.4, depth_multiplier=1.8,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def tf_efficientnet_b5(num_classes, in_chans=3, pretrained=False, **kwargs):
""" EfficientNet-B5. Tensorflow compatible variant """
default_cfg = default_cfgs['tf_efficientnet_b5']
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
kwargs['padding_same'] = True
model = _gen_efficientnet(
channel_multiplier=1.6, depth_multiplier=2.2,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def gen_efficientnet_model_names():
return set(_models)

@ -53,6 +53,10 @@ def resume_checkpoint(model, checkpoint_path, start_epoch=None):
def load_pretrained(model, default_cfg, num_classes=1000, in_chans=3, filter_fn=None):
if 'url' not in default_cfg or not default_cfg['url']:
print("Warning: pretrained model URL is invalid, using random initialization.")
return
state_dict = model_zoo.load_url(default_cfg['url'])
if in_chans == 1:

Loading…
Cancel
Save