Merge branch 'adding_Hardcore_NAS' of https://github.com/yoniaflalo/pytorch-image-models into yoniaflalo-adding_Hardcore_NAS
commit
52b5d0ad0a
@ -0,0 +1,148 @@
|
||||
import torch.nn as nn
|
||||
from .efficientnet_builder import decode_arch_def, resolve_bn_args
|
||||
from .mobilenetv3 import MobileNetV3, MobileNetV3Features, build_model_with_cfg, default_cfg_for_features
|
||||
from .layers import hard_sigmoid
|
||||
from .efficientnet_blocks import resolve_act_layer
|
||||
from .registry import register_model
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1),
|
||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'conv_stem', 'classifier': 'classifier',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'hardcorenas_A': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_A_Green_38ms_75.9_23474aeb.pth'),
|
||||
'hardcorenas_B': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_B_Green_40ms_76.5_1f882d1e.pth'),
|
||||
'hardcorenas_C': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_C_Green_44ms_77.1_d4148c9e.pth'),
|
||||
'hardcorenas_D': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_D_Green_50ms_77.4_23e3cdde.pth'),
|
||||
'hardcorenas_E': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_E_Green_55ms_77.9_90f20e8a.pth'),
|
||||
'hardcorenas_F': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_F_Green_60ms_78.1_2855edf1.pth'),
|
||||
}
|
||||
|
||||
def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs):
|
||||
"""Creates a hardcorenas model
|
||||
|
||||
Ref impl: https://github.com/Alibaba-MIIL/HardCoReNAS
|
||||
Paper: https://arxiv.org/abs/2102.11646
|
||||
|
||||
"""
|
||||
num_features = 1280
|
||||
act_layer = resolve_act_layer(kwargs, 'hard_swish')
|
||||
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def),
|
||||
num_features=num_features,
|
||||
stem_size=32,
|
||||
channel_multiplier=1,
|
||||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
act_layer=act_layer,
|
||||
se_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisor=8),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
features_only = False
|
||||
model_cls = MobileNetV3
|
||||
if model_kwargs.pop('features_only', False):
|
||||
features_only = True
|
||||
model_kwargs.pop('num_classes', 0)
|
||||
model_kwargs.pop('num_features', 0)
|
||||
model_kwargs.pop('head_conv', None)
|
||||
model_kwargs.pop('head_bias', None)
|
||||
model_cls = MobileNetV3Features
|
||||
model = build_model_with_cfg(
|
||||
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
|
||||
pretrained_strict=not features_only, **model_kwargs)
|
||||
if features_only:
|
||||
model.default_cfg = default_cfg_for_features(model.default_cfg)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def hardcorenas_A(pretrained=False, **kwargs):
|
||||
""" hardcorenas_A """
|
||||
arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25'],
|
||||
['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e6_c40_nre_se0.25'],
|
||||
['ir_r1_k5_s2_e6_c80_se0.25', 'ir_r1_k5_s1_e6_c80_se0.25'],
|
||||
['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25'],
|
||||
['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']]
|
||||
model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_A', arch_def=arch_def, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def hardcorenas_B(pretrained=False, **kwargs):
|
||||
""" hardcorenas_B """
|
||||
arch_def = [['ds_r1_k3_s1_e1_c16_nre'],
|
||||
['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25', 'ir_r1_k3_s1_e3_c24_nre'],
|
||||
['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre'],
|
||||
['ir_r1_k5_s2_e3_c80', 'ir_r1_k5_s1_e3_c80', 'ir_r1_k3_s1_e3_c80', 'ir_r1_k3_s1_e3_c80'],
|
||||
['ir_r1_k5_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112'],
|
||||
['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e3_c192_se0.25'],
|
||||
['cn_r1_k1_s1_c960']]
|
||||
model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_B', arch_def=arch_def, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def hardcorenas_C(pretrained=False, **kwargs):
|
||||
""" hardcorenas_C """
|
||||
arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25'],
|
||||
['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre',
|
||||
'ir_r1_k5_s1_e3_c40_nre'],
|
||||
['ir_r1_k5_s2_e4_c80', 'ir_r1_k5_s1_e6_c80_se0.25', 'ir_r1_k3_s1_e3_c80', 'ir_r1_k3_s1_e3_c80'],
|
||||
['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112'],
|
||||
['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e3_c192_se0.25'],
|
||||
['cn_r1_k1_s1_c960']]
|
||||
model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_C', arch_def=arch_def, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def hardcorenas_D(pretrained=False, **kwargs):
|
||||
""" hardcorenas_D """
|
||||
arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'],
|
||||
['ir_r1_k5_s2_e3_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25', 'ir_r1_k3_s1_e3_c40_nre_se0.25'],
|
||||
['ir_r1_k5_s2_e4_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25',
|
||||
'ir_r1_k3_s1_e3_c80_se0.25'],
|
||||
['ir_r1_k3_s1_e4_c112_se0.25', 'ir_r1_k5_s1_e4_c112_se0.25', 'ir_r1_k3_s1_e3_c112_se0.25',
|
||||
'ir_r1_k5_s1_e3_c112_se0.25'],
|
||||
['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25',
|
||||
'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']]
|
||||
model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_D', arch_def=arch_def, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def hardcorenas_E(pretrained=False, **kwargs):
|
||||
""" hardcorenas_E """
|
||||
arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'],
|
||||
['ir_r1_k5_s2_e6_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25',
|
||||
'ir_r1_k3_s1_e3_c40_nre_se0.25'], ['ir_r1_k5_s2_e4_c80_se0.25', 'ir_r1_k3_s1_e6_c80_se0.25'],
|
||||
['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25',
|
||||
'ir_r1_k5_s1_e3_c112_se0.25'],
|
||||
['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25',
|
||||
'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']]
|
||||
model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_E', arch_def=arch_def, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def hardcorenas_F(pretrained=False, **kwargs):
|
||||
""" hardcorenas_F """
|
||||
arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'],
|
||||
['ir_r1_k5_s2_e6_c40_nre_se0.25', 'ir_r1_k5_s1_e6_c40_nre_se0.25'],
|
||||
['ir_r1_k5_s2_e6_c80_se0.25', 'ir_r1_k5_s1_e6_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25',
|
||||
'ir_r1_k3_s1_e3_c80_se0.25'],
|
||||
['ir_r1_k3_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25',
|
||||
'ir_r1_k3_s1_e3_c112_se0.25'],
|
||||
['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e6_c192_se0.25',
|
||||
'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']]
|
||||
model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_F', arch_def=arch_def, **kwargs)
|
||||
return model
|
Loading…
Reference in new issue