Experimenting with XSepConv w/ EfficientNet and MobileNetV3...

xsepconv
Ross Wightman 4 years ago
parent 8bd08d5f64
commit 9419577c90

@ -102,6 +102,8 @@ default_cfgs = {
'efficientnet_eca_b2': _cfg(
url='',
input_size=(3, 260, 260), pool_size=(9, 9)),
'xefficientnet_b0': _cfg(
url=''),
'efficientnet_es': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth'),
'efficientnet_em': _cfg(
@ -242,7 +244,7 @@ default_cfgs = {
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth'),
}
_DEBUG = False
_DEBUG = True
class EfficientNet(nn.Module):
@ -658,6 +660,54 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
return model
def _gen_xefficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
"""Creates an EfficientNet model.
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
Paper: https://arxiv.org/abs/1905.11946
EfficientNet params
name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
'efficientnet-b0': (1.0, 1.0, 224, 0.2),
'efficientnet-b1': (1.0, 1.1, 240, 0.2),
'efficientnet-b2': (1.1, 1.2, 260, 0.3),
'efficientnet-b3': (1.2, 1.4, 300, 0.3),
'efficientnet-b4': (1.4, 1.8, 380, 0.4),
'efficientnet-b5': (1.6, 2.2, 456, 0.4),
'efficientnet-b6': (1.8, 2.6, 528, 0.5),
'efficientnet-b7': (2.0, 3.1, 600, 0.5),
'efficientnet-b8': (2.2, 3.6, 672, 0.5),
'efficientnet-l2': (4.3, 5.3, 800, 0.5),
Args:
channel_multiplier: multiplier to number of channels per layer
depth_multiplier: multiplier to number of repeats per stage
"""
arch_def = [
['ds_r1_k3_s1_e1_c16'],
['ir_r2_k3_s2_e6_c24'],
['ir_r1_k5_s2_e6_c40', 'xir_r1_k5_s1_e6_c40'],
['ir_r3_k3_s2_e6_c80'],
['xir_r3_k5_s1_e6_c112'],
['ir_r1_k5_s2_e6_c192', 'xir_r3_k5_s1_e6_c192'],
['xir_r1_k5_s1_e6_c320'],
]
model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier),
num_features=round_channels(1280, channel_multiplier, 8, None),
stem_size=32,
channel_multiplier=channel_multiplier,
act_layer=Swish,
attn_layer='sev2',
attn_kwargs=dict(se_ratio=0.25),
norm_kwargs=resolve_bn_args(kwargs),
**kwargs,
)
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
return model
def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
""" Creates an EfficientNet-EdgeTPU model
@ -1064,6 +1114,15 @@ def efficientnet_eca_b2(pretrained=False, **kwargs):
return model
@register_model
def xefficientnet_b0(pretrained=False, **kwargs):
""" XEfficientNet-B0 """
# NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
model = _gen_xefficientnet(
'xefficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
return model
@register_model
def efficientnet_es(pretrained=False, **kwargs):
""" EfficientNet-Edge Small. """

@ -244,6 +244,161 @@ class InvertedResidual(nn.Module):
return x
class XDepthwiseSeparableConv(nn.Module):
""" DepthwiseSeparable block
Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion
(factor of 1.0). This is an alternative to having a IR with an optional first pw conv.
"""
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
pw_kernel_size=1, pw_act=False, attn_layer=None, attn_kwargs=None,
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0.):
super(XDepthwiseSeparableConv, self).__init__()
norm_kwargs = norm_kwargs or {}
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
self.has_pw_act = pw_act # activation after point-wise conv
self.drop_path_rate = drop_path_rate
conv_kwargs = {}
self.conv_dw_2x2 = create_conv2d(
in_chs, in_chs, 2, stride=stride, dilation=dilation,
padding='same', depthwise=True, **conv_kwargs)
self.conv_dw_1xk = create_conv2d(
in_chs, in_chs, (1, dw_kernel_size), stride=stride, dilation=dilation,
padding=pad_type, depthwise=True, **conv_kwargs)
self.conv_dw_kx1 = create_conv2d(
in_chs, in_chs, (dw_kernel_size, 1), stride=stride, dilation=dilation,
padding=pad_type, depthwise=True, **conv_kwargs)
self.bn1 = norm_layer(in_chs, **norm_kwargs)
self.act1 = act_layer(inplace=True)
# Attention block (Squeeze-Excitation, ECA, etc)
if attn_layer is not None:
attn_kwargs = resolve_attn_args(attn_layer, attn_kwargs, in_chs, act_layer)
self.se = create_attn(attn_layer, in_chs, **attn_kwargs)
else:
self.se = None
self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
self.bn2 = norm_layer(out_chs, **norm_kwargs)
self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity()
def feature_module(self, location):
# no expansion in this block, pre pw only feature extraction point
return 'conv_pw'
def feature_channels(self, location):
return self.conv_pw.in_channels
def forward(self, x):
residual = x
x = self.conv_dw_2x2(x)
x = self.conv_dw_1xk(x)
x = self.conv_dw_kx1(x)
x = self.bn1(x)
x = self.act1(x)
if self.se is not None:
x = self.se(x)
x = self.conv_pw(x)
x = self.bn2(x)
x = self.act2(x)
if self.has_residual:
if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training)
x += residual
return x
class XInvertedResidual(nn.Module):
""" Inverted residual block w/ optional SE and CondConv routing"""
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, pad_shift=0,
attn_layer=None, attn_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
conv_kwargs=None, drop_path_rate=0.):
super(XInvertedResidual, self).__init__()
norm_kwargs = norm_kwargs or {}
conv_kwargs = conv_kwargs or {}
mid_chs = make_divisible(in_chs * exp_ratio)
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.drop_path_rate = drop_path_rate
# Point-wise expansion
self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
self.act1 = act_layer(inplace=True)
# Depth-wise convolution
self.conv_dw_2x2 = create_conv2d(
mid_chs, mid_chs, 2, stride=stride, dilation=dilation,
padding='same', depthwise=True, pad_shift=pad_shift, **conv_kwargs)
self.conv_dw_1xk = create_conv2d(
mid_chs, mid_chs, (1, dw_kernel_size), stride=stride, dilation=dilation,
padding=pad_type, depthwise=True, **conv_kwargs)
self.conv_dw_kx1 = create_conv2d(
mid_chs, mid_chs, (dw_kernel_size, 1), stride=stride, dilation=dilation,
padding=pad_type, depthwise=True, **conv_kwargs)
self.bn2 = norm_layer(mid_chs, **norm_kwargs)
self.act2 = act_layer(inplace=True)
# Attention block (Squeeze-Excitation, ECA, etc)
if attn_layer is not None:
attn_kwargs = resolve_attn_args(attn_layer, attn_kwargs, in_chs, act_layer)
self.se = create_attn(attn_layer, mid_chs, **attn_kwargs)
else:
self.se = None
# Point-wise linear projection
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
self.bn3 = norm_layer(out_chs, **norm_kwargs)
def feature_module(self, location):
if location == 'post_exp':
return 'act1'
return 'conv_pwl'
def feature_channels(self, location):
if location == 'post_exp':
return self.conv_pw.out_channels
# location == 'pre_pw'
return self.conv_pwl.in_channels
def forward(self, x):
residual = x
# Point-wise expansion
x = self.conv_pw(x)
x = self.bn1(x)
x = self.act1(x)
# Depth-wise convolution
x = self.conv_dw_2x2(x)
x = self.conv_dw_1xk(x)
x = self.conv_dw_kx1(x)
x = self.bn2(x)
x = self.act2(x)
# Attention
if self.se is not None:
x = self.se(x)
# Point-wise linear projection
x = self.conv_pwl(x)
x = self.bn3(x)
if self.has_residual:
if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training)
x += residual
return x
class CondConvResidual(InvertedResidual):
""" Inverted residual block w/ CondConv routing"""

@ -90,7 +90,7 @@ def _decode_block_str(block_str):
num_repeat = int(options['r'])
# each type of block has different valid arguments, fill accordingly
if block_type == 'ir':
if block_type == 'ir' or block_type == 'xir':
block_args = dict(
block_type=block_type,
dw_kernel_size=_parse_ksize(options['k']),
@ -106,7 +106,7 @@ def _decode_block_str(block_str):
)
if 'cc' in options:
block_args['num_experts'] = int(options['cc'])
elif block_type == 'ds' or block_type == 'dsa':
elif block_type == 'ds' or block_type == 'dsa' or block_type == 'xds':
block_args = dict(
block_type=block_type,
dw_kernel_size=_parse_ksize(options['k']),
@ -232,6 +232,7 @@ class EfficientNetBuilder:
# state updated during build, consumed by model
self.in_chs = None
self.x_count = 0
self.features = OrderedDict()
def _round_channels(self, chs):
@ -261,33 +262,35 @@ class EfficientNetBuilder:
ba['attn_kwargs'] = self.attn_kwargs
else:
ba['attn_kwargs'].update(self.attn_kwargs)
ba['drop_path_rate'] = drop_path_rate
if bt == 'ir':
ba['drop_path_rate'] = drop_path_rate
if self.verbose:
logging.info(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)))
if ba.get('num_experts', 0) > 0:
block = CondConvResidual(**ba)
else:
block = InvertedResidual(**ba)
elif bt == 'xir':
ba['pad_shift'] = self.x_count
block = XInvertedResidual(**ba)
self.x_count = (self.x_count + 1) % 4
elif bt == 'ds' or bt == 'dsa':
ba['drop_path_rate'] = drop_path_rate
if self.verbose:
logging.info(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)))
block = DepthwiseSeparableConv(**ba)
elif bt == 'xds':
ba['pad_shift'] = self.x_count
block = XDepthwiseSeparableConv(**ba)
self.x_count = (self.x_count + 1) % 4
elif bt == 'er':
ba['drop_path_rate'] = drop_path_rate
if self.verbose:
logging.info(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)))
block = EdgeResidual(**ba)
elif bt == 'cn':
if self.verbose:
logging.info(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)))
del ba['drop_path_rate']
block = ConvBnAct(**ba)
else:
assert False, 'Uknkown block type (%s) while building model.' % bt
self.in_chs = ba['out_chs'] # update in_chs for arg of next block
if self.verbose:
logging.info(' {} {}, Args: {}'.format(block.__class__.__name__, block_idx, str(ba)))
return block
def __call__(self, in_chs, model_block_args):

@ -13,8 +13,8 @@ from .padding import get_padding, pad_same, is_static_pad
def conv2d_same(
x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1):
x = pad_same(x, weight.shape[-2:], stride, dilation)
padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1, pad_shift: int = 0):
x = pad_same(x, weight.shape[-2:], stride, dilation, pad_shift)
return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
@ -23,12 +23,14 @@ class Conv2dSame(nn.Conv2d):
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
padding=0, dilation=1, groups=1, bias=True, pad_shift=0):
super(Conv2dSame, self).__init__(
in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
self.pad_shift = pad_shift
def forward(self, x):
return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
return conv2d_same(
x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups, self.pad_shift)
def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:

@ -7,7 +7,7 @@ from torch._six import container_abcs
# From PyTorch internals
def _ntuple(n):
def ntuple(n):
def parse(x):
if isinstance(x, container_abcs.Iterable):
return x
@ -15,10 +15,10 @@ def _ntuple(n):
return parse
tup_single = _ntuple(1)
tup_pair = _ntuple(2)
tup_triple = _ntuple(3)
tup_quadruple = _ntuple(4)
tup_single = ntuple(1)
tup_pair = ntuple(2)
tup_triple = ntuple(3)
tup_quadruple = ntuple(4)
def make_divisible(v, divisor=8, min_value=None):

@ -4,12 +4,17 @@ Hacked together by Ross Wightman
"""
import math
from typing import List
from .helpers import ntuple
import torch.nn.functional as F
# Calculate symmetric padding for a convolution
def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
if isinstance(kernel_size, (list, tuple)):
stride = ntuple(len(kernel_size))(stride)
dilation = ntuple(len(kernel_size))(dilation)
return [get_padding(k, s, d) for k, s, d in zip(kernel_size, stride, dilation)]
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
return padding
@ -25,9 +30,17 @@ def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
# Dynamically pad input x with 'SAME' padding for conv with specified args
def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1)):
def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), shift: int = 0):
ih, iw = x.size()[-2:]
pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1])
if pad_h > 0 or pad_w > 0:
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
if shift == 0:
pl = [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] # ul
elif shift == 1:
pl = [pad_w - pad_w // 2, pad_w // 2, pad_h - pad_h // 2, pad_h // 2] # lr
elif shift == 2:
pl = [pad_w - pad_w // 2, pad_w // 2, pad_h // 2, pad_h - pad_h // 2] # ur
else:
pl = [pad_w // 2, pad_w - pad_w // 2, pad_h - pad_h // 2, pad_h // 2] # ll
x = F.pad(x, pl)
return x

@ -35,6 +35,7 @@ default_cfgs = {
'mobilenetv3_small_075': _cfg(url='', interoplation='bicubic'),
'mobilenetv3_small_100': _cfg(url='', interoplation='bicubic'),
'mobilenetv3_eca_large': _cfg(url='', interoplation='bicubic'),
'xmobilenetv3_large_100': _cfg(url='', interoplation='bicubic'),
'mobilenetv3_rw': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth',
interpolation='bicubic'),
@ -58,7 +59,7 @@ default_cfgs = {
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
}
_DEBUG = False
_DEBUG = True
class MobileNetV3(nn.Module):
@ -360,6 +361,102 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
return model
def _gen_xmobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
"""Creates a MobileNet-V3 model.
Ref impl: ?
Paper: https://arxiv.org/abs/1905.02244
Args:
channel_multiplier: multiplier to number of channels per layer.
"""
if 'small' in variant:
num_features = 1024
if 'minimal' in variant:
act_layer = nn.ReLU
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s2_e1_c16'],
# stage 1, 56x56 in
['ir_r1_k3_s2_e4.5_c24', 'ir_r1_k3_s1_e3.67_c24'],
# stage 2, 28x28 in
['ir_r1_k3_s2_e4_c40', 'ir_r2_k3_s1_e6_c40'],
# stage 3, 14x14 in
['ir_r2_k3_s1_e3_c48'],
# stage 4, 14x14in
['ir_r3_k3_s2_e6_c96'],
# stage 6, 7x7 in
['cn_r1_k1_s1_c576'],
]
else:
act_layer = HardSwish
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s2_e1_c16_se0.25_nre'], # relu
# stage 1, 56x56 in
['ir_r1_k3_s2_e4.5_c24_nre', 'ir_r1_k3_s1_e3.67_c24_nre'], # relu
# stage 2, 28x28 in
['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r2_k5_s1_e6_c40_se0.25'], # hard-swish
# stage 3, 14x14 in
['ir_r2_k5_s1_e3_c48_se0.25'], # hard-swish
# stage 4, 14x14in
['ir_r3_k5_s2_e6_c96_se0.25'], # hard-swish
# stage 6, 7x7 in
['cn_r1_k1_s1_c576'], # hard-swish
]
else:
num_features = 1280
if 'minimal' in variant:
act_layer = nn.ReLU
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s1_e1_c16'],
# stage 1, 112x112 in
['ir_r1_k3_s2_e4_c24', 'ir_r1_k3_s1_e3_c24'],
# stage 2, 56x56 in
['ir_r3_k3_s2_e3_c40'],
# stage 3, 28x28 in
['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'],
# stage 4, 14x14in
['ir_r2_k3_s1_e6_c112'],
# stage 5, 14x14in
['ir_r3_k3_s2_e6_c160'],
# stage 6, 7x7 in
['cn_r1_k1_s1_c960'],
]
else:
act_layer = HardSwish
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s1_e1_c16_nre'], # relu
# stage 1, 112x112 in
['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu
# stage 2, 56x56 in
['ir_r3_k5_s2_e3_c40_se0.25_nre', 'xir_r2_k5_s2_e3_c40_se0.25_nre'], # relu
# stage 3, 28x28 in
['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish
# stage 4, 14x14in
['xir_r2_k5_s1_e6_c112_se0.25'], # hard-swish
# stage 5, 14x14in
['ir_r1_k5_s2_e6_c160_se0.25', 'xir_r2_k5_s2_e6_c160_se0.25'], # hard-swish
# stage 6, 7x7 in
['cn_r1_k1_s1_c960'], # hard-swish
]
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
num_features=num_features,
stem_size=16,
channel_multiplier=channel_multiplier,
norm_kwargs=resolve_bn_args(kwargs),
act_layer=act_layer,
attn_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisible_by=8),
**kwargs,
)
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
return model
def _gen_mobilenet_v3_eca(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
"""Creates a MobileNet-V3 model.
@ -471,6 +568,13 @@ def mobilenetv3_eca_large(pretrained=False, **kwargs):
return model
@register_model
def xmobilenetv3_large_100(pretrained=False, **kwargs):
""" MobileNet V3 """
model = _gen_xmobilenet_v3('xmobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)
return model
@register_model
def mobilenetv3_rw(pretrained=False, **kwargs):
""" MobileNet V3 """

Loading…
Cancel
Save