diff --git a/timm/models/gen_efficientnet.py b/timm/models/gen_efficientnet.py index 2541bc6b..6e54c70e 100644 --- a/timm/models/gen_efficientnet.py +++ b/timm/models/gen_efficientnet.py @@ -50,18 +50,12 @@ default_cfgs = { 'mnasnet_100': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth', interpolation='bicubic'), - 'tflite_mnasnet_100': _cfg( - url='https://www.dropbox.com/s/q55ir3tx8mpeyol/tflite_mnasnet_100-31639cdc.pth?dl=1', - interpolation='bicubic'), 'mnasnet_140': _cfg(url=''), 'semnasnet_050': _cfg(url=''), 'semnasnet_075': _cfg(url=''), 'semnasnet_100': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth', interpolation='bicubic'), - 'tflite_semnasnet_100': _cfg( - url='https://www.dropbox.com/s/yiori47sr9dydev/tflite_semnasnet_100-7c780429.pth?dl=1', - interpolation='bicubic'), 'semnasnet_140': _cfg(url=''), 'mnasnet_small': _cfg(url=''), 'mobilenetv1_100': _cfg(url=''), @@ -118,6 +112,7 @@ _DEBUG = False # Default args for PyTorch BN impl _BN_MOMENTUM_PT_DEFAULT = 0.1 _BN_EPS_PT_DEFAULT = 1e-5 +_BN_ARGS_PT = dict(momentum=_BN_MOMENTUM_PT_DEFAULT, eps=_BN_EPS_PT_DEFAULT) # Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per # papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay) @@ -126,23 +121,18 @@ _BN_EPS_PT_DEFAULT = 1e-5 # .9997 (/w .999 in search space) for paper _BN_MOMENTUM_TF_DEFAULT = 1 - 0.99 _BN_EPS_TF_DEFAULT = 1e-3 +_BN_ARGS_TF = dict(momentum=_BN_MOMENTUM_TF_DEFAULT, eps=_BN_EPS_TF_DEFAULT) -def _resolve_bn_params(kwargs): - # NOTE kwargs passed as dict intentionally - bn_momentum_default = _BN_MOMENTUM_PT_DEFAULT - bn_eps_default = _BN_EPS_PT_DEFAULT - bn_tf = kwargs.pop('bn_tf', False) - if bn_tf: - bn_momentum_default = _BN_MOMENTUM_TF_DEFAULT - bn_eps_default = _BN_EPS_TF_DEFAULT +def _resolve_bn_args(kwargs): + bn_args = _BN_ARGS_TF.copy() if kwargs.pop('bn_tf', False) else _BN_ARGS_PT.copy() bn_momentum = kwargs.pop('bn_momentum', None) + if bn_momentum is not None: + bn_args['momentum'] = bn_momentum bn_eps = kwargs.pop('bn_eps', None) - if bn_momentum is None: - bn_momentum = bn_momentum_default - if bn_eps is None: - bn_eps = bn_eps_default - return bn_momentum, bn_eps + if bn_eps is not None: + bn_args['eps'] = bn_eps + return bn_args def _round_channels(channels, multiplier=1.0, divisor=8, channel_min=None): @@ -292,6 +282,31 @@ def _decode_arch_def(arch_def, depth_multiplier=1.0): return arch_args +def swish(x, inplace=False): + if inplace: + return x.mul_(x.sigmoid()) + else: + return x * x.sigmoid() + + +def sigmoid(x, inplace=False): + return x.sigmoid_() if inplace else x.sigmoid() + + +def hard_swish(x, inplace=False): + if inplace: + return x.mul_(F.relu6(x + 3.) / 6.) + else: + return x * F.relu6(x + 3.) / 6. + + +def hard_sigmoid(x, inplace=False): + if inplace: + return x.add_(3.).clamp_(0., 6.).div_(6.) + else: + return F.relu6(x + 3.) / 6. + + class _BlockBuilder: """ Build Trunk Blocks @@ -303,9 +318,9 @@ class _BlockBuilder: """ def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None, - drop_connect_rate=0., act_fn=None, se_gate_fn=torch.sigmoid, se_reduce_mid=False, - bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT, - folded_bn=False, padding_same=False, verbose=False): + drop_connect_rate=0., act_fn=None, se_gate_fn=sigmoid, se_reduce_mid=False, + bn_args=_BN_ARGS_PT, padding_same=False, + verbose=False): self.channel_multiplier = channel_multiplier self.channel_divisor = channel_divisor self.channel_min = channel_min @@ -313,9 +328,7 @@ class _BlockBuilder: self.act_fn = act_fn self.se_gate_fn = se_gate_fn self.se_reduce_mid = se_reduce_mid - self.bn_momentum = bn_momentum - self.bn_eps = bn_eps - self.folded_bn = folded_bn + self.bn_args = bn_args self.padding_same = padding_same self.verbose = verbose @@ -331,9 +344,7 @@ class _BlockBuilder: bt = ba.pop('block_type') ba['in_chs'] = self.in_chs ba['out_chs'] = self._round_channels(ba['out_chs']) - ba['bn_momentum'] = self.bn_momentum - ba['bn_eps'] = self.bn_eps - ba['folded_bn'] = self.folded_bn + ba['bn_args'] = self.bn_args ba['padding_same'] = self.padding_same # block act fn overrides the model default ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn @@ -427,18 +438,6 @@ def _initialize_weight_default(m): nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear') -def swish(x): - return x * torch.sigmoid(x) - - -def hard_swish(x): - return x * F.relu6(x + 3.) / 6. - - -def hard_sigmoid(x): - return F.relu6(x + 3.) / 6. - - def drop_connect(inputs, training=False, drop_connect_rate=0.): """Apply drop connect.""" if not training: @@ -474,7 +473,7 @@ class ChannelShuffle(nn.Module): class SqueezeExcite(nn.Module): - def __init__(self, in_chs, reduce_chs=None, act_fn=F.relu, gate_fn=torch.sigmoid): + def __init__(self, in_chs, reduce_chs=None, act_fn=F.relu, gate_fn=sigmoid): super(SqueezeExcite, self).__init__() self.act_fn = act_fn self.gate_fn = gate_fn @@ -486,17 +485,16 @@ class SqueezeExcite(nn.Module): # NOTE adaptiveavgpool can be used here, but seems to cause issues with NVIDIA AMP performance x_se = x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1) x_se = self.conv_reduce(x_se) - x_se = self.act_fn(x_se) + x_se = self.act_fn(x_se, inplace=True) x_se = self.conv_expand(x_se) - x = self.gate_fn(x_se) * x + x = x * self.gate_fn(x_se) return x class ConvBnAct(nn.Module): def __init__(self, in_chs, out_chs, kernel_size, stride=1, act_fn=F.relu, - bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT, - folded_bn=False, padding_same=False): + bn_args=_BN_ARGS_PT, padding_same=False): super(ConvBnAct, self).__init__() assert stride in [1, 2] self.act_fn = act_fn @@ -504,14 +502,13 @@ class ConvBnAct(nn.Module): self.conv = sconv2d( in_chs, out_chs, kernel_size, - stride=stride, padding=padding, bias=folded_bn) - self.bn1 = None if folded_bn else nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps) + stride=stride, padding=padding, bias=False) + self.bn1 = nn.BatchNorm2d(out_chs, **bn_args) def forward(self, x): x = self.conv(x) - if self.bn1 is not None: - x = self.bn1(x) - x = self.act_fn(x) + x = self.bn1(x) + x = self.act_fn(x, inplace=True) return x @@ -522,9 +519,8 @@ class DepthwiseSeparableConv(nn.Module): """ def __init__(self, in_chs, out_chs, kernel_size, stride=1, act_fn=F.relu, noskip=False, pw_act=False, - se_ratio=0., se_gate_fn=torch.sigmoid, - bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT, - folded_bn=False, padding_same=False, drop_connect_rate=0.): + se_ratio=0., se_gate_fn=sigmoid, + bn_args=_BN_ARGS_PT, padding_same=False, drop_connect_rate=0.): super(DepthwiseSeparableConv, self).__init__() assert stride in [1, 2] self.has_se = se_ratio is not None and se_ratio > 0. @@ -537,33 +533,31 @@ class DepthwiseSeparableConv(nn.Module): self.conv_dw = sconv2d( in_chs, in_chs, kernel_size, - stride=stride, padding=dw_padding, groups=in_chs, bias=folded_bn) - self.bn1 = None if folded_bn else nn.BatchNorm2d(in_chs, momentum=bn_momentum, eps=bn_eps) + stride=stride, padding=dw_padding, groups=in_chs, bias=False) + self.bn1 = nn.BatchNorm2d(in_chs, **bn_args) # Squeeze-and-excitation if self.has_se: self.se = SqueezeExcite( in_chs, reduce_chs=max(1, int(in_chs * se_ratio)), act_fn=act_fn, gate_fn=se_gate_fn) - self.conv_pw = sconv2d(in_chs, out_chs, 1, padding=pw_padding, bias=folded_bn) - self.bn2 = None if folded_bn else nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps) + self.conv_pw = sconv2d(in_chs, out_chs, 1, padding=pw_padding, bias=False) + self.bn2 = nn.BatchNorm2d(out_chs, **bn_args) def forward(self, x): residual = x x = self.conv_dw(x) - if self.bn1 is not None: - x = self.bn1(x) - x = self.act_fn(x) + x = self.bn1(x) + x = self.act_fn(x, inplace=True) if self.has_se: x = self.se(x) x = self.conv_pw(x) - if self.bn2 is not None: - x = self.bn2(x) + x = self.bn2(x) if self.has_pw_act: - x = self.act_fn(x) + x = self.act_fn(x, inplace=True) if self.has_residual: if self.drop_connect_rate > 0.: @@ -577,10 +571,9 @@ class InvertedResidual(nn.Module): def __init__(self, in_chs, out_chs, kernel_size, stride=1, act_fn=F.relu, exp_ratio=1.0, noskip=False, - se_ratio=0., se_reduce_mid=False, se_gate_fn=torch.sigmoid, + se_ratio=0., se_reduce_mid=False, se_gate_fn=sigmoid, shuffle_type=None, pw_group=1, - bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT, - folded_bn=False, padding_same=False, drop_connect_rate=0.): + bn_args=_BN_ARGS_PT, padding_same=False, drop_connect_rate=0.): super(InvertedResidual, self).__init__() mid_chs = int(in_chs * exp_ratio) self.has_se = se_ratio is not None and se_ratio > 0. @@ -591,8 +584,8 @@ class InvertedResidual(nn.Module): pw_padding = _padding_arg(0, padding_same) # Point-wise expansion - self.conv_pw = sconv2d(in_chs, mid_chs, 1, padding=pw_padding, groups=pw_group, bias=folded_bn) - self.bn1 = None if folded_bn else nn.BatchNorm2d(mid_chs, momentum=bn_momentum, eps=bn_eps) + self.conv_pw = sconv2d(in_chs, mid_chs, 1, padding=pw_padding, groups=pw_group, bias=False) + self.bn1 = nn.BatchNorm2d(mid_chs, **bn_args) self.shuffle_type = shuffle_type if shuffle_type is not None: @@ -600,8 +593,8 @@ class InvertedResidual(nn.Module): # Depth-wise convolution self.conv_dw = sconv2d( - mid_chs, mid_chs, kernel_size, padding=dw_padding, stride=stride, groups=mid_chs, bias=folded_bn) - self.bn2 = None if folded_bn else nn.BatchNorm2d(mid_chs, momentum=bn_momentum, eps=bn_eps) + mid_chs, mid_chs, kernel_size, padding=dw_padding, stride=stride, groups=mid_chs, bias=False) + self.bn2 = nn.BatchNorm2d(mid_chs, **bn_args) # Squeeze-and-excitation if self.has_se: @@ -610,17 +603,16 @@ class InvertedResidual(nn.Module): mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)), act_fn=act_fn, gate_fn=se_gate_fn) # Point-wise linear projection - self.conv_pwl = sconv2d(mid_chs, out_chs, 1, padding=pw_padding, groups=pw_group, bias=folded_bn) - self.bn3 = None if folded_bn else nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps) + self.conv_pwl = sconv2d(mid_chs, out_chs, 1, padding=pw_padding, groups=pw_group, bias=False) + self.bn3 = nn.BatchNorm2d(out_chs, **bn_args) def forward(self, x): residual = x # Point-wise expansion x = self.conv_pw(x) - if self.bn1 is not None: - x = self.bn1(x) - x = self.act_fn(x) + x = self.bn1(x) + x = self.act_fn(x, inplace=True) # FIXME haven't tried this yet # for channel shuffle when using groups with pointwise convs as per FBNet variants @@ -629,9 +621,8 @@ class InvertedResidual(nn.Module): # Depth-wise convolution x = self.conv_dw(x) - if self.bn2 is not None: - x = self.bn2(x) - x = self.act_fn(x) + x = self.bn2(x) + x = self.act_fn(x, inplace=True) # Squeeze-and-excitation if self.has_se: @@ -639,8 +630,7 @@ class InvertedResidual(nn.Module): # Point-wise linear projection x = self.conv_pwl(x) - if self.bn3 is not None: - x = self.bn3(x) + x = self.bn3(x) if self.has_residual: if self.drop_connect_rate > 0.: @@ -668,11 +658,9 @@ class GenEfficientNet(nn.Module): def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280, channel_multiplier=1.0, channel_divisor=8, channel_min=None, - bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT, drop_rate=0., drop_connect_rate=0., act_fn=F.relu, - se_gate_fn=torch.sigmoid, se_reduce_mid=False, - global_pool='avg', head_conv='default', weight_init='goog', - folded_bn=False, padding_same=False,): + se_gate_fn=sigmoid, se_reduce_mid=False, bn_args=_BN_ARGS_PT, + global_pool='avg', head_conv='default', weight_init='goog', padding_same=False): super(GenEfficientNet, self).__init__() self.num_classes = num_classes self.drop_rate = drop_rate @@ -682,14 +670,14 @@ class GenEfficientNet(nn.Module): stem_size = _round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) self.conv_stem = sconv2d( in_chans, stem_size, 3, - padding=_padding_arg(1, padding_same), stride=2, bias=folded_bn) - self.bn1 = None if folded_bn else nn.BatchNorm2d(stem_size, momentum=bn_momentum, eps=bn_eps) + padding=_padding_arg(1, padding_same), stride=2, bias=False) + self.bn1 = nn.BatchNorm2d(stem_size, **bn_args) in_chs = stem_size builder = _BlockBuilder( channel_multiplier, channel_divisor, channel_min, drop_connect_rate, act_fn, se_gate_fn, se_reduce_mid, - bn_momentum, bn_eps, folded_bn, padding_same, verbose=_DEBUG) + bn_args, padding_same, verbose=_DEBUG) self.blocks = nn.Sequential(*builder(in_chs, block_args)) in_chs = builder.in_chs @@ -701,9 +689,8 @@ class GenEfficientNet(nn.Module): self.efficient_head = head_conv == 'efficient' self.conv_head = sconv2d( in_chs, self.num_features, 1, - padding=_padding_arg(0, padding_same), bias=folded_bn and not self.efficient_head) - self.bn2 = None if (folded_bn or self.efficient_head) else \ - nn.BatchNorm2d(self.num_features, momentum=bn_momentum, eps=bn_eps) + padding=_padding_arg(0, padding_same), bias=False) + self.bn2 = None if self.efficient_head else nn.BatchNorm2d(self.num_features, **bn_args) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), self.num_classes) @@ -729,25 +716,23 @@ class GenEfficientNet(nn.Module): def forward_features(self, x, pool=True): x = self.conv_stem(x) - if self.bn1 is not None: - x = self.bn1(x) - x = self.act_fn(x) + x = self.bn1(x) + x = self.act_fn(x, inplace=True) x = self.blocks(x) if self.efficient_head: # efficient head, currently only mobilenet-v3 performs pool before last 1x1 conv x = self.global_pool(x) # always need to pool here regardless of flag x = self.conv_head(x) # no BN - x = self.act_fn(x) + x = self.act_fn(x, inplace=True) if pool: # expect flattened output if pool is true, otherwise keep dim x = x.view(x.size(0), -1) else: if self.conv_head is not None: x = self.conv_head(x) - if self.bn2 is not None: - x = self.bn2(x) - x = self.act_fn(x) + x = self.bn2(x) + x = self.act_fn(x, inplace=True) if pool: x = self.global_pool(x) x = x.view(x.size(0), -1) @@ -785,7 +770,6 @@ def _gen_mnasnet_a1(channel_multiplier, num_classes=1000, **kwargs): # stage 6, 7x7 in ['ir_r1_k3_s1_e6_c320'], ] - bn_momentum, bn_eps = _resolve_bn_params(kwargs) model = GenEfficientNet( _decode_arch_def(arch_def), num_classes=num_classes, @@ -793,8 +777,7 @@ def _gen_mnasnet_a1(channel_multiplier, num_classes=1000, **kwargs): channel_multiplier=channel_multiplier, channel_divisor=8, channel_min=None, - bn_momentum=bn_momentum, - bn_eps=bn_eps, + bn_args=_resolve_bn_args(kwargs), **kwargs ) return model @@ -825,7 +808,6 @@ def _gen_mnasnet_b1(channel_multiplier, num_classes=1000, **kwargs): # stage 6, 7x7 in ['ir_r1_k3_s1_e6_c320_noskip'] ] - bn_momentum, bn_eps = _resolve_bn_params(kwargs) model = GenEfficientNet( _decode_arch_def(arch_def), num_classes=num_classes, @@ -833,8 +815,7 @@ def _gen_mnasnet_b1(channel_multiplier, num_classes=1000, **kwargs): channel_multiplier=channel_multiplier, channel_divisor=8, channel_min=None, - bn_momentum=bn_momentum, - bn_eps=bn_eps, + bn_args=_resolve_bn_args(kwargs), **kwargs ) return model @@ -858,7 +839,6 @@ def _gen_mnasnet_small(channel_multiplier, num_classes=1000, **kwargs): ['ir_r3_k5_s2_e6_c88_se0.25'], ['ir_r1_k3_s1_e6_c144'] ] - bn_momentum, bn_eps = _resolve_bn_params(kwargs) model = GenEfficientNet( _decode_arch_def(arch_def), num_classes=num_classes, @@ -866,8 +846,7 @@ def _gen_mnasnet_small(channel_multiplier, num_classes=1000, **kwargs): channel_multiplier=channel_multiplier, channel_divisor=8, channel_min=None, - bn_momentum=bn_momentum, - bn_eps=bn_eps, + bn_args=_resolve_bn_args(kwargs), **kwargs ) return model @@ -885,7 +864,6 @@ def _gen_mobilenet_v1(channel_multiplier, num_classes=1000, **kwargs): ['dsa_r6_k3_s2_c512'], ['dsa_r2_k3_s2_c1024'], ] - bn_momentum, bn_eps = _resolve_bn_params(kwargs) model = GenEfficientNet( _decode_arch_def(arch_def), num_classes=num_classes, @@ -894,8 +872,7 @@ def _gen_mobilenet_v1(channel_multiplier, num_classes=1000, **kwargs): channel_multiplier=channel_multiplier, channel_divisor=8, channel_min=None, - bn_momentum=bn_momentum, - bn_eps=bn_eps, + bn_args=_resolve_bn_args(kwargs), act_fn=F.relu6, head_conv='none', **kwargs @@ -917,7 +894,6 @@ def _gen_mobilenet_v2(channel_multiplier, num_classes=1000, **kwargs): ['ir_r3_k3_s2_e6_c160'], ['ir_r1_k3_s1_e6_c320'], ] - bn_momentum, bn_eps = _resolve_bn_params(kwargs) model = GenEfficientNet( _decode_arch_def(arch_def), num_classes=num_classes, @@ -925,8 +901,7 @@ def _gen_mobilenet_v2(channel_multiplier, num_classes=1000, **kwargs): channel_multiplier=channel_multiplier, channel_divisor=8, channel_min=None, - bn_momentum=bn_momentum, - bn_eps=bn_eps, + bn_args=_resolve_bn_args(kwargs), act_fn=F.relu6, **kwargs ) @@ -958,7 +933,6 @@ def _gen_mobilenet_v3(channel_multiplier, num_classes=1000, **kwargs): # stage 6, 7x7 in ['cn_r1_k1_s1_c960'], # hard-swish ] - bn_momentum, bn_eps = _resolve_bn_params(kwargs) model = GenEfficientNet( _decode_arch_def(arch_def), num_classes=num_classes, @@ -966,8 +940,7 @@ def _gen_mobilenet_v3(channel_multiplier, num_classes=1000, **kwargs): channel_multiplier=channel_multiplier, channel_divisor=8, channel_min=None, - bn_momentum=bn_momentum, - bn_eps=bn_eps, + bn_args=_resolve_bn_args(kwargs), act_fn=hard_swish, se_gate_fn=hard_sigmoid, se_reduce_mid=True, @@ -994,7 +967,6 @@ def _gen_chamnet_v1(channel_multiplier, num_classes=1000, **kwargs): ['ir_r4_k3_s2_e7_c152'], ['ir_r1_k3_s1_e10_c104'], ] - bn_momentum, bn_eps = _resolve_bn_params(kwargs) model = GenEfficientNet( _decode_arch_def(arch_def), num_classes=num_classes, @@ -1003,8 +975,7 @@ def _gen_chamnet_v1(channel_multiplier, num_classes=1000, **kwargs): channel_multiplier=channel_multiplier, channel_divisor=8, channel_min=None, - bn_momentum=bn_momentum, - bn_eps=bn_eps, + bn_args=_resolve_bn_args(kwargs), **kwargs ) return model @@ -1027,7 +998,6 @@ def _gen_chamnet_v2(channel_multiplier, num_classes=1000, **kwargs): ['ir_r6_k3_s2_e2_c152'], ['ir_r1_k3_s1_e6_c112'], ] - bn_momentum, bn_eps = _resolve_bn_params(kwargs) model = GenEfficientNet( _decode_arch_def(arch_def), num_classes=num_classes, @@ -1036,8 +1006,7 @@ def _gen_chamnet_v2(channel_multiplier, num_classes=1000, **kwargs): channel_multiplier=channel_multiplier, channel_divisor=8, channel_min=None, - bn_momentum=bn_momentum, - bn_eps=bn_eps, + bn_args=_resolve_bn_args(kwargs), **kwargs ) return model @@ -1061,7 +1030,6 @@ def _gen_fbnetc(channel_multiplier, num_classes=1000, **kwargs): ['ir_r4_k5_s2_e6_c184'], ['ir_r1_k3_s1_e6_c352'], ] - bn_momentum, bn_eps = _resolve_bn_params(kwargs) model = GenEfficientNet( _decode_arch_def(arch_def), num_classes=num_classes, @@ -1070,8 +1038,7 @@ def _gen_fbnetc(channel_multiplier, num_classes=1000, **kwargs): channel_multiplier=channel_multiplier, channel_divisor=8, channel_min=None, - bn_momentum=bn_momentum, - bn_eps=bn_eps, + bn_args=_resolve_bn_args(kwargs), **kwargs ) return model @@ -1101,7 +1068,6 @@ def _gen_spnasnet(channel_multiplier, num_classes=1000, **kwargs): # stage 6, 7x7 in ['ir_r1_k3_s1_e6_c320_noskip'] ] - bn_momentum, bn_eps = _resolve_bn_params(kwargs) model = GenEfficientNet( _decode_arch_def(arch_def), num_classes=num_classes, @@ -1109,8 +1075,7 @@ def _gen_spnasnet(channel_multiplier, num_classes=1000, **kwargs): channel_multiplier=channel_multiplier, channel_divisor=8, channel_min=None, - bn_momentum=bn_momentum, - bn_eps=bn_eps, + bn_args=_resolve_bn_args(kwargs), **kwargs ) return model @@ -1147,7 +1112,6 @@ def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes= ['ir_r4_k5_s2_e6_c192_se0.25'], ['ir_r1_k3_s1_e6_c320_se0.25'], ] - bn_momentum, bn_eps = _resolve_bn_params(kwargs) # NOTE: other models in the family didn't scale the feature count num_features = _round_channels(1280, channel_multiplier, 8, None) model = GenEfficientNet( @@ -1158,8 +1122,7 @@ def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes= channel_divisor=8, channel_min=None, num_features=num_features, - bn_momentum=bn_momentum, - bn_eps=bn_eps, + bn_args=_resolve_bn_args(kwargs), act_fn=swish, **kwargs ) @@ -1205,20 +1168,6 @@ def mnasnet_b1(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return mnasnet_100(pretrained, num_classes, in_chans, **kwargs) -@register_model -def tflite_mnasnet_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ MNASNet B1, depth multiplier of 1.0. """ - default_cfg = default_cfgs['tflite_mnasnet_100'] - # these two args are for compat with tflite pretrained weights - kwargs['folded_bn'] = True - kwargs['padding_same'] = True - model = _gen_mnasnet_b1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - @register_model def mnasnet_140(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ MNASNet B1, depth multiplier of 1.4 """ @@ -1269,20 +1218,6 @@ def mnasnet_a1(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return semnasnet_100(pretrained, num_classes, in_chans, **kwargs) -@register_model -def tflite_semnasnet_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ MNASNet A1, depth multiplier of 1.0. """ - default_cfg = default_cfgs['tflite_semnasnet_100'] - # these two args are for compat with tflite pretrained weights - kwargs['folded_bn'] = True - kwargs['padding_same'] = True - model = _gen_mnasnet_a1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - @register_model def semnasnet_140(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ MNASNet A1 (w/ SE), depth multiplier of 1.4. """