Merge pull request #136 from yoniaflalo/adding_effnet_pruned

added efficientnet pruned weights
pull/140/head
Ross Wightman 4 years ago committed by GitHub
commit 8ec554b82e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -27,7 +27,7 @@ Hacked together by Ross Wightman
from .efficientnet_builder import *
from .feature_hooks import FeatureHooks
from .registry import register_model
from .helpers import load_pretrained
from .helpers import load_pretrained, adapt_model_from_file
from .layers import SelectAdaptivePool2d
from timm.models.layers import create_conv2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
@ -131,6 +131,16 @@ default_cfgs = {
'efficientnet_lite4': _cfg(
url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
'efficientnet_b1_pruned': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb1_pruned_9ebb3fe6.pth',
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'efficientnet_b2_pruned': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb2_pruned_203f55bc.pth',
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'efficientnet_b3_pruned': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb3_pruned_5abcc29f.pth',
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'tf_efficientnet_b0': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
input_size=(3, 224, 224)),
@ -482,9 +492,11 @@ def _create_model(model_kwargs, default_cfg, pretrained=False):
else:
load_strict = True
model_class = EfficientNet
variant = model_kwargs.pop('variant', '')
model = model_class(**model_kwargs)
model.default_cfg = default_cfg
if '_pruned' in variant:
model = adapt_model_from_file(model, variant)
if pretrained:
load_pretrained(
model,
@ -730,6 +742,7 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
channel_multiplier=channel_multiplier,
act_layer=Swish,
norm_kwargs=resolve_bn_args(kwargs),
variant=variant,
**kwargs,
)
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
@ -1229,6 +1242,41 @@ def efficientnet_lite4(pretrained=False, **kwargs):
return model
@register_model
def efficientnet_b1_pruned(pretrained=False, **kwargs):
""" EfficientNet-B1 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
variant = 'efficientnet_b1_pruned'
model = _gen_efficientnet(
variant, channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
return model
@register_model
def efficientnet_b2_pruned(pretrained=False, **kwargs):
""" EfficientNet-B2 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'efficientnet_b2_pruned', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
return model
@register_model
def efficientnet_b3_pruned(pretrained=False, **kwargs):
""" EfficientNet-B3 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'efficientnet_b3_pruned', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
return model
@register_model
def tf_efficientnet_b0(pretrained=False, **kwargs):
""" EfficientNet-B0. Tensorflow compatible variant """

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long
Loading…
Cancel
Save