From db056d97e2b640fbec24a0bcc20ff346d3500235 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 10 May 2019 23:28:13 -0700 Subject: [PATCH] Add MobileNetV3 and associated changes hard-swish, hard-sigmoid, efficient head, etc --- README.md | 1 + models/genmobilenet.py | 338 ++++++++++++++++++++++++++++++---------- models/model_factory.py | 8 +- 3 files changed, 259 insertions(+), 88 deletions(-) diff --git a/README.md b/README.md index db84a051..1a4a9eb7 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,7 @@ I've included a few of my favourite models, but this is not an exhaustive collec * MNASNet B1, A1 (Squeeze-Excite), and Small * MobileNet-V1 * MobileNet-V2 + * MobileNet-V3 (work in progress, validating config) * ChamNet (details hard to find, currently an educated guess) * FBNet-C (TODO A/B variants) ## Features diff --git a/models/genmobilenet.py b/models/genmobilenet.py index 29c2599b..d0665316 100644 --- a/models/genmobilenet.py +++ b/models/genmobilenet.py @@ -2,7 +2,7 @@ A generic MobileNet class with building blocks to support a variety of models: * MNasNet B1, A1 (SE), Small -* MobileNetV2 +* MobileNet V1, V2, and V3 (work in progress) * FBNet-C (TODO A & B) * ChamNet (TODO still guessing at architecture definition) * Single-Path NAS Pixel1 @@ -26,10 +26,10 @@ from models.adaptive_avgmax_pool import SelectAdaptivePool2d from models.conv2d_same import sconv2d from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -__all__ = ['GenMobileNet', 'mnasnet0_50', 'mnasnet0_75', 'mnasnet1_00', 'mnasnet1_40', - 'semnasnet0_50', 'semnasnet0_75', 'semnasnet1_00', 'semnasnet1_40', 'mnasnet_small', - 'mobilenetv1_1_00', 'mobilenetv2_1_00', 'chamnetv1_1_00', 'chamnetv2_1_00', - 'fbnetc_1_00', 'spnasnet1_00'] +__all__ = ['GenMobileNet', 'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_140', + 'semnasnet_050', 'semnasnet_075', 'semnasnet_100', 'semnasnet_140', 'mnasnet_small', + 'mobilenetv1_100', 'mobilenetv2_100', 'mobilenetv3_050', 'mobilenetv3_075', 'mobilenetv3_100', + 'chamnetv1_100', 'chamnetv2_100', 'fbnetc_100', 'spnasnet_100'] def _cfg(url='', **kwargs): @@ -43,25 +43,28 @@ def _cfg(url='', **kwargs): default_cfgs = { - 'mnasnet0_50': _cfg(url=''), - 'mnasnet0_75': _cfg(url=''), - 'mnasnet1_00': _cfg(url=''), - 'tflite_mnasnet1_00': _cfg(url='https://www.dropbox.com/s/q55ir3tx8mpeyol/tflite_mnasnet1_00-31639cdc.pth?dl=1', + 'mnasnet_050': _cfg(url=''), + 'mnasnet_075': _cfg(url=''), + 'mnasnet_100': _cfg(url=''), + 'tflite_mnasnet_100': _cfg(url='https://www.dropbox.com/s/q55ir3tx8mpeyol/tflite_mnasnet_100-31639cdc.pth?dl=1', interpolation='bicubic'), - 'mnasnet1_40': _cfg(url=''), - 'semnasnet0_50': _cfg(url=''), - 'semnasnet0_75': _cfg(url=''), - 'semnasnet1_00': _cfg(url=''), - 'tflite_semnasnet1_00': _cfg(url='https://www.dropbox.com/s/yiori47sr9dydev/tflite_semnasnet1_00-7c780429.pth?dl=1', + 'mnasnet_140': _cfg(url=''), + 'semnasnet_050': _cfg(url=''), + 'semnasnet_075': _cfg(url=''), + 'semnasnet_100': _cfg(url=''), + 'tflite_semnasnet_100': _cfg(url='https://www.dropbox.com/s/yiori47sr9dydev/tflite_semnasnet_100-7c780429.pth?dl=1', interpolation='bicubic'), - 'semnasnet1_40': _cfg(url=''), + 'semnasnet_140': _cfg(url=''), 'mnasnet_small': _cfg(url=''), - 'mobilenetv1_1_00': _cfg(url=''), - 'mobilenetv2_1_00': _cfg(url=''), - 'chamnetv1_1_00': _cfg(url=''), - 'chamnetv2_1_00': _cfg(url=''), - 'fbnetc_1_00': _cfg(url=''), - 'spnasnet1_00': _cfg(url='https://www.dropbox.com/s/iieopt18rytkgaa/spnasnet1_00-048bc3f4.pth?dl=1'), + 'mobilenetv1_100': _cfg(url=''), + 'mobilenetv2_100': _cfg(url=''), + 'mobilenetv3_050': _cfg(url=''), + 'mobilenetv3_075': _cfg(url=''), + 'mobilenetv3_100': _cfg(url=''), + 'chamnetv1_100': _cfg(url=''), + 'chamnetv2_100': _cfg(url=''), + 'fbnetc_100': _cfg(url=''), + 'spnasnet_100': _cfg(url='https://www.dropbox.com/s/iieopt18rytkgaa/spnasnet_100-048bc3f4.pth?dl=1'), } _DEBUG = True @@ -130,6 +133,7 @@ def _decode_block_str(block_str): e - expansion ratio, c - output channels, se - squeeze/excitation ratio + a - activation fn ('re', 'r6', or 'hs') Args: block_str: a string representation of block arguments. Returns: @@ -142,13 +146,33 @@ def _decode_block_str(block_str): block_type = ops[0] # take the block type off the front ops = ops[1:] options = {} + noskip = False for op in ops: - splits = re.split(r'(\d.*)', op) - if len(splits) >= 2: - key, value = splits[:2] + # string options being checked on individual basis, combine if they grow + if op.startswith('a'): + # activation fn + key = op[0] + v = op[1:] + if v == 're': + value = F.relu + elif v == 'r6': + value = F.relu6 + elif v == 'hs': + value = hard_swish + else: + continue options[key] = value + elif op == 'noskip': + noskip = True + else: + # all numeric options + splits = re.split(r'(\d.*)', op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value - # FIXME validate args and throw + # if act_fn is None, the model default (passed to model init) will be used + act_fn = options['a'] if 'a' in options else None num_repeat = int(options['r']) # each type of block has different valid arguments, fill accordingly @@ -157,10 +181,11 @@ def _decode_block_str(block_str): block_type=block_type, kernel_size=int(options['k']), out_chs=int(options['c']), - exp_ratio=int(options['e']), + exp_ratio=float(options['e']), se_ratio=float(options['se']) if 'se' in options else None, stride=int(options['s']), - noskip=('noskip' in block_str), + act_fn=act_fn, + noskip=noskip, ) if 'g' in options: block_args['pw_group'] = options['g'] @@ -169,9 +194,11 @@ def _decode_block_str(block_str): elif block_type == 'ca': block_args = dict( block_type=block_type, + kernel_size=int(options['k']), out_chs=int(options['c']), stride=int(options['s']), - noskip=('noskip' in block_str), + act_fn=act_fn, + noskip=noskip, ) elif block_type == 'ds' or block_type == 'dsa': block_args = dict( @@ -179,9 +206,18 @@ def _decode_block_str(block_str): kernel_size=int(options['k']), out_chs=int(options['c']), stride=int(options['s']), - noskip=block_type == 'dsa' or 'noskip' in block_str, + act_fn=act_fn, + noskip=block_type == 'dsa' or noskip, pw_act=block_type == 'dsa', ) + elif block_type == 'cn': + block_args = dict( + block_type=block_type, + kernel_size=int(options['k']), + out_chs=int(options['c']), + stride=int(options['s']), + act_fn=act_fn, + ) else: assert False, 'Unknown block type (%s)' % block_type @@ -228,11 +264,15 @@ class _BlockBuilder: """ def __init__(self, depth_multiplier=1.0, depth_divisor=8, min_depth=None, + act_fn=None, se_gate_fn=torch.sigmoid, se_reduce_mid=False, bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT, folded_bn=False, padding_same=False): self.depth_multiplier = depth_multiplier self.depth_divisor = depth_divisor self.min_depth = min_depth + self.act_fn = act_fn + self.se_gate_fn = se_gate_fn + self.se_reduce_mid = se_reduce_mid self.bn_momentum = bn_momentum self.bn_eps = bn_eps self.folded_bn = folded_bn @@ -250,15 +290,21 @@ class _BlockBuilder: ba['bn_eps'] = self.bn_eps ba['folded_bn'] = self.folded_bn ba['padding_same'] = self.padding_same + ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn + assert ba['act_fn'] is not None if _DEBUG: print('args:', ba) - # could replace this with lambdas or functools binding if variety increases + # could replace this if with lambdas or functools binding if variety increases if bt == 'ir': + ba['se_gate_fn'] = self.se_gate_fn + ba['se_reduce_mid'] = self.se_reduce_mid block = InvertedResidual(**ba) elif bt == 'ds' or bt == 'dsa': block = DepthwiseSeparableConv(**ba) elif bt == 'ca': - block = CascadeConv3x3(**ba) + block = CascadeConv(**ba) + elif bt == 'cn': + block = ConvBnAct(**ba) else: assert False, 'Uknkown block type (%s) while building model.' % bt self.in_chs = ba['out_chs'] # update in_chs for arg of next block @@ -331,6 +377,37 @@ def _initialize_weight_default(m): nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear') +def hard_swish(x): + return x * F.relu6(x + 3.) / 6. + + +def hard_sigmoid(x): + return F.relu6(x + 3.) / 6. + + +class ConvBnAct(nn.Module): + def __init__(self, in_chs, out_chs, kernel_size, + stride=1, act_fn=F.relu, + bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT, + folded_bn=False, padding_same=False): + super(ConvBnAct, self).__init__() + assert stride in [1, 2] + self.act_fn = act_fn + padding = _padding_arg(_get_padding(kernel_size, stride), padding_same) + + self.conv = sconv2d( + in_chs, out_chs, kernel_size, + stride=stride, padding=padding, bias=folded_bn) + self.bn1 = None if folded_bn else nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps) + + def forward(self, x): + x = self.conv(x) + if self.bn1 is not None: + x = self.bn1(x) + x = self.act_fn(x) + return x + + class DepthwiseSeparableConv(nn.Module): def __init__(self, in_chs, out_chs, kernel_size, stride=1, act_fn=F.relu, noskip=False, pw_act=False, @@ -370,20 +447,20 @@ class DepthwiseSeparableConv(nn.Module): return x -class CascadeConv3x3(nn.Sequential): - # FIXME lifted from maskrcnn_benchmark blocks, haven't used yet - def __init__(self, in_chs, out_chs, stride, act_fn=F.relu, noskip=False, +class CascadeConv(nn.Sequential): + # FIXME haven't used yet + def __init__(self, in_chs, out_chs, kernel_size=3, stride=2, act_fn=F.relu, noskip=False, bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT, folded_bn=False, padding_same=False): - super(CascadeConv3x3, self).__init__() + super(CascadeConv, self).__init__() assert stride in [1, 2] self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip self.act_fn = act_fn padding = _padding_arg(1, padding_same) - self.conv1 = sconv2d(in_chs, in_chs, 3, stride=stride, padding=padding, bias=folded_bn) + self.conv1 = sconv2d(in_chs, in_chs, kernel_size, stride=stride, padding=padding, bias=folded_bn) self.bn1 = None if folded_bn else nn.BatchNorm2d(in_chs, momentum=bn_momentum, eps=bn_eps) - self.conv2 = sconv2d(in_chs, out_chs, 3, stride=1, padding=padding, bias=folded_bn) + self.conv2 = sconv2d(in_chs, out_chs, kernel_size, stride=1, padding=padding, bias=folded_bn) self.bn2 = None if folded_bn else nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps) def forward(self, x): @@ -401,7 +478,7 @@ class CascadeConv3x3(nn.Sequential): class ChannelShuffle(nn.Module): - # FIXME lifted from maskrcnn_benchmark blocks, haven't used yet + # FIXME haven't used yet def __init__(self, groups): super(ChannelShuffle, self).__init__() self.groups = groups @@ -422,9 +499,10 @@ class ChannelShuffle(nn.Module): class SqueezeExcite(nn.Module): - def __init__(self, in_chs, reduce_chs=None, act_fn=F.relu): + def __init__(self, in_chs, reduce_chs=None, act_fn=F.relu, gate_fn=torch.sigmoid): super(SqueezeExcite, self).__init__() self.act_fn = act_fn + self.gate_fn = gate_fn reduced_chs = reduce_chs or in_chs self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) @@ -435,7 +513,7 @@ class SqueezeExcite(nn.Module): x_se = self.conv_reduce(x_se) x_se = self.act_fn(x_se) x_se = self.conv_expand(x_se) - x = torch.sigmoid(x_se) * x + x = self.gate_fn(x_se) * x return x @@ -444,7 +522,8 @@ class InvertedResidual(nn.Module): def __init__(self, in_chs, out_chs, kernel_size, stride=1, act_fn=F.relu, exp_ratio=1.0, noskip=False, - se_ratio=0., shuffle_type=None, pw_group=1, + se_ratio=0., se_reduce_mid=False, se_gate_fn=torch.sigmoid, + shuffle_type=None, pw_group=1, bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT, folded_bn=False, padding_same=False): super(InvertedResidual, self).__init__() @@ -470,7 +549,9 @@ class InvertedResidual(nn.Module): # Squeeze-and-excitation if self.has_se: - self.se = SqueezeExcite(mid_chs, reduce_chs=max(1, int(in_chs * se_ratio))) + reduce_mult = mid_chs if se_reduce_mid else in_chs + self.se = SqueezeExcite(mid_chs, reduce_chs=max(1, int(reduce_mult * se_ratio)), + act_fn=act_fn, gate_fn=se_gate_fn) # Point-wise linear projection self.conv_pwl = sconv2d(mid_chs, out_chs, 1, padding=pw_padding, groups=pw_group, bias=folded_bn) @@ -519,6 +600,7 @@ class GenMobileNet(nn.Module): An implementation of mobile optimized networks that covers: * MobileNet-V1 * MobileNet-V2 + * MobileNet-V3 * MNASNet A1, B1, and small * FBNet A, B, and C * ChamNet (arch details are murky) @@ -528,7 +610,8 @@ class GenMobileNet(nn.Module): def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280, depth_multiplier=1.0, depth_divisor=8, min_depth=None, bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT, - drop_rate=0., act_fn=F.relu, global_pool='avg', skip_head_conv=False, + drop_rate=0., act_fn=F.relu, se_gate_fn=torch.sigmoid, se_reduce_mid=False, + global_pool='avg', skip_head_conv=False, efficient_head=False, weight_init='goog', folded_bn=False, padding_same=False): super(GenMobileNet, self).__init__() self.num_classes = num_classes @@ -536,6 +619,7 @@ class GenMobileNet(nn.Module): self.drop_rate = drop_rate self.act_fn = act_fn self.num_features = num_features + self.efficient_head = efficient_head # pool before last conv stem_size = _round_channels(stem_size, depth_multiplier, depth_divisor, min_depth) self.conv_stem = sconv2d( @@ -545,7 +629,7 @@ class GenMobileNet(nn.Module): in_chs = stem_size builder = _BlockBuilder( - depth_multiplier, depth_divisor, min_depth, + depth_multiplier, depth_divisor, min_depth, act_fn, se_gate_fn, se_reduce_mid, bn_momentum, bn_eps, folded_bn, padding_same) self.blocks = nn.Sequential(*builder(in_chs, block_args)) in_chs = builder.in_chs @@ -556,8 +640,9 @@ class GenMobileNet(nn.Module): else: self.conv_head = sconv2d( in_chs, self.num_features, 1, - padding=_padding_arg(0, padding_same), bias=folded_bn) - self.bn2 = None if folded_bn else nn.BatchNorm2d(self.num_features, momentum=bn_momentum, eps=bn_eps) + padding=_padding_arg(0, padding_same), bias=folded_bn and not efficient_head) + self.bn2 = None if (folded_bn or efficient_head) else \ + nn.BatchNorm2d(self.num_features, momentum=bn_momentum, eps=bn_eps) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.classifier = nn.Linear(self.num_features, self.num_classes) @@ -587,14 +672,23 @@ class GenMobileNet(nn.Module): x = self.bn1(x) x = self.act_fn(x) x = self.blocks(x) - if self.conv_head is not None: + if self.efficient_head: + # efficient head, currently only mobilenet-v3 performs pool before last 1x1 conv + x = self.global_pool(x) # always need to pool here regardless of bool x = self.conv_head(x) - if self.bn2 is not None: - x = self.bn2(x) x = self.act_fn(x) - if pool: - x = self.global_pool(x) - x = x.view(x.size(0), -1) + if pool: + # expect flattened output if pool is true, otherwise keep dim + x = x.view(x.size(0), -1) + else: + if self.conv_head is not None: + x = self.conv_head(x) + if self.bn2 is not None: + x = self.bn2(x) + x = self.act_fn(x) + if pool: + x = self.global_pool(x) + x = x.view(x.size(0), -1) return x def forward(self, x): @@ -777,6 +871,52 @@ def _gen_mobilenet_v2(depth_multiplier, num_classes=1000, **kwargs): return model +def _gen_mobilenet_v3(depth_multiplier, num_classes=1000, **kwargs): + """Creates a MobileNet-V3 model. + + Ref impl: ? + Paper: https://arxiv.org/abs/1905.02244 + + Args: + depth_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_are_noskip'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24_are', 'ir_r1_k3_s1_e6_c24_are'], # relu + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25_are'], # relu + # stage 3, 28x28 in + # FIXME are expansions here correct? + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish + # stage 5, 14x14in + # FIXME the paper contains a weird block-stride pattern 1-2-1 that doesn't fit the usual 2-1-... + # What is correct? + ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], # hard-swish + ] + bn_momentum, bn_eps = _resolve_bn_params(kwargs) + model = GenMobileNet( + arch_def, + num_classes=num_classes, + stem_size=16, + depth_multiplier=depth_multiplier, + depth_divisor=8, + min_depth=None, + bn_momentum=bn_momentum, + bn_eps=bn_eps, + act_fn=hard_swish, + se_gate_fn=hard_sigmoid, + se_reduce_mid=True, + **kwargs + ) + return model + + def _gen_chamnet_v1(depth_multiplier, num_classes=1000, **kwargs): """ Generate Chameleon Network (ChamNet) @@ -916,9 +1056,9 @@ def _gen_spnasnet(depth_multiplier, num_classes=1000, **kwargs): return model -def mnasnet0_50(num_classes=1000, in_chans=3, pretrained=False, **kwargs): +def mnasnet_050(num_classes=1000, in_chans=3, pretrained=False, **kwargs): """ MNASNet B1, depth multiplier of 0.5. """ - default_cfg = default_cfgs['mnasnet0_50'] + default_cfg = default_cfgs['mnasnet_050'] model = _gen_mnasnet_b1(0.5, num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: @@ -926,9 +1066,9 @@ def mnasnet0_50(num_classes=1000, in_chans=3, pretrained=False, **kwargs): return model -def mnasnet0_75(num_classes, in_chans=3, pretrained=False, **kwargs): +def mnasnet_075(num_classes, in_chans=3, pretrained=False, **kwargs): """ MNASNet B1, depth multiplier of 0.75. """ - default_cfg = default_cfgs['mnasnet0_75'] + default_cfg = default_cfgs['mnasnet_075'] model = _gen_mnasnet_b1(0.75, num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: @@ -936,9 +1076,9 @@ def mnasnet0_75(num_classes, in_chans=3, pretrained=False, **kwargs): return model -def mnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs): +def mnasnet_100(num_classes, in_chans=3, pretrained=False, **kwargs): """ MNASNet B1, depth multiplier of 1.0. """ - default_cfg = default_cfgs['mnasnet1_00'] + default_cfg = default_cfgs['mnasnet_100'] model = _gen_mnasnet_b1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: @@ -946,9 +1086,9 @@ def mnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs): return model -def tflite_mnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs): +def tflite_mnasnet_100(num_classes, in_chans=3, pretrained=False, **kwargs): """ MNASNet B1, depth multiplier of 1.0. """ - default_cfg = default_cfgs['tflite_mnasnet1_00'] + default_cfg = default_cfgs['tflite_mnasnet_100'] # these two args are for compat with tflite pretrained weights kwargs['folded_bn'] = True kwargs['padding_same'] = True @@ -959,9 +1099,9 @@ def tflite_mnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs): return model -def mnasnet1_40(num_classes, in_chans=3, pretrained=False, **kwargs): +def mnasnet_140(num_classes, in_chans=3, pretrained=False, **kwargs): """ MNASNet B1, depth multiplier of 1.4 """ - default_cfg = default_cfgs['mnasnet1_40'] + default_cfg = default_cfgs['mnasnet_140'] model = _gen_mnasnet_b1(1.4, num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: @@ -969,9 +1109,9 @@ def mnasnet1_40(num_classes, in_chans=3, pretrained=False, **kwargs): return model -def semnasnet0_50(num_classes=1000, in_chans=3, pretrained=False, **kwargs): +def semnasnet_050(num_classes=1000, in_chans=3, pretrained=False, **kwargs): """ MNASNet A1 (w/ SE), depth multiplier of 0.5 """ - default_cfg = default_cfgs['semnasnet0_50'] + default_cfg = default_cfgs['semnasnet_050'] model = _gen_mnasnet_a1(0.5, num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: @@ -979,9 +1119,9 @@ def semnasnet0_50(num_classes=1000, in_chans=3, pretrained=False, **kwargs): return model -def semnasnet0_75(num_classes, in_chans=3, pretrained=False, **kwargs): +def semnasnet_075(num_classes, in_chans=3, pretrained=False, **kwargs): """ MNASNet A1 (w/ SE), depth multiplier of 0.75. """ - default_cfg = default_cfgs['semnasnet0_75'] + default_cfg = default_cfgs['semnasnet_075'] model = _gen_mnasnet_a1(0.75, num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: @@ -989,9 +1129,9 @@ def semnasnet0_75(num_classes, in_chans=3, pretrained=False, **kwargs): return model -def semnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs): +def semnasnet_100(num_classes, in_chans=3, pretrained=False, **kwargs): """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """ - default_cfg = default_cfgs['semnasnet1_00'] + default_cfg = default_cfgs['semnasnet_100'] model = _gen_mnasnet_a1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: @@ -999,9 +1139,9 @@ def semnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs): return model -def tflite_semnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs): +def tflite_semnasnet_100(num_classes, in_chans=3, pretrained=False, **kwargs): """ MNASNet A1, depth multiplier of 1.0. """ - default_cfg = default_cfgs['tflite_semnasnet1_00'] + default_cfg = default_cfgs['tflite_semnasnet_100'] # these two args are for compat with tflite pretrained weights kwargs['folded_bn'] = True kwargs['padding_same'] = True @@ -1012,9 +1152,9 @@ def tflite_semnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs): return model -def semnasnet1_40(num_classes, in_chans=3, pretrained=False, **kwargs): +def semnasnet_140(num_classes, in_chans=3, pretrained=False, **kwargs): """ MNASNet A1 (w/ SE), depth multiplier of 1.4. """ - default_cfg = default_cfgs['semnasnet1_40'] + default_cfg = default_cfgs['semnasnet_140'] model = _gen_mnasnet_a1(1.4, num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: @@ -1032,9 +1172,9 @@ def mnasnet_small(num_classes, in_chans=3, pretrained=False, **kwargs): return model -def mobilenetv1_1_00(num_classes, in_chans=3, pretrained=False, **kwargs): +def mobilenetv1_100(num_classes, in_chans=3, pretrained=False, **kwargs): """ MobileNet V1 """ - default_cfg = default_cfgs['mobilenetv1_1_00'] + default_cfg = default_cfgs['mobilenetv1_100'] model = _gen_mobilenet_v1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: @@ -1042,9 +1182,9 @@ def mobilenetv1_1_00(num_classes, in_chans=3, pretrained=False, **kwargs): return model -def mobilenetv2_1_00(num_classes, in_chans=3, pretrained=False, **kwargs): +def mobilenetv2_100(num_classes, in_chans=3, pretrained=False, **kwargs): """ MobileNet V2 """ - default_cfg = default_cfgs['mobilenetv2_1_00'] + default_cfg = default_cfgs['mobilenetv2_100'] model = _gen_mobilenet_v2(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: @@ -1052,9 +1192,39 @@ def mobilenetv2_1_00(num_classes, in_chans=3, pretrained=False, **kwargs): return model -def fbnetc_1_00(num_classes, in_chans=3, pretrained=False, **kwargs): +def mobilenetv3_050(num_classes, in_chans=3, pretrained=False, **kwargs): + """ MobileNet V3 """ + default_cfg = default_cfgs['mobilenetv3_050'] + model = _gen_mobilenet_v3(0.5, num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +def mobilenetv3_075(num_classes, in_chans=3, pretrained=False, **kwargs): + """ MobileNet V3 """ + default_cfg = default_cfgs['mobilenetv3_075'] + model = _gen_mobilenet_v3(0.75, num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +def mobilenetv3_100(num_classes, in_chans=3, pretrained=False, **kwargs): + """ MobileNet V3 """ + default_cfg = default_cfgs['mobilenetv3_100'] + model = _gen_mobilenet_v3(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +def fbnetc_100(num_classes, in_chans=3, pretrained=False, **kwargs): """ FBNet-C """ - default_cfg = default_cfgs['fbnetc_1_00'] + default_cfg = default_cfgs['fbnetc_100'] model = _gen_fbnetc(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: @@ -1062,9 +1232,9 @@ def fbnetc_1_00(num_classes, in_chans=3, pretrained=False, **kwargs): return model -def chamnetv1_1_00(num_classes, in_chans=3, pretrained=False, **kwargs): +def chamnetv1_100(num_classes, in_chans=3, pretrained=False, **kwargs): """ ChamNet """ - default_cfg = default_cfgs['chamnetv1_1_00'] + default_cfg = default_cfgs['chamnetv1_100'] model = _gen_chamnet_v1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: @@ -1072,9 +1242,9 @@ def chamnetv1_1_00(num_classes, in_chans=3, pretrained=False, **kwargs): return model -def chamnetv2_1_00(num_classes, in_chans=3, pretrained=False, **kwargs): +def chamnetv2_100(num_classes, in_chans=3, pretrained=False, **kwargs): """ ChamNet """ - default_cfg = default_cfgs['chamnetv2_1_00'] + default_cfg = default_cfgs['chamnetv2_100'] model = _gen_chamnet_v2(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: @@ -1082,9 +1252,9 @@ def chamnetv2_1_00(num_classes, in_chans=3, pretrained=False, **kwargs): return model -def spnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs): +def spnasnet_100(num_classes, in_chans=3, pretrained=False, **kwargs): """ Single-Path NAS Pixel1""" - default_cfg = default_cfgs['spnasnet1_00'] + default_cfg = default_cfgs['spnasnet_100'] model = _gen_spnasnet(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: diff --git a/models/model_factory.py b/models/model_factory.py index 902804ce..39a6f2d2 100644 --- a/models/model_factory.py +++ b/models/model_factory.py @@ -9,10 +9,10 @@ from models.senet import seresnet18, seresnet34, seresnet50, seresnet101, seresn from models.xception import xception from models.pnasnet import pnasnet5large from models.genmobilenet import \ - mnasnet0_50, mnasnet0_75, mnasnet1_00, mnasnet1_40, tflite_mnasnet1_00,\ - semnasnet0_50, semnasnet0_75, semnasnet1_00, semnasnet1_40, tflite_semnasnet1_00, mnasnet_small,\ - mobilenetv1_1_00, mobilenetv2_1_00, fbnetc_1_00, chamnetv1_1_00, chamnetv2_1_00,\ - spnasnet1_00 + mnasnet_050, mnasnet_075, mnasnet_100, mnasnet_140, tflite_mnasnet_100,\ + semnasnet_050, semnasnet_075, semnasnet_100, semnasnet_140, tflite_semnasnet_100, mnasnet_small,\ + mobilenetv1_100, mobilenetv2_100, mobilenetv3_050, mobilenetv3_075, mobilenetv3_100,\ + fbnetc_100, chamnetv1_100, chamnetv2_100, spnasnet_100 from models.helpers import load_checkpoint