added efficientnet pruned weights

pull/136/head
AFLALO, Jonathan Isaac 5 years ago
parent a4d20a1cb8
commit 9c15d57505

@ -27,7 +27,7 @@ Hacked together by Ross Wightman
from .efficientnet_builder import * from .efficientnet_builder import *
from .feature_hooks import FeatureHooks from .feature_hooks import FeatureHooks
from .registry import register_model from .registry import register_model
from .helpers import load_pretrained from .helpers import load_pretrained, adapt_model_from_file
from .layers import SelectAdaptivePool2d from .layers import SelectAdaptivePool2d
from timm.models.layers import create_conv2d from timm.models.layers import create_conv2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
@ -131,6 +131,16 @@ default_cfgs = {
'efficientnet_lite4': _cfg( 'efficientnet_lite4': _cfg(
url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
'efficientnet_b1_pruned': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb1_pruned_9ebb3fe6.pth',
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'efficientnet_b2_pruned': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb2_pruned_203f55bc.pth',
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'efficientnet_b3_pruned': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb3_pruned_5abcc29f.pth',
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'tf_efficientnet_b0': _cfg( 'tf_efficientnet_b0': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
input_size=(3, 224, 224)), input_size=(3, 224, 224)),
@ -482,9 +492,11 @@ def _create_model(model_kwargs, default_cfg, pretrained=False):
else: else:
load_strict = True load_strict = True
model_class = EfficientNet model_class = EfficientNet
variant = model_kwargs.pop('variant', '')
model = model_class(**model_kwargs) model = model_class(**model_kwargs)
model.default_cfg = default_cfg model.default_cfg = default_cfg
if '_pruned' in variant:
model = adapt_model_from_file(model, variant)
if pretrained: if pretrained:
load_pretrained( load_pretrained(
model, model,
@ -730,6 +742,7 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
channel_multiplier=channel_multiplier, channel_multiplier=channel_multiplier,
act_layer=Swish, act_layer=Swish,
norm_kwargs=resolve_bn_args(kwargs), norm_kwargs=resolve_bn_args(kwargs),
variant=variant,
**kwargs, **kwargs,
) )
model = _create_model(model_kwargs, default_cfgs[variant], pretrained) model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
@ -1229,6 +1242,41 @@ def efficientnet_lite4(pretrained=False, **kwargs):
return model return model
@register_model
def efficientnet_b1_pruned(pretrained=False, **kwargs):
""" EfficientNet-B1 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
variant = 'efficientnet_b1_pruned'
model = _gen_efficientnet(
variant, channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
return model
@register_model
def efficientnet_b2_pruned(pretrained=False, **kwargs):
""" EfficientNet-B2 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'efficientnet_b2_pruned', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
return model
@register_model
def efficientnet_b3_pruned(pretrained=False, **kwargs):
""" EfficientNet-B3 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'efficientnet_b3_pruned', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
return model
@register_model @register_model
def tf_efficientnet_b0(pretrained=False, **kwargs): def tf_efficientnet_b0(pretrained=False, **kwargs):
""" EfficientNet-B0. Tensorflow compatible variant """ """ EfficientNet-B0. Tensorflow compatible variant """

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long
Loading…
Cancel
Save