from functools import partial import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .efficientnet_blocks import SqueezeExcite from .efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args, round_channels from .helpers import build_model_with_cfg, pretrained_cfg_for_features from .layers import get_act_fn from .mobilenetv3 import MobileNetV3, MobileNetV3Features from .registry import register_model def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bilinear', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'conv_stem', 'classifier': 'classifier', **kwargs } default_cfgs = { 'hardcorenas_a': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_a_green_38ms_75_9-31dc7186.pth'), 'hardcorenas_b': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_b_green_40ms_76_5-32d91ff2.pth'), 'hardcorenas_c': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_c_green_44ms_77_1-631a0983.pth'), 'hardcorenas_d': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_d_green_50ms_77_4-998d9d7a.pth'), 'hardcorenas_e': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_e_green_55ms_77_9-482886a3.pth'), 'hardcorenas_f': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_f_green_60ms_78_1-14b9e780.pth'), } def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs): """Creates a hardcorenas model Ref impl: https://github.com/Alibaba-MIIL/HardCoReNAS Paper: https://arxiv.org/abs/2102.11646 """ num_features = 1280 se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels) model_kwargs = dict( block_args=decode_arch_def(arch_def), num_features=num_features, stem_size=32, norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), act_layer=resolve_act_layer(kwargs, 'hard_swish'), se_layer=se_layer, **kwargs, ) features_only = False model_cls = MobileNetV3 kwargs_filter = None if model_kwargs.pop('features_only', False): features_only = True kwargs_filter = ('num_classes', 'num_features', 'global_pool', 'head_conv', 'head_bias', 'global_pool') model_cls = MobileNetV3Features model = build_model_with_cfg( model_cls, variant, pretrained, pretrained_strict=not features_only, kwargs_filter=kwargs_filter, **model_kwargs) if features_only: model.default_cfg = pretrained_cfg_for_features(model.default_cfg) return model @register_model def hardcorenas_a(pretrained=False, **kwargs): """ hardcorenas_A """ arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], ['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e6_c40_nre_se0.25'], ['ir_r1_k5_s2_e6_c80_se0.25', 'ir_r1_k5_s1_e6_c80_se0.25'], ['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25'], ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']] model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_a', arch_def=arch_def, **kwargs) return model @register_model def hardcorenas_b(pretrained=False, **kwargs): """ hardcorenas_B """ arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25', 'ir_r1_k3_s1_e3_c24_nre'], ['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre'], ['ir_r1_k5_s2_e3_c80', 'ir_r1_k5_s1_e3_c80', 'ir_r1_k3_s1_e3_c80', 'ir_r1_k3_s1_e3_c80'], ['ir_r1_k5_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112'], ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e3_c192_se0.25'], ['cn_r1_k1_s1_c960']] model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_b', arch_def=arch_def, **kwargs) return model @register_model def hardcorenas_c(pretrained=False, **kwargs): """ hardcorenas_C """ arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], ['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre'], ['ir_r1_k5_s2_e4_c80', 'ir_r1_k5_s1_e6_c80_se0.25', 'ir_r1_k3_s1_e3_c80', 'ir_r1_k3_s1_e3_c80'], ['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112'], ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e3_c192_se0.25'], ['cn_r1_k1_s1_c960']] model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_c', arch_def=arch_def, **kwargs) return model @register_model def hardcorenas_d(pretrained=False, **kwargs): """ hardcorenas_D """ arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], ['ir_r1_k5_s2_e3_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25', 'ir_r1_k3_s1_e3_c40_nre_se0.25'], ['ir_r1_k5_s2_e4_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25'], ['ir_r1_k3_s1_e4_c112_se0.25', 'ir_r1_k5_s1_e4_c112_se0.25', 'ir_r1_k3_s1_e3_c112_se0.25', 'ir_r1_k5_s1_e3_c112_se0.25'], ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']] model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_d', arch_def=arch_def, **kwargs) return model @register_model def hardcorenas_e(pretrained=False, **kwargs): """ hardcorenas_E """ arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], ['ir_r1_k5_s2_e6_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25', 'ir_r1_k3_s1_e3_c40_nre_se0.25'], ['ir_r1_k5_s2_e4_c80_se0.25', 'ir_r1_k3_s1_e6_c80_se0.25'], ['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e3_c112_se0.25'], ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']] model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_e', arch_def=arch_def, **kwargs) return model @register_model def hardcorenas_f(pretrained=False, **kwargs): """ hardcorenas_F """ arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], ['ir_r1_k5_s2_e6_c40_nre_se0.25', 'ir_r1_k5_s1_e6_c40_nre_se0.25'], ['ir_r1_k5_s2_e6_c80_se0.25', 'ir_r1_k5_s1_e6_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25'], ['ir_r1_k3_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k3_s1_e3_c112_se0.25'], ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']] model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_f', arch_def=arch_def, **kwargs) return model