diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 8d99d19b..c04aad11 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -29,6 +29,7 @@ from .vision_transformer import * from .vovnet import * from .xception import * from .xception_aligned import * +from .hardcorenas import * from .factory import create_model from .helpers import load_checkpoint, resume_checkpoint, model_parameters @@ -36,3 +37,4 @@ from .layers import TestTimePoolHead, apply_test_time_pool from .layers import convert_splitbn_model from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit from .registry import * + diff --git a/timm/models/hardcorenas.py b/timm/models/hardcorenas.py new file mode 100644 index 00000000..18422354 --- /dev/null +++ b/timm/models/hardcorenas.py @@ -0,0 +1,148 @@ +import torch.nn as nn +from .efficientnet_builder import decode_arch_def, resolve_bn_args +from .mobilenetv3 import MobileNetV3, MobileNetV3Features, build_model_with_cfg, default_cfg_for_features +from .layers import hard_sigmoid +from .efficientnet_blocks import resolve_act_layer +from .registry import register_model +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv_stem', 'classifier': 'classifier', + **kwargs + } + + +default_cfgs = { + 'hardcorenas_A': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_A_Green_38ms_75.9_23474aeb.pth'), + 'hardcorenas_B': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_B_Green_40ms_76.5_1f882d1e.pth'), + 'hardcorenas_C': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_C_Green_44ms_77.1_d4148c9e.pth'), + 'hardcorenas_D': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_D_Green_50ms_77.4_23e3cdde.pth'), + 'hardcorenas_E': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_E_Green_55ms_77.9_90f20e8a.pth'), + 'hardcorenas_F': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_F_Green_60ms_78.1_2855edf1.pth'), +} + +def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs): + """Creates a hardcorenas model + + Ref impl: https://github.com/Alibaba-MIIL/HardCoReNAS + Paper: https://arxiv.org/abs/2102.11646 + + """ + num_features = 1280 + act_layer = resolve_act_layer(kwargs, 'hard_swish') + + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + num_features=num_features, + stem_size=32, + channel_multiplier=1, + norm_kwargs=resolve_bn_args(kwargs), + act_layer=act_layer, + se_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisor=8), + **kwargs, + ) + + features_only = False + model_cls = MobileNetV3 + if model_kwargs.pop('features_only', False): + features_only = True + model_kwargs.pop('num_classes', 0) + model_kwargs.pop('num_features', 0) + model_kwargs.pop('head_conv', None) + model_kwargs.pop('head_bias', None) + model_cls = MobileNetV3Features + model = build_model_with_cfg( + model_cls, variant, pretrained, default_cfg=default_cfgs[variant], + pretrained_strict=not features_only, **model_kwargs) + if features_only: + model.default_cfg = default_cfg_for_features(model.default_cfg) + return model + + +@register_model +def hardcorenas_A(pretrained=False, **kwargs): + """ hardcorenas_A """ + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], + ['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e6_c40_nre_se0.25'], + ['ir_r1_k5_s2_e6_c80_se0.25', 'ir_r1_k5_s1_e6_c80_se0.25'], + ['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25'], + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']] + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_A', arch_def=arch_def, **kwargs) + return model + + +@register_model +def hardcorenas_B(pretrained=False, **kwargs): + """ hardcorenas_B """ + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], + ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25', 'ir_r1_k3_s1_e3_c24_nre'], + ['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre'], + ['ir_r1_k5_s2_e3_c80', 'ir_r1_k5_s1_e3_c80', 'ir_r1_k3_s1_e3_c80', 'ir_r1_k3_s1_e3_c80'], + ['ir_r1_k5_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112'], + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e3_c192_se0.25'], + ['cn_r1_k1_s1_c960']] + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_B', arch_def=arch_def, **kwargs) + return model + + +@register_model +def hardcorenas_C(pretrained=False, **kwargs): + """ hardcorenas_C """ + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], + ['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre', + 'ir_r1_k5_s1_e3_c40_nre'], + ['ir_r1_k5_s2_e4_c80', 'ir_r1_k5_s1_e6_c80_se0.25', 'ir_r1_k3_s1_e3_c80', 'ir_r1_k3_s1_e3_c80'], + ['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112'], + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e3_c192_se0.25'], + ['cn_r1_k1_s1_c960']] + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_C', arch_def=arch_def, **kwargs) + return model + + +@register_model +def hardcorenas_D(pretrained=False, **kwargs): + """ hardcorenas_D """ + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], + ['ir_r1_k5_s2_e3_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25', 'ir_r1_k3_s1_e3_c40_nre_se0.25'], + ['ir_r1_k5_s2_e4_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25', + 'ir_r1_k3_s1_e3_c80_se0.25'], + ['ir_r1_k3_s1_e4_c112_se0.25', 'ir_r1_k5_s1_e4_c112_se0.25', 'ir_r1_k3_s1_e3_c112_se0.25', + 'ir_r1_k5_s1_e3_c112_se0.25'], + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', + 'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']] + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_D', arch_def=arch_def, **kwargs) + return model + + +@register_model +def hardcorenas_E(pretrained=False, **kwargs): + """ hardcorenas_E """ + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], + ['ir_r1_k5_s2_e6_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25', + 'ir_r1_k3_s1_e3_c40_nre_se0.25'], ['ir_r1_k5_s2_e4_c80_se0.25', 'ir_r1_k3_s1_e6_c80_se0.25'], + ['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', + 'ir_r1_k5_s1_e3_c112_se0.25'], + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', + 'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']] + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_E', arch_def=arch_def, **kwargs) + return model + + +@register_model +def hardcorenas_F(pretrained=False, **kwargs): + """ hardcorenas_F """ + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], + ['ir_r1_k5_s2_e6_c40_nre_se0.25', 'ir_r1_k5_s1_e6_c40_nre_se0.25'], + ['ir_r1_k5_s2_e6_c80_se0.25', 'ir_r1_k5_s1_e6_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25', + 'ir_r1_k3_s1_e3_c80_se0.25'], + ['ir_r1_k3_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', + 'ir_r1_k3_s1_e3_c112_se0.25'], + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e6_c192_se0.25', + 'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']] + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_F', arch_def=arch_def, **kwargs) + return model