|
|
|
@ -244,6 +244,161 @@ class InvertedResidual(nn.Module):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class XDepthwiseSeparableConv(nn.Module):
|
|
|
|
|
""" DepthwiseSeparable block
|
|
|
|
|
Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion
|
|
|
|
|
(factor of 1.0). This is an alternative to having a IR with an optional first pw conv.
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
|
|
|
|
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
|
|
|
|
|
pw_kernel_size=1, pw_act=False, attn_layer=None, attn_kwargs=None,
|
|
|
|
|
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0.):
|
|
|
|
|
super(XDepthwiseSeparableConv, self).__init__()
|
|
|
|
|
norm_kwargs = norm_kwargs or {}
|
|
|
|
|
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
|
|
|
|
|
self.has_pw_act = pw_act # activation after point-wise conv
|
|
|
|
|
self.drop_path_rate = drop_path_rate
|
|
|
|
|
|
|
|
|
|
conv_kwargs = {}
|
|
|
|
|
self.conv_dw_2x2 = create_conv2d(
|
|
|
|
|
in_chs, in_chs, 2, stride=stride, dilation=dilation,
|
|
|
|
|
padding='same', depthwise=True, **conv_kwargs)
|
|
|
|
|
self.conv_dw_1xk = create_conv2d(
|
|
|
|
|
in_chs, in_chs, (1, dw_kernel_size), stride=stride, dilation=dilation,
|
|
|
|
|
padding=pad_type, depthwise=True, **conv_kwargs)
|
|
|
|
|
self.conv_dw_kx1 = create_conv2d(
|
|
|
|
|
in_chs, in_chs, (dw_kernel_size, 1), stride=stride, dilation=dilation,
|
|
|
|
|
padding=pad_type, depthwise=True, **conv_kwargs)
|
|
|
|
|
self.bn1 = norm_layer(in_chs, **norm_kwargs)
|
|
|
|
|
self.act1 = act_layer(inplace=True)
|
|
|
|
|
|
|
|
|
|
# Attention block (Squeeze-Excitation, ECA, etc)
|
|
|
|
|
if attn_layer is not None:
|
|
|
|
|
attn_kwargs = resolve_attn_args(attn_layer, attn_kwargs, in_chs, act_layer)
|
|
|
|
|
self.se = create_attn(attn_layer, in_chs, **attn_kwargs)
|
|
|
|
|
else:
|
|
|
|
|
self.se = None
|
|
|
|
|
|
|
|
|
|
self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
|
|
|
|
|
self.bn2 = norm_layer(out_chs, **norm_kwargs)
|
|
|
|
|
self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity()
|
|
|
|
|
|
|
|
|
|
def feature_module(self, location):
|
|
|
|
|
# no expansion in this block, pre pw only feature extraction point
|
|
|
|
|
return 'conv_pw'
|
|
|
|
|
|
|
|
|
|
def feature_channels(self, location):
|
|
|
|
|
return self.conv_pw.in_channels
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
residual = x
|
|
|
|
|
|
|
|
|
|
x = self.conv_dw_2x2(x)
|
|
|
|
|
x = self.conv_dw_1xk(x)
|
|
|
|
|
x = self.conv_dw_kx1(x)
|
|
|
|
|
x = self.bn1(x)
|
|
|
|
|
x = self.act1(x)
|
|
|
|
|
|
|
|
|
|
if self.se is not None:
|
|
|
|
|
x = self.se(x)
|
|
|
|
|
|
|
|
|
|
x = self.conv_pw(x)
|
|
|
|
|
x = self.bn2(x)
|
|
|
|
|
x = self.act2(x)
|
|
|
|
|
|
|
|
|
|
if self.has_residual:
|
|
|
|
|
if self.drop_path_rate > 0.:
|
|
|
|
|
x = drop_path(x, self.drop_path_rate, self.training)
|
|
|
|
|
x += residual
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class XInvertedResidual(nn.Module):
|
|
|
|
|
""" Inverted residual block w/ optional SE and CondConv routing"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
|
|
|
|
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
|
|
|
|
|
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, pad_shift=0,
|
|
|
|
|
attn_layer=None, attn_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
|
|
|
|
conv_kwargs=None, drop_path_rate=0.):
|
|
|
|
|
super(XInvertedResidual, self).__init__()
|
|
|
|
|
norm_kwargs = norm_kwargs or {}
|
|
|
|
|
conv_kwargs = conv_kwargs or {}
|
|
|
|
|
mid_chs = make_divisible(in_chs * exp_ratio)
|
|
|
|
|
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
|
|
|
|
|
self.drop_path_rate = drop_path_rate
|
|
|
|
|
|
|
|
|
|
# Point-wise expansion
|
|
|
|
|
self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
|
|
|
|
|
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
|
|
|
|
|
self.act1 = act_layer(inplace=True)
|
|
|
|
|
|
|
|
|
|
# Depth-wise convolution
|
|
|
|
|
self.conv_dw_2x2 = create_conv2d(
|
|
|
|
|
mid_chs, mid_chs, 2, stride=stride, dilation=dilation,
|
|
|
|
|
padding='same', depthwise=True, pad_shift=pad_shift, **conv_kwargs)
|
|
|
|
|
self.conv_dw_1xk = create_conv2d(
|
|
|
|
|
mid_chs, mid_chs, (1, dw_kernel_size), stride=stride, dilation=dilation,
|
|
|
|
|
padding=pad_type, depthwise=True, **conv_kwargs)
|
|
|
|
|
self.conv_dw_kx1 = create_conv2d(
|
|
|
|
|
mid_chs, mid_chs, (dw_kernel_size, 1), stride=stride, dilation=dilation,
|
|
|
|
|
padding=pad_type, depthwise=True, **conv_kwargs)
|
|
|
|
|
self.bn2 = norm_layer(mid_chs, **norm_kwargs)
|
|
|
|
|
self.act2 = act_layer(inplace=True)
|
|
|
|
|
|
|
|
|
|
# Attention block (Squeeze-Excitation, ECA, etc)
|
|
|
|
|
if attn_layer is not None:
|
|
|
|
|
attn_kwargs = resolve_attn_args(attn_layer, attn_kwargs, in_chs, act_layer)
|
|
|
|
|
self.se = create_attn(attn_layer, mid_chs, **attn_kwargs)
|
|
|
|
|
else:
|
|
|
|
|
self.se = None
|
|
|
|
|
|
|
|
|
|
# Point-wise linear projection
|
|
|
|
|
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
|
|
|
|
|
self.bn3 = norm_layer(out_chs, **norm_kwargs)
|
|
|
|
|
|
|
|
|
|
def feature_module(self, location):
|
|
|
|
|
if location == 'post_exp':
|
|
|
|
|
return 'act1'
|
|
|
|
|
return 'conv_pwl'
|
|
|
|
|
|
|
|
|
|
def feature_channels(self, location):
|
|
|
|
|
if location == 'post_exp':
|
|
|
|
|
return self.conv_pw.out_channels
|
|
|
|
|
# location == 'pre_pw'
|
|
|
|
|
return self.conv_pwl.in_channels
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
residual = x
|
|
|
|
|
|
|
|
|
|
# Point-wise expansion
|
|
|
|
|
x = self.conv_pw(x)
|
|
|
|
|
x = self.bn1(x)
|
|
|
|
|
x = self.act1(x)
|
|
|
|
|
|
|
|
|
|
# Depth-wise convolution
|
|
|
|
|
x = self.conv_dw_2x2(x)
|
|
|
|
|
x = self.conv_dw_1xk(x)
|
|
|
|
|
x = self.conv_dw_kx1(x)
|
|
|
|
|
x = self.bn2(x)
|
|
|
|
|
x = self.act2(x)
|
|
|
|
|
|
|
|
|
|
# Attention
|
|
|
|
|
if self.se is not None:
|
|
|
|
|
x = self.se(x)
|
|
|
|
|
|
|
|
|
|
# Point-wise linear projection
|
|
|
|
|
x = self.conv_pwl(x)
|
|
|
|
|
x = self.bn3(x)
|
|
|
|
|
|
|
|
|
|
if self.has_residual:
|
|
|
|
|
if self.drop_path_rate > 0.:
|
|
|
|
|
x = drop_path(x, self.drop_path_rate, self.training)
|
|
|
|
|
x += residual
|
|
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CondConvResidual(InvertedResidual):
|
|
|
|
|
""" Inverted residual block w/ CondConv routing"""
|
|
|
|
|
|
|
|
|
|