From c1cf9712fc566182952fd7053a1b55399173ff1a Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 19 Apr 2021 10:42:56 -0700 Subject: [PATCH] Add updated EfficientNet-V2S weights, 83.8 @ 384x384 test. Add PyTorch trained EfficientNet-B4 weights, 83.4 @ 384x384 test. Tweak non TF EfficientNet B1-B4 train/test res scaling. --- timm/models/efficientnet.py | 31 +++++++++++-------------------- 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 8f285759..0c414d50 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -85,21 +85,16 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0_ra-3dd342df.pth'), 'efficientnet_b1': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth', - input_size=(3, 240, 240), pool_size=(8, 8)), + test_input_size=(3, 256, 256), crop_pct=1.0), 'efficientnet_b2': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2_ra-bcdf34b7.pth', - input_size=(3, 260, 260), pool_size=(9, 9)), - 'efficientnet_b2a': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2_ra-bcdf34b7.pth', - input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0), + input_size=(3, 256, 256), pool_size=(8, 8), test_input_size=(3, 288, 288), crop_pct=1.0), 'efficientnet_b3': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b3_ra2-cf984f9c.pth', - input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), - 'efficientnet_b3a': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b3_ra2-cf984f9c.pth', - input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0), + input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), crop_pct=1.0), 'efficientnet_b4': _cfg( - url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b4_ra2_320-7eb33cd5.pth', + input_size=(3, 320, 320), pool_size=(10, 10), test_input_size=(3, 384, 384), crop_pct=1.0), 'efficientnet_b5': _cfg( url='', input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), 'efficientnet_b6': _cfg( @@ -155,8 +150,8 @@ default_cfgs = { input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), 'efficientnet_v2s': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_v2s_ra2-b265c1ba.pth', - input_size=(3, 224, 224), test_input_size=(3, 320, 320), pool_size=(7, 7), crop_pct=1.0), # FIXME WIP + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_v2s_ra2_288-a6477665.pth', + input_size=(3, 288, 288), test_input_size=(3, 384, 384), pool_size=(9, 9), crop_pct=1.0), # FIXME WIP 'tf_efficientnet_b0': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth', @@ -1077,10 +1072,8 @@ def efficientnet_b2(pretrained=False, **kwargs): @register_model def efficientnet_b2a(pretrained=False, **kwargs): """ EfficientNet-B2 @ 288x288 w/ 1.0 test crop""" - # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 - model = _gen_efficientnet( - 'efficientnet_b2a', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) - return model + # WARN this model def is deprecated, different train/test res + test crop handled by default_cfg now + return efficientnet_b2(pretrained=pretrained, **kwargs) @register_model @@ -1095,10 +1088,8 @@ def efficientnet_b3(pretrained=False, **kwargs): @register_model def efficientnet_b3a(pretrained=False, **kwargs): """ EfficientNet-B3 @ 320x320 w/ 1.0 test crop-pct """ - # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 - model = _gen_efficientnet( - 'efficientnet_b3a', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) - return model + # WARN this model def is deprecated, different train/test res + test crop handled by default_cfg now + return efficientnet_b3(pretrained=pretrained, **kwargs) @register_model