Finally got around to adding EdgeTPU EfficientNet variant

pull/30/head
Ross Wightman 5 years ago
parent daeaa113e2
commit 9ec6824bab

@ -88,6 +88,14 @@ default_cfgs = {
url='', input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942),
'efficientnet_b7': _cfg(
url='', input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
'efficientnet_es': _cfg(
url=''),
'efficientnet_em': _cfg(
url='',
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
'efficientnet_el': _cfg(
url='',
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
'tf_efficientnet_b0': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
input_size=(3, 224, 224)),
@ -112,6 +120,18 @@ default_cfgs = {
'tf_efficientnet_b7': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_aa-076e3472.pth',
input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
'tf_efficientnet_es': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 224, 224), ),
'tf_efficientnet_em': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_em-e78cfe58.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
'tf_efficientnet_el': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_el-5143854e.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
'mixnet_s': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth'),
'mixnet_m': _cfg(
@ -239,6 +259,7 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
act_fn = options['n'] if 'n' in options else None
exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
num_repeat = int(options['r'])
# each type of block has different valid arguments, fill accordingly
@ -267,6 +288,19 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
pw_act=block_type == 'dsa',
noskip=block_type == 'dsa' or noskip,
)
elif block_type == 'er':
block_args = dict(
block_type=block_type,
exp_kernel_size=_parse_ksize(options['k']),
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
exp_ratio=float(options['e']),
fake_in_chs=fake_in_chs,
se_ratio=float(options['se']) if 'se' in options else None,
stride=int(options['s']),
act_fn=act_fn,
noskip=noskip,
)
elif block_type == 'cn':
block_args = dict(
block_type=block_type,
@ -356,6 +390,9 @@ class _BlockBuilder:
bt = ba.pop('block_type')
ba['in_chs'] = self.in_chs
ba['out_chs'] = self._round_channels(ba['out_chs'])
if 'fake_in_chs' in ba and ba['fake_in_chs']:
# FIXME this is a hack to work around mismatch in origin impl input filters
ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs'])
ba['bn_args'] = self.bn_args
ba['pad_type'] = self.pad_type
# block act fn overrides the model default
@ -373,6 +410,13 @@ class _BlockBuilder:
if self.verbose:
logging.info(' DepthwiseSeparable {}, Args: {}'.format(self.block_idx, str(ba)))
block = DepthwiseSeparableConv(**ba)
elif bt == 'er':
ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
ba['se_gate_fn'] = self.se_gate_fn
ba['se_reduce_mid'] = self.se_reduce_mid
if self.verbose:
logging.info(' EdgeResidual {}, Args: {}'.format(self.block_idx, str(ba)))
block = EdgeResidual(**ba)
elif bt == 'cn':
if self.verbose:
logging.info(' ConvBnAct {}, Args: {}'.format(self.block_idx, str(ba)))
@ -519,10 +563,62 @@ class ConvBnAct(nn.Module):
return x
class EdgeResidual(nn.Module):
""" Residual block with expansion convolution followed by pointwise-linear w/ stride"""
def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0,
stride=1, pad_type='', act_fn=F.relu, noskip=False, pw_kernel_size=1,
se_ratio=0., se_reduce_mid=False, se_gate_fn=sigmoid,
bn_args=_BN_ARGS_PT, drop_connect_rate=0.):
super(EdgeResidual, self).__init__()
mid_chs = int(fake_in_chs * exp_ratio) if fake_in_chs > 0 else int(in_chs * exp_ratio)
self.has_se = se_ratio is not None and se_ratio > 0.
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.act_fn = act_fn
self.drop_connect_rate = drop_connect_rate
# Expansion convolution
self.conv_exp = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type)
self.bn1 = nn.BatchNorm2d(mid_chs, **bn_args)
# Squeeze-and-excitation
if self.has_se:
se_base_chs = mid_chs if se_reduce_mid else in_chs
self.se = SqueezeExcite(
mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)), act_fn=act_fn, gate_fn=se_gate_fn)
# Point-wise linear projection
self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, stride=stride, padding=pad_type)
self.bn2 = nn.BatchNorm2d(out_chs, **bn_args)
def forward(self, x):
residual = x
# Expansion convolution
x = self.conv_exp(x)
x = self.bn1(x)
x = self.act_fn(x, inplace=True)
# Squeeze-and-excitation
if self.has_se:
x = self.se(x)
# Point-wise linear projection
x = self.conv_pwl(x)
x = self.bn2(x)
if self.has_residual:
if self.drop_connect_rate > 0.:
x = drop_connect(x, self.training, self.drop_connect_rate)
x += residual
return x
class DepthwiseSeparableConv(nn.Module):
""" DepthwiseSeparable block
Used for DS convs in MobileNet-V1 and in the place of IR blocks with an expansion
factor of 1.0. This is an alternative to having a IR with optional first pw conv.
factor of 1.0. This is an alternative to having a IR with an optional first pw conv.
"""
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
stride=1, pad_type='', act_fn=F.relu, noskip=False,
@ -1092,7 +1188,6 @@ def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=
['ir_r4_k5_s2_e6_c192_se0.25'],
['ir_r1_k3_s1_e6_c320_se0.25'],
]
# NOTE: other models in the family didn't scale the feature count
num_features = _round_channels(1280, channel_multiplier, 8, None)
model = GenEfficientNet(
_decode_arch_def(arch_def, depth_multiplier),
@ -1107,6 +1202,31 @@ def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=
return model
def _gen_efficientnet_edge(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=1000, **kwargs):
arch_def = [
# NOTE `fc` is present to override a mismatch between stem channels and in chs not
# present in other models
['er_r1_k3_s1_e4_c24_fc24_noskip'],
['er_r2_k3_s2_e8_c32'],
['er_r4_k3_s2_e8_c48'],
['ir_r5_k5_s2_e8_c96'],
['ir_r4_k5_s1_e8_c144'],
['ir_r2_k5_s2_e8_c192'],
]
num_features = _round_channels(1280, channel_multiplier, 8, None)
model = GenEfficientNet(
_decode_arch_def(arch_def, depth_multiplier),
num_classes=num_classes,
stem_size=32,
channel_multiplier=channel_multiplier,
num_features=num_features,
bn_args=_resolve_bn_args(kwargs),
act_fn=F.relu,
**kwargs
)
return model
def _gen_mixnet_s(channel_multiplier=1.0, num_classes=1000, **kwargs):
"""Creates a MixNet Small model.
@ -1481,7 +1601,6 @@ def efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
return model
@register_model
def efficientnet_b6(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" EfficientNet-B6 """
@ -1512,6 +1631,45 @@ def efficientnet_b7(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
return model
@register_model
def efficientnet_es(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" EfficientNet-Edge Small. """
default_cfg = default_cfgs['efficientnet_es']
model = _gen_efficientnet_edge(
channel_multiplier=1.0, depth_multiplier=1.0,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def efficientnet_em(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" EfficientNet-Edge-Medium. """
default_cfg = default_cfgs['efficientnet_em']
model = _gen_efficientnet_edge(
channel_multiplier=1.0, depth_multiplier=1.1,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def efficientnet_el(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" EfficientNet-Edge-Large. """
default_cfg = default_cfgs['efficientnet_el']
model = _gen_efficientnet_edge(
channel_multiplier=1.2, depth_multiplier=1.4,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def tf_efficientnet_b0(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" EfficientNet-B0. Tensorflow compatible variant """
@ -1634,6 +1792,51 @@ def tf_efficientnet_b7(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
return model
@register_model
def tf_efficientnet_es(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" EfficientNet-Edge Small. Tensorflow compatible variant """
default_cfg = default_cfgs['tf_efficientnet_es']
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet_edge(
channel_multiplier=1.0, depth_multiplier=1.0,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def tf_efficientnet_em(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" EfficientNet-Edge-Medium. Tensorflow compatible variant """
default_cfg = default_cfgs['tf_efficientnet_em']
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet_edge(
channel_multiplier=1.0, depth_multiplier=1.1,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def tf_efficientnet_el(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" EfficientNet-Edge-Large. Tensorflow compatible variant """
default_cfg = default_cfgs['tf_efficientnet_el']
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet_edge(
channel_multiplier=1.2, depth_multiplier=1.4,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def mixnet_s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Creates a MixNet Small model.

Loading…
Cancel
Save