You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
521 lines
19 KiB
521 lines
19 KiB
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import functional as F
|
|
from .layers import create_conv2d, create_attn, drop_path
|
|
|
|
|
|
# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
|
|
# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)
|
|
# NOTE: momentum varies btw .99 and .9997 depending on source
|
|
# .99 in official TF TPU impl
|
|
# .9997 (/w .999 in search space) for paper
|
|
BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
|
|
BN_EPS_TF_DEFAULT = 1e-3
|
|
_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT)
|
|
|
|
|
|
def get_bn_args_tf():
|
|
return _BN_ARGS_TF.copy()
|
|
|
|
|
|
def resolve_bn_args(kwargs):
|
|
bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {}
|
|
bn_momentum = kwargs.pop('bn_momentum', None)
|
|
if bn_momentum is not None:
|
|
bn_args['momentum'] = bn_momentum
|
|
bn_eps = kwargs.pop('bn_eps', None)
|
|
if bn_eps is not None:
|
|
bn_args['eps'] = bn_eps
|
|
return bn_args
|
|
|
|
|
|
def resolve_attn_args(layer, kwargs, in_chs, act_layer=None):
|
|
attn_kwargs = kwargs.copy() if kwargs is not None else {}
|
|
if isinstance(layer, nn.Module):
|
|
is_se = 'SqueezeExciteV2' in layer.__name__
|
|
else:
|
|
is_se = layer == 'sev2'
|
|
if is_se:
|
|
# some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch
|
|
if not attn_kwargs.pop('reduce_mid', False):
|
|
attn_kwargs['reduced_base_chs'] = in_chs
|
|
# if act_layer it is not defined by attn kwargs, the containing block's act_layer will be used for attn
|
|
if attn_kwargs.get('act_layer', None) is None:
|
|
assert act_layer is not None
|
|
attn_kwargs['act_layer'] = act_layer
|
|
return attn_kwargs
|
|
|
|
|
|
def make_divisible(v, divisor=8, min_value=None):
|
|
min_value = min_value or divisor
|
|
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
|
# Make sure that round down does not go down by more than 10%.
|
|
if new_v < 0.9 * v:
|
|
new_v += divisor
|
|
return new_v
|
|
|
|
|
|
def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
|
|
"""Round number of filters based on depth multiplier."""
|
|
if not multiplier:
|
|
return channels
|
|
channels *= multiplier
|
|
return make_divisible(channels, divisor, channel_min)
|
|
|
|
|
|
class ChannelShuffle(nn.Module):
|
|
# FIXME haven't used yet
|
|
def __init__(self, groups):
|
|
super(ChannelShuffle, self).__init__()
|
|
self.groups = groups
|
|
|
|
def forward(self, x):
|
|
"""Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]"""
|
|
N, C, H, W = x.size()
|
|
g = self.groups
|
|
assert C % g == 0, "Incompatible group size {} for input channel {}".format(
|
|
g, C
|
|
)
|
|
return (
|
|
x.view(N, g, int(C / g), H, W)
|
|
.permute(0, 2, 1, 3, 4)
|
|
.contiguous()
|
|
.view(N, C, H, W)
|
|
)
|
|
|
|
|
|
class ConvBnAct(nn.Module):
|
|
def __init__(self, in_chs, out_chs, kernel_size,
|
|
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU,
|
|
norm_layer=nn.BatchNorm2d, norm_kwargs=None):
|
|
super(ConvBnAct, self).__init__()
|
|
norm_kwargs = norm_kwargs or {}
|
|
self.conv = create_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type)
|
|
self.bn1 = norm_layer(out_chs, **norm_kwargs)
|
|
self.act1 = act_layer(inplace=True)
|
|
|
|
def feature_module(self, location):
|
|
return 'act1'
|
|
|
|
def feature_channels(self, location):
|
|
return self.conv.out_channels
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.bn1(x)
|
|
x = self.act1(x)
|
|
return x
|
|
|
|
|
|
class DepthwiseSeparableConv(nn.Module):
|
|
""" DepthwiseSeparable block
|
|
Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion
|
|
(factor of 1.0). This is an alternative to having a IR with an optional first pw conv.
|
|
"""
|
|
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
|
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
|
|
pw_kernel_size=1, pw_act=False, attn_layer=None, attn_kwargs=None,
|
|
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0.):
|
|
super(DepthwiseSeparableConv, self).__init__()
|
|
norm_kwargs = norm_kwargs or {}
|
|
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
|
|
self.has_pw_act = pw_act # activation after point-wise conv
|
|
self.drop_path_rate = drop_path_rate
|
|
|
|
self.conv_dw = create_conv2d(
|
|
in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True)
|
|
self.bn1 = norm_layer(in_chs, **norm_kwargs)
|
|
self.act1 = act_layer(inplace=True)
|
|
|
|
# Attention block (Squeeze-Excitation, ECA, etc)
|
|
if attn_layer is not None:
|
|
attn_kwargs = resolve_attn_args(attn_layer, attn_kwargs, in_chs, act_layer)
|
|
self.se = create_attn(attn_layer, in_chs, **attn_kwargs)
|
|
else:
|
|
self.se = None
|
|
|
|
self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
|
|
self.bn2 = norm_layer(out_chs, **norm_kwargs)
|
|
self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity()
|
|
|
|
def feature_module(self, location):
|
|
# no expansion in this block, pre pw only feature extraction point
|
|
return 'conv_pw'
|
|
|
|
def feature_channels(self, location):
|
|
return self.conv_pw.in_channels
|
|
|
|
def forward(self, x):
|
|
residual = x
|
|
|
|
x = self.conv_dw(x)
|
|
x = self.bn1(x)
|
|
x = self.act1(x)
|
|
|
|
if self.se is not None:
|
|
x = self.se(x)
|
|
|
|
x = self.conv_pw(x)
|
|
x = self.bn2(x)
|
|
x = self.act2(x)
|
|
|
|
if self.has_residual:
|
|
if self.drop_path_rate > 0.:
|
|
x = drop_path(x, self.drop_path_rate, self.training)
|
|
x += residual
|
|
return x
|
|
|
|
|
|
class InvertedResidual(nn.Module):
|
|
""" Inverted residual block w/ optional SE and CondConv routing"""
|
|
|
|
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
|
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
|
|
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
|
|
attn_layer=None, attn_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
|
conv_kwargs=None, drop_path_rate=0.):
|
|
super(InvertedResidual, self).__init__()
|
|
norm_kwargs = norm_kwargs or {}
|
|
conv_kwargs = conv_kwargs or {}
|
|
mid_chs = make_divisible(in_chs * exp_ratio)
|
|
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
|
|
self.drop_path_rate = drop_path_rate
|
|
|
|
# Point-wise expansion
|
|
self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
|
|
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
|
|
self.act1 = act_layer(inplace=True)
|
|
|
|
# Depth-wise convolution
|
|
self.conv_dw = create_conv2d(
|
|
mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation,
|
|
padding=pad_type, depthwise=True, **conv_kwargs)
|
|
self.bn2 = norm_layer(mid_chs, **norm_kwargs)
|
|
self.act2 = act_layer(inplace=True)
|
|
|
|
# Attention block (Squeeze-Excitation, ECA, etc)
|
|
if attn_layer is not None:
|
|
attn_kwargs = resolve_attn_args(attn_layer, attn_kwargs, in_chs, act_layer)
|
|
self.se = create_attn(attn_layer, mid_chs, **attn_kwargs)
|
|
else:
|
|
self.se = None
|
|
|
|
# Point-wise linear projection
|
|
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
|
|
self.bn3 = norm_layer(out_chs, **norm_kwargs)
|
|
|
|
def feature_module(self, location):
|
|
if location == 'post_exp':
|
|
return 'act1'
|
|
return 'conv_pwl'
|
|
|
|
def feature_channels(self, location):
|
|
if location == 'post_exp':
|
|
return self.conv_pw.out_channels
|
|
# location == 'pre_pw'
|
|
return self.conv_pwl.in_channels
|
|
|
|
def forward(self, x):
|
|
residual = x
|
|
|
|
# Point-wise expansion
|
|
x = self.conv_pw(x)
|
|
x = self.bn1(x)
|
|
x = self.act1(x)
|
|
|
|
# Depth-wise convolution
|
|
x = self.conv_dw(x)
|
|
x = self.bn2(x)
|
|
x = self.act2(x)
|
|
|
|
# Attention
|
|
if self.se is not None:
|
|
x = self.se(x)
|
|
|
|
# Point-wise linear projection
|
|
x = self.conv_pwl(x)
|
|
x = self.bn3(x)
|
|
|
|
if self.has_residual:
|
|
if self.drop_path_rate > 0.:
|
|
x = drop_path(x, self.drop_path_rate, self.training)
|
|
x += residual
|
|
|
|
return x
|
|
|
|
|
|
class XDepthwiseSeparableConv(nn.Module):
|
|
""" DepthwiseSeparable block
|
|
Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion
|
|
(factor of 1.0). This is an alternative to having a IR with an optional first pw conv.
|
|
"""
|
|
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
|
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
|
|
pw_kernel_size=1, pw_act=False, attn_layer=None, attn_kwargs=None,
|
|
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0.):
|
|
super(XDepthwiseSeparableConv, self).__init__()
|
|
norm_kwargs = norm_kwargs or {}
|
|
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
|
|
self.has_pw_act = pw_act # activation after point-wise conv
|
|
self.drop_path_rate = drop_path_rate
|
|
|
|
conv_kwargs = {}
|
|
self.conv_dw_2x2 = create_conv2d(
|
|
in_chs, in_chs, 2, stride=stride, dilation=dilation,
|
|
padding='same', depthwise=True, **conv_kwargs)
|
|
self.conv_dw_1xk = create_conv2d(
|
|
in_chs, in_chs, (1, dw_kernel_size), stride=stride, dilation=dilation,
|
|
padding=pad_type, depthwise=True, **conv_kwargs)
|
|
self.conv_dw_kx1 = create_conv2d(
|
|
in_chs, in_chs, (dw_kernel_size, 1), stride=stride, dilation=dilation,
|
|
padding=pad_type, depthwise=True, **conv_kwargs)
|
|
self.bn1 = norm_layer(in_chs, **norm_kwargs)
|
|
self.act1 = act_layer(inplace=True)
|
|
|
|
# Attention block (Squeeze-Excitation, ECA, etc)
|
|
if attn_layer is not None:
|
|
attn_kwargs = resolve_attn_args(attn_layer, attn_kwargs, in_chs, act_layer)
|
|
self.se = create_attn(attn_layer, in_chs, **attn_kwargs)
|
|
else:
|
|
self.se = None
|
|
|
|
self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
|
|
self.bn2 = norm_layer(out_chs, **norm_kwargs)
|
|
self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity()
|
|
|
|
def feature_module(self, location):
|
|
# no expansion in this block, pre pw only feature extraction point
|
|
return 'conv_pw'
|
|
|
|
def feature_channels(self, location):
|
|
return self.conv_pw.in_channels
|
|
|
|
def forward(self, x):
|
|
residual = x
|
|
|
|
x = self.conv_dw_2x2(x)
|
|
x = self.conv_dw_1xk(x)
|
|
x = self.conv_dw_kx1(x)
|
|
x = self.bn1(x)
|
|
x = self.act1(x)
|
|
|
|
if self.se is not None:
|
|
x = self.se(x)
|
|
|
|
x = self.conv_pw(x)
|
|
x = self.bn2(x)
|
|
x = self.act2(x)
|
|
|
|
if self.has_residual:
|
|
if self.drop_path_rate > 0.:
|
|
x = drop_path(x, self.drop_path_rate, self.training)
|
|
x += residual
|
|
return x
|
|
|
|
|
|
class XInvertedResidual(nn.Module):
|
|
""" Inverted residual block w/ optional SE and CondConv routing"""
|
|
|
|
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
|
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
|
|
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, pad_shift=0,
|
|
attn_layer=None, attn_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
|
conv_kwargs=None, drop_path_rate=0.):
|
|
super(XInvertedResidual, self).__init__()
|
|
norm_kwargs = norm_kwargs or {}
|
|
conv_kwargs = conv_kwargs or {}
|
|
mid_chs = make_divisible(in_chs * exp_ratio)
|
|
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
|
|
self.drop_path_rate = drop_path_rate
|
|
|
|
# Point-wise expansion
|
|
self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
|
|
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
|
|
self.act1 = act_layer(inplace=True)
|
|
|
|
# Depth-wise convolution
|
|
self.conv_dw_2x2 = create_conv2d(
|
|
mid_chs, mid_chs, 2, stride=stride, dilation=dilation,
|
|
padding='same', depthwise=True, pad_shift=pad_shift, **conv_kwargs)
|
|
self.conv_dw_1xk = create_conv2d(
|
|
mid_chs, mid_chs, (1, dw_kernel_size), stride=stride, dilation=dilation,
|
|
padding=pad_type, depthwise=True, **conv_kwargs)
|
|
self.conv_dw_kx1 = create_conv2d(
|
|
mid_chs, mid_chs, (dw_kernel_size, 1), stride=stride, dilation=dilation,
|
|
padding=pad_type, depthwise=True, **conv_kwargs)
|
|
self.bn2 = norm_layer(mid_chs, **norm_kwargs)
|
|
self.act2 = act_layer(inplace=True)
|
|
|
|
# Attention block (Squeeze-Excitation, ECA, etc)
|
|
if attn_layer is not None:
|
|
attn_kwargs = resolve_attn_args(attn_layer, attn_kwargs, in_chs, act_layer)
|
|
self.se = create_attn(attn_layer, mid_chs, **attn_kwargs)
|
|
else:
|
|
self.se = None
|
|
|
|
# Point-wise linear projection
|
|
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
|
|
self.bn3 = norm_layer(out_chs, **norm_kwargs)
|
|
|
|
def feature_module(self, location):
|
|
if location == 'post_exp':
|
|
return 'act1'
|
|
return 'conv_pwl'
|
|
|
|
def feature_channels(self, location):
|
|
if location == 'post_exp':
|
|
return self.conv_pw.out_channels
|
|
# location == 'pre_pw'
|
|
return self.conv_pwl.in_channels
|
|
|
|
def forward(self, x):
|
|
residual = x
|
|
|
|
# Point-wise expansion
|
|
x = self.conv_pw(x)
|
|
x = self.bn1(x)
|
|
x = self.act1(x)
|
|
|
|
# Depth-wise convolution
|
|
x = self.conv_dw_2x2(x)
|
|
x = self.conv_dw_1xk(x)
|
|
x = self.conv_dw_kx1(x)
|
|
x = self.bn2(x)
|
|
x = self.act2(x)
|
|
|
|
# Attention
|
|
if self.se is not None:
|
|
x = self.se(x)
|
|
|
|
# Point-wise linear projection
|
|
x = self.conv_pwl(x)
|
|
x = self.bn3(x)
|
|
|
|
if self.has_residual:
|
|
if self.drop_path_rate > 0.:
|
|
x = drop_path(x, self.drop_path_rate, self.training)
|
|
x += residual
|
|
|
|
return x
|
|
|
|
|
|
class CondConvResidual(InvertedResidual):
|
|
""" Inverted residual block w/ CondConv routing"""
|
|
|
|
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
|
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
|
|
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
|
|
attn_layer=None, attn_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
|
num_experts=0, drop_path_rate=0.):
|
|
|
|
self.num_experts = num_experts
|
|
conv_kwargs = dict(num_experts=self.num_experts)
|
|
|
|
super(CondConvResidual, self).__init__(
|
|
in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, pad_type=pad_type,
|
|
act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size,
|
|
pw_kernel_size=pw_kernel_size, attn_layer=attn_layer, attn_kwargs=attn_kwargs,
|
|
norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs,
|
|
drop_path_rate=drop_path_rate)
|
|
|
|
self.routing_fn = nn.Linear(in_chs, self.num_experts)
|
|
|
|
def forward(self, x):
|
|
residual = x
|
|
|
|
# CondConv routing
|
|
pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1)
|
|
routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs))
|
|
|
|
# Point-wise expansion
|
|
x = self.conv_pw(x, routing_weights)
|
|
x = self.bn1(x)
|
|
x = self.act1(x)
|
|
|
|
# Depth-wise convolution
|
|
x = self.conv_dw(x, routing_weights)
|
|
x = self.bn2(x)
|
|
x = self.act2(x)
|
|
|
|
# Attention
|
|
if self.se is not None:
|
|
x = self.se(x)
|
|
|
|
# Point-wise linear projection
|
|
x = self.conv_pwl(x, routing_weights)
|
|
x = self.bn3(x)
|
|
|
|
if self.has_residual:
|
|
if self.drop_path_rate > 0.:
|
|
x = drop_path(x, self.drop_path_rate, self.training)
|
|
x += residual
|
|
return x
|
|
|
|
|
|
class EdgeResidual(nn.Module):
|
|
""" Residual block with expansion convolution followed by pointwise-linear w/ stride"""
|
|
|
|
def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0,
|
|
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1,
|
|
attn_layer=None, attn_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
|
drop_path_rate=0.):
|
|
super(EdgeResidual, self).__init__()
|
|
norm_kwargs = norm_kwargs or {}
|
|
if fake_in_chs > 0:
|
|
mid_chs = make_divisible(fake_in_chs * exp_ratio)
|
|
else:
|
|
mid_chs = make_divisible(in_chs * exp_ratio)
|
|
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
|
|
self.drop_path_rate = drop_path_rate
|
|
|
|
# Expansion convolution
|
|
self.conv_exp = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type)
|
|
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
|
|
self.act1 = act_layer(inplace=True)
|
|
|
|
# Attention block (Squeeze-Excitation, ECA, etc)
|
|
if attn_layer is not None:
|
|
attn_kwargs = resolve_attn_args(attn_layer, attn_kwargs, in_chs, act_layer)
|
|
self.se = create_attn(attn_layer, mid_chs, **attn_kwargs)
|
|
else:
|
|
self.se = None
|
|
|
|
# Point-wise linear projection
|
|
self.conv_pwl = create_conv2d(
|
|
mid_chs, out_chs, pw_kernel_size, stride=stride, dilation=dilation, padding=pad_type)
|
|
self.bn2 = norm_layer(out_chs, **norm_kwargs)
|
|
|
|
def feature_module(self, location):
|
|
if location == 'post_exp':
|
|
return 'act1'
|
|
return 'conv_pwl'
|
|
|
|
def feature_channels(self, location):
|
|
if location == 'post_exp':
|
|
return self.conv_exp.out_channels
|
|
# location == 'pre_pw'
|
|
return self.conv_pwl.in_channels
|
|
|
|
def forward(self, x):
|
|
residual = x
|
|
|
|
# Expansion convolution
|
|
x = self.conv_exp(x)
|
|
x = self.bn1(x)
|
|
x = self.act1(x)
|
|
|
|
# Attention
|
|
if self.se is not None:
|
|
x = self.se(x)
|
|
|
|
# Point-wise linear projection
|
|
x = self.conv_pwl(x)
|
|
x = self.bn2(x)
|
|
|
|
if self.has_residual:
|
|
if self.drop_path_rate > 0.:
|
|
x = drop_path(x, self.drop_path_rate, self.training)
|
|
x += residual
|
|
|
|
return x
|