|
|
|
@ -88,6 +88,14 @@ default_cfgs = {
|
|
|
|
|
url='', input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942),
|
|
|
|
|
'efficientnet_b7': _cfg(
|
|
|
|
|
url='', input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
|
|
|
|
|
'efficientnet_es': _cfg(
|
|
|
|
|
url=''),
|
|
|
|
|
'efficientnet_em': _cfg(
|
|
|
|
|
url='',
|
|
|
|
|
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
|
|
|
|
|
'efficientnet_el': _cfg(
|
|
|
|
|
url='',
|
|
|
|
|
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
|
|
|
|
|
'tf_efficientnet_b0': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
|
|
|
|
|
input_size=(3, 224, 224)),
|
|
|
|
@ -112,6 +120,18 @@ default_cfgs = {
|
|
|
|
|
'tf_efficientnet_b7': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_aa-076e3472.pth',
|
|
|
|
|
input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
|
|
|
|
|
'tf_efficientnet_es': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth',
|
|
|
|
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
|
|
|
|
input_size=(3, 224, 224), ),
|
|
|
|
|
'tf_efficientnet_em': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_em-e78cfe58.pth',
|
|
|
|
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
|
|
|
|
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
|
|
|
|
|
'tf_efficientnet_el': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_el-5143854e.pth',
|
|
|
|
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
|
|
|
|
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
|
|
|
|
|
'mixnet_s': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth'),
|
|
|
|
|
'mixnet_m': _cfg(
|
|
|
|
@ -239,6 +259,7 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
|
|
|
|
|
act_fn = options['n'] if 'n' in options else None
|
|
|
|
|
exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
|
|
|
|
|
pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
|
|
|
|
|
fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
|
|
|
|
|
|
|
|
|
|
num_repeat = int(options['r'])
|
|
|
|
|
# each type of block has different valid arguments, fill accordingly
|
|
|
|
@ -267,6 +288,19 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
|
|
|
|
|
pw_act=block_type == 'dsa',
|
|
|
|
|
noskip=block_type == 'dsa' or noskip,
|
|
|
|
|
)
|
|
|
|
|
elif block_type == 'er':
|
|
|
|
|
block_args = dict(
|
|
|
|
|
block_type=block_type,
|
|
|
|
|
exp_kernel_size=_parse_ksize(options['k']),
|
|
|
|
|
pw_kernel_size=pw_kernel_size,
|
|
|
|
|
out_chs=int(options['c']),
|
|
|
|
|
exp_ratio=float(options['e']),
|
|
|
|
|
fake_in_chs=fake_in_chs,
|
|
|
|
|
se_ratio=float(options['se']) if 'se' in options else None,
|
|
|
|
|
stride=int(options['s']),
|
|
|
|
|
act_fn=act_fn,
|
|
|
|
|
noskip=noskip,
|
|
|
|
|
)
|
|
|
|
|
elif block_type == 'cn':
|
|
|
|
|
block_args = dict(
|
|
|
|
|
block_type=block_type,
|
|
|
|
@ -356,6 +390,9 @@ class _BlockBuilder:
|
|
|
|
|
bt = ba.pop('block_type')
|
|
|
|
|
ba['in_chs'] = self.in_chs
|
|
|
|
|
ba['out_chs'] = self._round_channels(ba['out_chs'])
|
|
|
|
|
if 'fake_in_chs' in ba and ba['fake_in_chs']:
|
|
|
|
|
# FIXME this is a hack to work around mismatch in origin impl input filters
|
|
|
|
|
ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs'])
|
|
|
|
|
ba['bn_args'] = self.bn_args
|
|
|
|
|
ba['pad_type'] = self.pad_type
|
|
|
|
|
# block act fn overrides the model default
|
|
|
|
@ -373,6 +410,13 @@ class _BlockBuilder:
|
|
|
|
|
if self.verbose:
|
|
|
|
|
logging.info(' DepthwiseSeparable {}, Args: {}'.format(self.block_idx, str(ba)))
|
|
|
|
|
block = DepthwiseSeparableConv(**ba)
|
|
|
|
|
elif bt == 'er':
|
|
|
|
|
ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
|
|
|
|
|
ba['se_gate_fn'] = self.se_gate_fn
|
|
|
|
|
ba['se_reduce_mid'] = self.se_reduce_mid
|
|
|
|
|
if self.verbose:
|
|
|
|
|
logging.info(' EdgeResidual {}, Args: {}'.format(self.block_idx, str(ba)))
|
|
|
|
|
block = EdgeResidual(**ba)
|
|
|
|
|
elif bt == 'cn':
|
|
|
|
|
if self.verbose:
|
|
|
|
|
logging.info(' ConvBnAct {}, Args: {}'.format(self.block_idx, str(ba)))
|
|
|
|
@ -519,10 +563,62 @@ class ConvBnAct(nn.Module):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EdgeResidual(nn.Module):
|
|
|
|
|
""" Residual block with expansion convolution followed by pointwise-linear w/ stride"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0,
|
|
|
|
|
stride=1, pad_type='', act_fn=F.relu, noskip=False, pw_kernel_size=1,
|
|
|
|
|
se_ratio=0., se_reduce_mid=False, se_gate_fn=sigmoid,
|
|
|
|
|
bn_args=_BN_ARGS_PT, drop_connect_rate=0.):
|
|
|
|
|
super(EdgeResidual, self).__init__()
|
|
|
|
|
mid_chs = int(fake_in_chs * exp_ratio) if fake_in_chs > 0 else int(in_chs * exp_ratio)
|
|
|
|
|
self.has_se = se_ratio is not None and se_ratio > 0.
|
|
|
|
|
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
|
|
|
|
|
self.act_fn = act_fn
|
|
|
|
|
self.drop_connect_rate = drop_connect_rate
|
|
|
|
|
|
|
|
|
|
# Expansion convolution
|
|
|
|
|
self.conv_exp = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type)
|
|
|
|
|
self.bn1 = nn.BatchNorm2d(mid_chs, **bn_args)
|
|
|
|
|
|
|
|
|
|
# Squeeze-and-excitation
|
|
|
|
|
if self.has_se:
|
|
|
|
|
se_base_chs = mid_chs if se_reduce_mid else in_chs
|
|
|
|
|
self.se = SqueezeExcite(
|
|
|
|
|
mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)), act_fn=act_fn, gate_fn=se_gate_fn)
|
|
|
|
|
|
|
|
|
|
# Point-wise linear projection
|
|
|
|
|
self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, stride=stride, padding=pad_type)
|
|
|
|
|
self.bn2 = nn.BatchNorm2d(out_chs, **bn_args)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
residual = x
|
|
|
|
|
|
|
|
|
|
# Expansion convolution
|
|
|
|
|
x = self.conv_exp(x)
|
|
|
|
|
x = self.bn1(x)
|
|
|
|
|
x = self.act_fn(x, inplace=True)
|
|
|
|
|
|
|
|
|
|
# Squeeze-and-excitation
|
|
|
|
|
if self.has_se:
|
|
|
|
|
x = self.se(x)
|
|
|
|
|
|
|
|
|
|
# Point-wise linear projection
|
|
|
|
|
x = self.conv_pwl(x)
|
|
|
|
|
x = self.bn2(x)
|
|
|
|
|
|
|
|
|
|
if self.has_residual:
|
|
|
|
|
if self.drop_connect_rate > 0.:
|
|
|
|
|
x = drop_connect(x, self.training, self.drop_connect_rate)
|
|
|
|
|
x += residual
|
|
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DepthwiseSeparableConv(nn.Module):
|
|
|
|
|
""" DepthwiseSeparable block
|
|
|
|
|
Used for DS convs in MobileNet-V1 and in the place of IR blocks with an expansion
|
|
|
|
|
factor of 1.0. This is an alternative to having a IR with optional first pw conv.
|
|
|
|
|
factor of 1.0. This is an alternative to having a IR with an optional first pw conv.
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
|
|
|
|
stride=1, pad_type='', act_fn=F.relu, noskip=False,
|
|
|
|
@ -1092,7 +1188,6 @@ def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=
|
|
|
|
|
['ir_r4_k5_s2_e6_c192_se0.25'],
|
|
|
|
|
['ir_r1_k3_s1_e6_c320_se0.25'],
|
|
|
|
|
]
|
|
|
|
|
# NOTE: other models in the family didn't scale the feature count
|
|
|
|
|
num_features = _round_channels(1280, channel_multiplier, 8, None)
|
|
|
|
|
model = GenEfficientNet(
|
|
|
|
|
_decode_arch_def(arch_def, depth_multiplier),
|
|
|
|
@ -1107,6 +1202,31 @@ def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _gen_efficientnet_edge(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=1000, **kwargs):
|
|
|
|
|
arch_def = [
|
|
|
|
|
# NOTE `fc` is present to override a mismatch between stem channels and in chs not
|
|
|
|
|
# present in other models
|
|
|
|
|
['er_r1_k3_s1_e4_c24_fc24_noskip'],
|
|
|
|
|
['er_r2_k3_s2_e8_c32'],
|
|
|
|
|
['er_r4_k3_s2_e8_c48'],
|
|
|
|
|
['ir_r5_k5_s2_e8_c96'],
|
|
|
|
|
['ir_r4_k5_s1_e8_c144'],
|
|
|
|
|
['ir_r2_k5_s2_e8_c192'],
|
|
|
|
|
]
|
|
|
|
|
num_features = _round_channels(1280, channel_multiplier, 8, None)
|
|
|
|
|
model = GenEfficientNet(
|
|
|
|
|
_decode_arch_def(arch_def, depth_multiplier),
|
|
|
|
|
num_classes=num_classes,
|
|
|
|
|
stem_size=32,
|
|
|
|
|
channel_multiplier=channel_multiplier,
|
|
|
|
|
num_features=num_features,
|
|
|
|
|
bn_args=_resolve_bn_args(kwargs),
|
|
|
|
|
act_fn=F.relu,
|
|
|
|
|
**kwargs
|
|
|
|
|
)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _gen_mixnet_s(channel_multiplier=1.0, num_classes=1000, **kwargs):
|
|
|
|
|
"""Creates a MixNet Small model.
|
|
|
|
|
|
|
|
|
@ -1481,7 +1601,6 @@ def efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def efficientnet_b6(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|
|
|
|
""" EfficientNet-B6 """
|
|
|
|
@ -1512,6 +1631,45 @@ def efficientnet_b7(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def efficientnet_es(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|
|
|
|
""" EfficientNet-Edge Small. """
|
|
|
|
|
default_cfg = default_cfgs['efficientnet_es']
|
|
|
|
|
model = _gen_efficientnet_edge(
|
|
|
|
|
channel_multiplier=1.0, depth_multiplier=1.0,
|
|
|
|
|
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
|
|
|
model.default_cfg = default_cfg
|
|
|
|
|
if pretrained:
|
|
|
|
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def efficientnet_em(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|
|
|
|
""" EfficientNet-Edge-Medium. """
|
|
|
|
|
default_cfg = default_cfgs['efficientnet_em']
|
|
|
|
|
model = _gen_efficientnet_edge(
|
|
|
|
|
channel_multiplier=1.0, depth_multiplier=1.1,
|
|
|
|
|
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
|
|
|
model.default_cfg = default_cfg
|
|
|
|
|
if pretrained:
|
|
|
|
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def efficientnet_el(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|
|
|
|
""" EfficientNet-Edge-Large. """
|
|
|
|
|
default_cfg = default_cfgs['efficientnet_el']
|
|
|
|
|
model = _gen_efficientnet_edge(
|
|
|
|
|
channel_multiplier=1.2, depth_multiplier=1.4,
|
|
|
|
|
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
|
|
|
model.default_cfg = default_cfg
|
|
|
|
|
if pretrained:
|
|
|
|
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def tf_efficientnet_b0(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|
|
|
|
""" EfficientNet-B0. Tensorflow compatible variant """
|
|
|
|
@ -1634,6 +1792,51 @@ def tf_efficientnet_b7(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def tf_efficientnet_es(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|
|
|
|
""" EfficientNet-Edge Small. Tensorflow compatible variant """
|
|
|
|
|
default_cfg = default_cfgs['tf_efficientnet_es']
|
|
|
|
|
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
|
|
|
|
|
kwargs['pad_type'] = 'same'
|
|
|
|
|
model = _gen_efficientnet_edge(
|
|
|
|
|
channel_multiplier=1.0, depth_multiplier=1.0,
|
|
|
|
|
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
|
|
|
model.default_cfg = default_cfg
|
|
|
|
|
if pretrained:
|
|
|
|
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def tf_efficientnet_em(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|
|
|
|
""" EfficientNet-Edge-Medium. Tensorflow compatible variant """
|
|
|
|
|
default_cfg = default_cfgs['tf_efficientnet_em']
|
|
|
|
|
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
|
|
|
|
|
kwargs['pad_type'] = 'same'
|
|
|
|
|
model = _gen_efficientnet_edge(
|
|
|
|
|
channel_multiplier=1.0, depth_multiplier=1.1,
|
|
|
|
|
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
|
|
|
model.default_cfg = default_cfg
|
|
|
|
|
if pretrained:
|
|
|
|
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def tf_efficientnet_el(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|
|
|
|
""" EfficientNet-Edge-Large. Tensorflow compatible variant """
|
|
|
|
|
default_cfg = default_cfgs['tf_efficientnet_el']
|
|
|
|
|
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
|
|
|
|
|
kwargs['pad_type'] = 'same'
|
|
|
|
|
model = _gen_efficientnet_edge(
|
|
|
|
|
channel_multiplier=1.2, depth_multiplier=1.4,
|
|
|
|
|
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
|
|
|
model.default_cfg = default_cfg
|
|
|
|
|
if pretrained:
|
|
|
|
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def mixnet_s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|
|
|
|
"""Creates a MixNet Small model.
|
|
|
|
|