EfficientNet and related cleanup

* remove folded_bn support and corresponding untrainable tflite ported weights
* combine bn args into dict
* add inplace support to activations and use where possible for reduced mem on large models
pull/23/head
Ross Wightman 6 years ago
parent c11973602d
commit d6ac5bbc48

@ -50,18 +50,12 @@ default_cfgs = {
'mnasnet_100': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth',
interpolation='bicubic'),
'tflite_mnasnet_100': _cfg(
url='https://www.dropbox.com/s/q55ir3tx8mpeyol/tflite_mnasnet_100-31639cdc.pth?dl=1',
interpolation='bicubic'),
'mnasnet_140': _cfg(url=''),
'semnasnet_050': _cfg(url=''),
'semnasnet_075': _cfg(url=''),
'semnasnet_100': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth',
interpolation='bicubic'),
'tflite_semnasnet_100': _cfg(
url='https://www.dropbox.com/s/yiori47sr9dydev/tflite_semnasnet_100-7c780429.pth?dl=1',
interpolation='bicubic'),
'semnasnet_140': _cfg(url=''),
'mnasnet_small': _cfg(url=''),
'mobilenetv1_100': _cfg(url=''),
@ -118,6 +112,7 @@ _DEBUG = False
# Default args for PyTorch BN impl
_BN_MOMENTUM_PT_DEFAULT = 0.1
_BN_EPS_PT_DEFAULT = 1e-5
_BN_ARGS_PT = dict(momentum=_BN_MOMENTUM_PT_DEFAULT, eps=_BN_EPS_PT_DEFAULT)
# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)
@ -126,23 +121,18 @@ _BN_EPS_PT_DEFAULT = 1e-5
# .9997 (/w .999 in search space) for paper
_BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
_BN_EPS_TF_DEFAULT = 1e-3
_BN_ARGS_TF = dict(momentum=_BN_MOMENTUM_TF_DEFAULT, eps=_BN_EPS_TF_DEFAULT)
def _resolve_bn_params(kwargs):
# NOTE kwargs passed as dict intentionally
bn_momentum_default = _BN_MOMENTUM_PT_DEFAULT
bn_eps_default = _BN_EPS_PT_DEFAULT
bn_tf = kwargs.pop('bn_tf', False)
if bn_tf:
bn_momentum_default = _BN_MOMENTUM_TF_DEFAULT
bn_eps_default = _BN_EPS_TF_DEFAULT
def _resolve_bn_args(kwargs):
bn_args = _BN_ARGS_TF.copy() if kwargs.pop('bn_tf', False) else _BN_ARGS_PT.copy()
bn_momentum = kwargs.pop('bn_momentum', None)
if bn_momentum is not None:
bn_args['momentum'] = bn_momentum
bn_eps = kwargs.pop('bn_eps', None)
if bn_momentum is None:
bn_momentum = bn_momentum_default
if bn_eps is None:
bn_eps = bn_eps_default
return bn_momentum, bn_eps
if bn_eps is not None:
bn_args['eps'] = bn_eps
return bn_args
def _round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
@ -292,6 +282,31 @@ def _decode_arch_def(arch_def, depth_multiplier=1.0):
return arch_args
def swish(x, inplace=False):
if inplace:
return x.mul_(x.sigmoid())
else:
return x * x.sigmoid()
def sigmoid(x, inplace=False):
return x.sigmoid_() if inplace else x.sigmoid()
def hard_swish(x, inplace=False):
if inplace:
return x.mul_(F.relu6(x + 3.) / 6.)
else:
return x * F.relu6(x + 3.) / 6.
def hard_sigmoid(x, inplace=False):
if inplace:
return x.add_(3.).clamp_(0., 6.).div_(6.)
else:
return F.relu6(x + 3.) / 6.
class _BlockBuilder:
""" Build Trunk Blocks
@ -303,9 +318,9 @@ class _BlockBuilder:
"""
def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
drop_connect_rate=0., act_fn=None, se_gate_fn=torch.sigmoid, se_reduce_mid=False,
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
folded_bn=False, padding_same=False, verbose=False):
drop_connect_rate=0., act_fn=None, se_gate_fn=sigmoid, se_reduce_mid=False,
bn_args=_BN_ARGS_PT, padding_same=False,
verbose=False):
self.channel_multiplier = channel_multiplier
self.channel_divisor = channel_divisor
self.channel_min = channel_min
@ -313,9 +328,7 @@ class _BlockBuilder:
self.act_fn = act_fn
self.se_gate_fn = se_gate_fn
self.se_reduce_mid = se_reduce_mid
self.bn_momentum = bn_momentum
self.bn_eps = bn_eps
self.folded_bn = folded_bn
self.bn_args = bn_args
self.padding_same = padding_same
self.verbose = verbose
@ -331,9 +344,7 @@ class _BlockBuilder:
bt = ba.pop('block_type')
ba['in_chs'] = self.in_chs
ba['out_chs'] = self._round_channels(ba['out_chs'])
ba['bn_momentum'] = self.bn_momentum
ba['bn_eps'] = self.bn_eps
ba['folded_bn'] = self.folded_bn
ba['bn_args'] = self.bn_args
ba['padding_same'] = self.padding_same
# block act fn overrides the model default
ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn
@ -427,18 +438,6 @@ def _initialize_weight_default(m):
nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear')
def swish(x):
return x * torch.sigmoid(x)
def hard_swish(x):
return x * F.relu6(x + 3.) / 6.
def hard_sigmoid(x):
return F.relu6(x + 3.) / 6.
def drop_connect(inputs, training=False, drop_connect_rate=0.):
"""Apply drop connect."""
if not training:
@ -474,7 +473,7 @@ class ChannelShuffle(nn.Module):
class SqueezeExcite(nn.Module):
def __init__(self, in_chs, reduce_chs=None, act_fn=F.relu, gate_fn=torch.sigmoid):
def __init__(self, in_chs, reduce_chs=None, act_fn=F.relu, gate_fn=sigmoid):
super(SqueezeExcite, self).__init__()
self.act_fn = act_fn
self.gate_fn = gate_fn
@ -486,17 +485,16 @@ class SqueezeExcite(nn.Module):
# NOTE adaptiveavgpool can be used here, but seems to cause issues with NVIDIA AMP performance
x_se = x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)
x_se = self.conv_reduce(x_se)
x_se = self.act_fn(x_se)
x_se = self.act_fn(x_se, inplace=True)
x_se = self.conv_expand(x_se)
x = self.gate_fn(x_se) * x
x = x * self.gate_fn(x_se)
return x
class ConvBnAct(nn.Module):
def __init__(self, in_chs, out_chs, kernel_size,
stride=1, act_fn=F.relu,
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
folded_bn=False, padding_same=False):
bn_args=_BN_ARGS_PT, padding_same=False):
super(ConvBnAct, self).__init__()
assert stride in [1, 2]
self.act_fn = act_fn
@ -504,14 +502,13 @@ class ConvBnAct(nn.Module):
self.conv = sconv2d(
in_chs, out_chs, kernel_size,
stride=stride, padding=padding, bias=folded_bn)
self.bn1 = None if folded_bn else nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
stride=stride, padding=padding, bias=False)
self.bn1 = nn.BatchNorm2d(out_chs, **bn_args)
def forward(self, x):
x = self.conv(x)
if self.bn1 is not None:
x = self.bn1(x)
x = self.act_fn(x)
x = self.bn1(x)
x = self.act_fn(x, inplace=True)
return x
@ -522,9 +519,8 @@ class DepthwiseSeparableConv(nn.Module):
"""
def __init__(self, in_chs, out_chs, kernel_size,
stride=1, act_fn=F.relu, noskip=False, pw_act=False,
se_ratio=0., se_gate_fn=torch.sigmoid,
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
folded_bn=False, padding_same=False, drop_connect_rate=0.):
se_ratio=0., se_gate_fn=sigmoid,
bn_args=_BN_ARGS_PT, padding_same=False, drop_connect_rate=0.):
super(DepthwiseSeparableConv, self).__init__()
assert stride in [1, 2]
self.has_se = se_ratio is not None and se_ratio > 0.
@ -537,33 +533,31 @@ class DepthwiseSeparableConv(nn.Module):
self.conv_dw = sconv2d(
in_chs, in_chs, kernel_size,
stride=stride, padding=dw_padding, groups=in_chs, bias=folded_bn)
self.bn1 = None if folded_bn else nn.BatchNorm2d(in_chs, momentum=bn_momentum, eps=bn_eps)
stride=stride, padding=dw_padding, groups=in_chs, bias=False)
self.bn1 = nn.BatchNorm2d(in_chs, **bn_args)
# Squeeze-and-excitation
if self.has_se:
self.se = SqueezeExcite(
in_chs, reduce_chs=max(1, int(in_chs * se_ratio)), act_fn=act_fn, gate_fn=se_gate_fn)
self.conv_pw = sconv2d(in_chs, out_chs, 1, padding=pw_padding, bias=folded_bn)
self.bn2 = None if folded_bn else nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
self.conv_pw = sconv2d(in_chs, out_chs, 1, padding=pw_padding, bias=False)
self.bn2 = nn.BatchNorm2d(out_chs, **bn_args)
def forward(self, x):
residual = x
x = self.conv_dw(x)
if self.bn1 is not None:
x = self.bn1(x)
x = self.act_fn(x)
x = self.bn1(x)
x = self.act_fn(x, inplace=True)
if self.has_se:
x = self.se(x)
x = self.conv_pw(x)
if self.bn2 is not None:
x = self.bn2(x)
x = self.bn2(x)
if self.has_pw_act:
x = self.act_fn(x)
x = self.act_fn(x, inplace=True)
if self.has_residual:
if self.drop_connect_rate > 0.:
@ -577,10 +571,9 @@ class InvertedResidual(nn.Module):
def __init__(self, in_chs, out_chs, kernel_size,
stride=1, act_fn=F.relu, exp_ratio=1.0, noskip=False,
se_ratio=0., se_reduce_mid=False, se_gate_fn=torch.sigmoid,
se_ratio=0., se_reduce_mid=False, se_gate_fn=sigmoid,
shuffle_type=None, pw_group=1,
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
folded_bn=False, padding_same=False, drop_connect_rate=0.):
bn_args=_BN_ARGS_PT, padding_same=False, drop_connect_rate=0.):
super(InvertedResidual, self).__init__()
mid_chs = int(in_chs * exp_ratio)
self.has_se = se_ratio is not None and se_ratio > 0.
@ -591,8 +584,8 @@ class InvertedResidual(nn.Module):
pw_padding = _padding_arg(0, padding_same)
# Point-wise expansion
self.conv_pw = sconv2d(in_chs, mid_chs, 1, padding=pw_padding, groups=pw_group, bias=folded_bn)
self.bn1 = None if folded_bn else nn.BatchNorm2d(mid_chs, momentum=bn_momentum, eps=bn_eps)
self.conv_pw = sconv2d(in_chs, mid_chs, 1, padding=pw_padding, groups=pw_group, bias=False)
self.bn1 = nn.BatchNorm2d(mid_chs, **bn_args)
self.shuffle_type = shuffle_type
if shuffle_type is not None:
@ -600,8 +593,8 @@ class InvertedResidual(nn.Module):
# Depth-wise convolution
self.conv_dw = sconv2d(
mid_chs, mid_chs, kernel_size, padding=dw_padding, stride=stride, groups=mid_chs, bias=folded_bn)
self.bn2 = None if folded_bn else nn.BatchNorm2d(mid_chs, momentum=bn_momentum, eps=bn_eps)
mid_chs, mid_chs, kernel_size, padding=dw_padding, stride=stride, groups=mid_chs, bias=False)
self.bn2 = nn.BatchNorm2d(mid_chs, **bn_args)
# Squeeze-and-excitation
if self.has_se:
@ -610,17 +603,16 @@ class InvertedResidual(nn.Module):
mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)), act_fn=act_fn, gate_fn=se_gate_fn)
# Point-wise linear projection
self.conv_pwl = sconv2d(mid_chs, out_chs, 1, padding=pw_padding, groups=pw_group, bias=folded_bn)
self.bn3 = None if folded_bn else nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
self.conv_pwl = sconv2d(mid_chs, out_chs, 1, padding=pw_padding, groups=pw_group, bias=False)
self.bn3 = nn.BatchNorm2d(out_chs, **bn_args)
def forward(self, x):
residual = x
# Point-wise expansion
x = self.conv_pw(x)
if self.bn1 is not None:
x = self.bn1(x)
x = self.act_fn(x)
x = self.bn1(x)
x = self.act_fn(x, inplace=True)
# FIXME haven't tried this yet
# for channel shuffle when using groups with pointwise convs as per FBNet variants
@ -629,9 +621,8 @@ class InvertedResidual(nn.Module):
# Depth-wise convolution
x = self.conv_dw(x)
if self.bn2 is not None:
x = self.bn2(x)
x = self.act_fn(x)
x = self.bn2(x)
x = self.act_fn(x, inplace=True)
# Squeeze-and-excitation
if self.has_se:
@ -639,8 +630,7 @@ class InvertedResidual(nn.Module):
# Point-wise linear projection
x = self.conv_pwl(x)
if self.bn3 is not None:
x = self.bn3(x)
x = self.bn3(x)
if self.has_residual:
if self.drop_connect_rate > 0.:
@ -668,11 +658,9 @@ class GenEfficientNet(nn.Module):
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280,
channel_multiplier=1.0, channel_divisor=8, channel_min=None,
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
drop_rate=0., drop_connect_rate=0., act_fn=F.relu,
se_gate_fn=torch.sigmoid, se_reduce_mid=False,
global_pool='avg', head_conv='default', weight_init='goog',
folded_bn=False, padding_same=False,):
se_gate_fn=sigmoid, se_reduce_mid=False, bn_args=_BN_ARGS_PT,
global_pool='avg', head_conv='default', weight_init='goog', padding_same=False):
super(GenEfficientNet, self).__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
@ -682,14 +670,14 @@ class GenEfficientNet(nn.Module):
stem_size = _round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
self.conv_stem = sconv2d(
in_chans, stem_size, 3,
padding=_padding_arg(1, padding_same), stride=2, bias=folded_bn)
self.bn1 = None if folded_bn else nn.BatchNorm2d(stem_size, momentum=bn_momentum, eps=bn_eps)
padding=_padding_arg(1, padding_same), stride=2, bias=False)
self.bn1 = nn.BatchNorm2d(stem_size, **bn_args)
in_chs = stem_size
builder = _BlockBuilder(
channel_multiplier, channel_divisor, channel_min,
drop_connect_rate, act_fn, se_gate_fn, se_reduce_mid,
bn_momentum, bn_eps, folded_bn, padding_same, verbose=_DEBUG)
bn_args, padding_same, verbose=_DEBUG)
self.blocks = nn.Sequential(*builder(in_chs, block_args))
in_chs = builder.in_chs
@ -701,9 +689,8 @@ class GenEfficientNet(nn.Module):
self.efficient_head = head_conv == 'efficient'
self.conv_head = sconv2d(
in_chs, self.num_features, 1,
padding=_padding_arg(0, padding_same), bias=folded_bn and not self.efficient_head)
self.bn2 = None if (folded_bn or self.efficient_head) else \
nn.BatchNorm2d(self.num_features, momentum=bn_momentum, eps=bn_eps)
padding=_padding_arg(0, padding_same), bias=False)
self.bn2 = None if self.efficient_head else nn.BatchNorm2d(self.num_features, **bn_args)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), self.num_classes)
@ -729,25 +716,23 @@ class GenEfficientNet(nn.Module):
def forward_features(self, x, pool=True):
x = self.conv_stem(x)
if self.bn1 is not None:
x = self.bn1(x)
x = self.act_fn(x)
x = self.bn1(x)
x = self.act_fn(x, inplace=True)
x = self.blocks(x)
if self.efficient_head:
# efficient head, currently only mobilenet-v3 performs pool before last 1x1 conv
x = self.global_pool(x) # always need to pool here regardless of flag
x = self.conv_head(x)
# no BN
x = self.act_fn(x)
x = self.act_fn(x, inplace=True)
if pool:
# expect flattened output if pool is true, otherwise keep dim
x = x.view(x.size(0), -1)
else:
if self.conv_head is not None:
x = self.conv_head(x)
if self.bn2 is not None:
x = self.bn2(x)
x = self.act_fn(x)
x = self.bn2(x)
x = self.act_fn(x, inplace=True)
if pool:
x = self.global_pool(x)
x = x.view(x.size(0), -1)
@ -785,7 +770,6 @@ def _gen_mnasnet_a1(channel_multiplier, num_classes=1000, **kwargs):
# stage 6, 7x7 in
['ir_r1_k3_s1_e6_c320'],
]
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
model = GenEfficientNet(
_decode_arch_def(arch_def),
num_classes=num_classes,
@ -793,8 +777,7 @@ def _gen_mnasnet_a1(channel_multiplier, num_classes=1000, **kwargs):
channel_multiplier=channel_multiplier,
channel_divisor=8,
channel_min=None,
bn_momentum=bn_momentum,
bn_eps=bn_eps,
bn_args=_resolve_bn_args(kwargs),
**kwargs
)
return model
@ -825,7 +808,6 @@ def _gen_mnasnet_b1(channel_multiplier, num_classes=1000, **kwargs):
# stage 6, 7x7 in
['ir_r1_k3_s1_e6_c320_noskip']
]
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
model = GenEfficientNet(
_decode_arch_def(arch_def),
num_classes=num_classes,
@ -833,8 +815,7 @@ def _gen_mnasnet_b1(channel_multiplier, num_classes=1000, **kwargs):
channel_multiplier=channel_multiplier,
channel_divisor=8,
channel_min=None,
bn_momentum=bn_momentum,
bn_eps=bn_eps,
bn_args=_resolve_bn_args(kwargs),
**kwargs
)
return model
@ -858,7 +839,6 @@ def _gen_mnasnet_small(channel_multiplier, num_classes=1000, **kwargs):
['ir_r3_k5_s2_e6_c88_se0.25'],
['ir_r1_k3_s1_e6_c144']
]
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
model = GenEfficientNet(
_decode_arch_def(arch_def),
num_classes=num_classes,
@ -866,8 +846,7 @@ def _gen_mnasnet_small(channel_multiplier, num_classes=1000, **kwargs):
channel_multiplier=channel_multiplier,
channel_divisor=8,
channel_min=None,
bn_momentum=bn_momentum,
bn_eps=bn_eps,
bn_args=_resolve_bn_args(kwargs),
**kwargs
)
return model
@ -885,7 +864,6 @@ def _gen_mobilenet_v1(channel_multiplier, num_classes=1000, **kwargs):
['dsa_r6_k3_s2_c512'],
['dsa_r2_k3_s2_c1024'],
]
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
model = GenEfficientNet(
_decode_arch_def(arch_def),
num_classes=num_classes,
@ -894,8 +872,7 @@ def _gen_mobilenet_v1(channel_multiplier, num_classes=1000, **kwargs):
channel_multiplier=channel_multiplier,
channel_divisor=8,
channel_min=None,
bn_momentum=bn_momentum,
bn_eps=bn_eps,
bn_args=_resolve_bn_args(kwargs),
act_fn=F.relu6,
head_conv='none',
**kwargs
@ -917,7 +894,6 @@ def _gen_mobilenet_v2(channel_multiplier, num_classes=1000, **kwargs):
['ir_r3_k3_s2_e6_c160'],
['ir_r1_k3_s1_e6_c320'],
]
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
model = GenEfficientNet(
_decode_arch_def(arch_def),
num_classes=num_classes,
@ -925,8 +901,7 @@ def _gen_mobilenet_v2(channel_multiplier, num_classes=1000, **kwargs):
channel_multiplier=channel_multiplier,
channel_divisor=8,
channel_min=None,
bn_momentum=bn_momentum,
bn_eps=bn_eps,
bn_args=_resolve_bn_args(kwargs),
act_fn=F.relu6,
**kwargs
)
@ -958,7 +933,6 @@ def _gen_mobilenet_v3(channel_multiplier, num_classes=1000, **kwargs):
# stage 6, 7x7 in
['cn_r1_k1_s1_c960'], # hard-swish
]
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
model = GenEfficientNet(
_decode_arch_def(arch_def),
num_classes=num_classes,
@ -966,8 +940,7 @@ def _gen_mobilenet_v3(channel_multiplier, num_classes=1000, **kwargs):
channel_multiplier=channel_multiplier,
channel_divisor=8,
channel_min=None,
bn_momentum=bn_momentum,
bn_eps=bn_eps,
bn_args=_resolve_bn_args(kwargs),
act_fn=hard_swish,
se_gate_fn=hard_sigmoid,
se_reduce_mid=True,
@ -994,7 +967,6 @@ def _gen_chamnet_v1(channel_multiplier, num_classes=1000, **kwargs):
['ir_r4_k3_s2_e7_c152'],
['ir_r1_k3_s1_e10_c104'],
]
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
model = GenEfficientNet(
_decode_arch_def(arch_def),
num_classes=num_classes,
@ -1003,8 +975,7 @@ def _gen_chamnet_v1(channel_multiplier, num_classes=1000, **kwargs):
channel_multiplier=channel_multiplier,
channel_divisor=8,
channel_min=None,
bn_momentum=bn_momentum,
bn_eps=bn_eps,
bn_args=_resolve_bn_args(kwargs),
**kwargs
)
return model
@ -1027,7 +998,6 @@ def _gen_chamnet_v2(channel_multiplier, num_classes=1000, **kwargs):
['ir_r6_k3_s2_e2_c152'],
['ir_r1_k3_s1_e6_c112'],
]
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
model = GenEfficientNet(
_decode_arch_def(arch_def),
num_classes=num_classes,
@ -1036,8 +1006,7 @@ def _gen_chamnet_v2(channel_multiplier, num_classes=1000, **kwargs):
channel_multiplier=channel_multiplier,
channel_divisor=8,
channel_min=None,
bn_momentum=bn_momentum,
bn_eps=bn_eps,
bn_args=_resolve_bn_args(kwargs),
**kwargs
)
return model
@ -1061,7 +1030,6 @@ def _gen_fbnetc(channel_multiplier, num_classes=1000, **kwargs):
['ir_r4_k5_s2_e6_c184'],
['ir_r1_k3_s1_e6_c352'],
]
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
model = GenEfficientNet(
_decode_arch_def(arch_def),
num_classes=num_classes,
@ -1070,8 +1038,7 @@ def _gen_fbnetc(channel_multiplier, num_classes=1000, **kwargs):
channel_multiplier=channel_multiplier,
channel_divisor=8,
channel_min=None,
bn_momentum=bn_momentum,
bn_eps=bn_eps,
bn_args=_resolve_bn_args(kwargs),
**kwargs
)
return model
@ -1101,7 +1068,6 @@ def _gen_spnasnet(channel_multiplier, num_classes=1000, **kwargs):
# stage 6, 7x7 in
['ir_r1_k3_s1_e6_c320_noskip']
]
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
model = GenEfficientNet(
_decode_arch_def(arch_def),
num_classes=num_classes,
@ -1109,8 +1075,7 @@ def _gen_spnasnet(channel_multiplier, num_classes=1000, **kwargs):
channel_multiplier=channel_multiplier,
channel_divisor=8,
channel_min=None,
bn_momentum=bn_momentum,
bn_eps=bn_eps,
bn_args=_resolve_bn_args(kwargs),
**kwargs
)
return model
@ -1147,7 +1112,6 @@ def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=
['ir_r4_k5_s2_e6_c192_se0.25'],
['ir_r1_k3_s1_e6_c320_se0.25'],
]
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
# NOTE: other models in the family didn't scale the feature count
num_features = _round_channels(1280, channel_multiplier, 8, None)
model = GenEfficientNet(
@ -1158,8 +1122,7 @@ def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=
channel_divisor=8,
channel_min=None,
num_features=num_features,
bn_momentum=bn_momentum,
bn_eps=bn_eps,
bn_args=_resolve_bn_args(kwargs),
act_fn=swish,
**kwargs
)
@ -1205,20 +1168,6 @@ def mnasnet_b1(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
return mnasnet_100(pretrained, num_classes, in_chans, **kwargs)
@register_model
def tflite_mnasnet_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" MNASNet B1, depth multiplier of 1.0. """
default_cfg = default_cfgs['tflite_mnasnet_100']
# these two args are for compat with tflite pretrained weights
kwargs['folded_bn'] = True
kwargs['padding_same'] = True
model = _gen_mnasnet_b1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def mnasnet_140(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" MNASNet B1, depth multiplier of 1.4 """
@ -1269,20 +1218,6 @@ def mnasnet_a1(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
return semnasnet_100(pretrained, num_classes, in_chans, **kwargs)
@register_model
def tflite_semnasnet_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" MNASNet A1, depth multiplier of 1.0. """
default_cfg = default_cfgs['tflite_semnasnet_100']
# these two args are for compat with tflite pretrained weights
kwargs['folded_bn'] = True
kwargs['padding_same'] = True
model = _gen_mnasnet_a1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def semnasnet_140(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" MNASNet A1 (w/ SE), depth multiplier of 1.4. """

Loading…
Cancel
Save