|
|
|
@ -1,8 +1,7 @@
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
from torch.nn import functional as F
|
|
|
|
|
from .layers.activations import sigmoid
|
|
|
|
|
from .layers import create_conv2d, drop_path
|
|
|
|
|
from .layers import create_conv2d, create_attn, drop_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
|
|
|
|
@ -30,26 +29,21 @@ def resolve_bn_args(kwargs):
|
|
|
|
|
return bn_args
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_SE_ARGS_DEFAULT = dict(
|
|
|
|
|
gate_fn=sigmoid,
|
|
|
|
|
act_layer=None,
|
|
|
|
|
reduce_mid=False,
|
|
|
|
|
divisor=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def resolve_se_args(kwargs, in_chs, act_layer=None):
|
|
|
|
|
se_kwargs = kwargs.copy() if kwargs is not None else {}
|
|
|
|
|
# fill in args that aren't specified with the defaults
|
|
|
|
|
for k, v in _SE_ARGS_DEFAULT.items():
|
|
|
|
|
se_kwargs.setdefault(k, v)
|
|
|
|
|
# some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch
|
|
|
|
|
if not se_kwargs.pop('reduce_mid'):
|
|
|
|
|
se_kwargs['reduced_base_chs'] = in_chs
|
|
|
|
|
# act_layer override, if it remains None, the containing block's act_layer will be used
|
|
|
|
|
if se_kwargs['act_layer'] is None:
|
|
|
|
|
assert act_layer is not None
|
|
|
|
|
se_kwargs['act_layer'] = act_layer
|
|
|
|
|
return se_kwargs
|
|
|
|
|
def resolve_attn_args(layer, kwargs, in_chs, act_layer=None):
|
|
|
|
|
attn_kwargs = kwargs.copy() if kwargs is not None else {}
|
|
|
|
|
if isinstance(layer, nn.Module):
|
|
|
|
|
is_se = 'SqueezeExciteV2' in layer.__name__
|
|
|
|
|
else:
|
|
|
|
|
is_se = layer == 'sev2'
|
|
|
|
|
if is_se:
|
|
|
|
|
# some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch
|
|
|
|
|
if not attn_kwargs.pop('reduce_mid', False):
|
|
|
|
|
attn_kwargs['reduced_base_chs'] = in_chs
|
|
|
|
|
# if act_layer it is not defined by attn kwargs, the containing block's act_layer will be used for attn
|
|
|
|
|
if attn_kwargs.get('act_layer', None) is None:
|
|
|
|
|
assert act_layer is not None
|
|
|
|
|
attn_kwargs['act_layer'] = act_layer
|
|
|
|
|
return attn_kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_divisible(v, divisor=8, min_value=None):
|
|
|
|
@ -90,26 +84,6 @@ class ChannelShuffle(nn.Module):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SqueezeExcite(nn.Module):
|
|
|
|
|
def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None,
|
|
|
|
|
act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1, **_):
|
|
|
|
|
super(SqueezeExcite, self).__init__()
|
|
|
|
|
self.gate_fn = gate_fn
|
|
|
|
|
reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor)
|
|
|
|
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
|
|
|
|
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
|
|
|
|
|
self.act1 = act_layer(inplace=True)
|
|
|
|
|
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x_se = self.avg_pool(x)
|
|
|
|
|
x_se = self.conv_reduce(x_se)
|
|
|
|
|
x_se = self.act1(x_se)
|
|
|
|
|
x_se = self.conv_expand(x_se)
|
|
|
|
|
x = x * self.gate_fn(x_se)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConvBnAct(nn.Module):
|
|
|
|
|
def __init__(self, in_chs, out_chs, kernel_size,
|
|
|
|
|
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU,
|
|
|
|
@ -140,11 +114,10 @@ class DepthwiseSeparableConv(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
|
|
|
|
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
|
|
|
|
|
pw_kernel_size=1, pw_act=False, se_ratio=0., se_kwargs=None,
|
|
|
|
|
pw_kernel_size=1, pw_act=False, attn_layer=None, attn_kwargs=None,
|
|
|
|
|
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0.):
|
|
|
|
|
super(DepthwiseSeparableConv, self).__init__()
|
|
|
|
|
norm_kwargs = norm_kwargs or {}
|
|
|
|
|
has_se = se_ratio is not None and se_ratio > 0.
|
|
|
|
|
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
|
|
|
|
|
self.has_pw_act = pw_act # activation after point-wise conv
|
|
|
|
|
self.drop_path_rate = drop_path_rate
|
|
|
|
@ -154,10 +127,10 @@ class DepthwiseSeparableConv(nn.Module):
|
|
|
|
|
self.bn1 = norm_layer(in_chs, **norm_kwargs)
|
|
|
|
|
self.act1 = act_layer(inplace=True)
|
|
|
|
|
|
|
|
|
|
# Squeeze-and-excitation
|
|
|
|
|
if has_se:
|
|
|
|
|
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
|
|
|
|
|
self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs)
|
|
|
|
|
# Attention block (Squeeze-Excitation, ECA, etc)
|
|
|
|
|
if attn_layer is not None:
|
|
|
|
|
attn_kwargs = resolve_attn_args(attn_layer, attn_kwargs, in_chs, act_layer)
|
|
|
|
|
self.se = create_attn(attn_layer, in_chs, **attn_kwargs)
|
|
|
|
|
else:
|
|
|
|
|
self.se = None
|
|
|
|
|
|
|
|
|
@ -199,13 +172,12 @@ class InvertedResidual(nn.Module):
|
|
|
|
|
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
|
|
|
|
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
|
|
|
|
|
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
|
|
|
|
|
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
|
|
|
|
attn_layer=None, attn_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
|
|
|
|
conv_kwargs=None, drop_path_rate=0.):
|
|
|
|
|
super(InvertedResidual, self).__init__()
|
|
|
|
|
norm_kwargs = norm_kwargs or {}
|
|
|
|
|
conv_kwargs = conv_kwargs or {}
|
|
|
|
|
mid_chs = make_divisible(in_chs * exp_ratio)
|
|
|
|
|
has_se = se_ratio is not None and se_ratio > 0.
|
|
|
|
|
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
|
|
|
|
|
self.drop_path_rate = drop_path_rate
|
|
|
|
|
|
|
|
|
@ -221,10 +193,10 @@ class InvertedResidual(nn.Module):
|
|
|
|
|
self.bn2 = norm_layer(mid_chs, **norm_kwargs)
|
|
|
|
|
self.act2 = act_layer(inplace=True)
|
|
|
|
|
|
|
|
|
|
# Squeeze-and-excitation
|
|
|
|
|
if has_se:
|
|
|
|
|
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
|
|
|
|
|
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
|
|
|
|
|
# Attention block (Squeeze-Excitation, ECA, etc)
|
|
|
|
|
if attn_layer is not None:
|
|
|
|
|
attn_kwargs = resolve_attn_args(attn_layer, attn_kwargs, in_chs, act_layer)
|
|
|
|
|
self.se = create_attn(attn_layer, mid_chs, **attn_kwargs)
|
|
|
|
|
else:
|
|
|
|
|
self.se = None
|
|
|
|
|
|
|
|
|
@ -256,7 +228,7 @@ class InvertedResidual(nn.Module):
|
|
|
|
|
x = self.bn2(x)
|
|
|
|
|
x = self.act2(x)
|
|
|
|
|
|
|
|
|
|
# Squeeze-and-excitation
|
|
|
|
|
# Attention
|
|
|
|
|
if self.se is not None:
|
|
|
|
|
x = self.se(x)
|
|
|
|
|
|
|
|
|
@ -278,7 +250,7 @@ class CondConvResidual(InvertedResidual):
|
|
|
|
|
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
|
|
|
|
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
|
|
|
|
|
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
|
|
|
|
|
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
|
|
|
|
attn_layer=None, attn_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
|
|
|
|
num_experts=0, drop_path_rate=0.):
|
|
|
|
|
|
|
|
|
|
self.num_experts = num_experts
|
|
|
|
@ -287,7 +259,7 @@ class CondConvResidual(InvertedResidual):
|
|
|
|
|
super(CondConvResidual, self).__init__(
|
|
|
|
|
in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, pad_type=pad_type,
|
|
|
|
|
act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size,
|
|
|
|
|
pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_kwargs=se_kwargs,
|
|
|
|
|
pw_kernel_size=pw_kernel_size, attn_layer=attn_layer, attn_kwargs=attn_kwargs,
|
|
|
|
|
norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs,
|
|
|
|
|
drop_path_rate=drop_path_rate)
|
|
|
|
|
|
|
|
|
@ -310,7 +282,7 @@ class CondConvResidual(InvertedResidual):
|
|
|
|
|
x = self.bn2(x)
|
|
|
|
|
x = self.act2(x)
|
|
|
|
|
|
|
|
|
|
# Squeeze-and-excitation
|
|
|
|
|
# Attention
|
|
|
|
|
if self.se is not None:
|
|
|
|
|
x = self.se(x)
|
|
|
|
|
|
|
|
|
@ -330,7 +302,7 @@ class EdgeResidual(nn.Module):
|
|
|
|
|
|
|
|
|
|
def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0,
|
|
|
|
|
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1,
|
|
|
|
|
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
|
|
|
|
attn_layer=None, attn_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
|
|
|
|
drop_path_rate=0.):
|
|
|
|
|
super(EdgeResidual, self).__init__()
|
|
|
|
|
norm_kwargs = norm_kwargs or {}
|
|
|
|
@ -338,7 +310,6 @@ class EdgeResidual(nn.Module):
|
|
|
|
|
mid_chs = make_divisible(fake_in_chs * exp_ratio)
|
|
|
|
|
else:
|
|
|
|
|
mid_chs = make_divisible(in_chs * exp_ratio)
|
|
|
|
|
has_se = se_ratio is not None and se_ratio > 0.
|
|
|
|
|
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
|
|
|
|
|
self.drop_path_rate = drop_path_rate
|
|
|
|
|
|
|
|
|
@ -347,10 +318,10 @@ class EdgeResidual(nn.Module):
|
|
|
|
|
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
|
|
|
|
|
self.act1 = act_layer(inplace=True)
|
|
|
|
|
|
|
|
|
|
# Squeeze-and-excitation
|
|
|
|
|
if has_se:
|
|
|
|
|
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
|
|
|
|
|
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
|
|
|
|
|
# Attention block (Squeeze-Excitation, ECA, etc)
|
|
|
|
|
if attn_layer is not None:
|
|
|
|
|
attn_kwargs = resolve_attn_args(attn_layer, attn_kwargs, in_chs, act_layer)
|
|
|
|
|
self.se = create_attn(attn_layer, mid_chs, **attn_kwargs)
|
|
|
|
|
else:
|
|
|
|
|
self.se = None
|
|
|
|
|
|
|
|
|
@ -378,7 +349,7 @@ class EdgeResidual(nn.Module):
|
|
|
|
|
x = self.bn1(x)
|
|
|
|
|
x = self.act1(x)
|
|
|
|
|
|
|
|
|
|
# Squeeze-and-excitation
|
|
|
|
|
# Attention
|
|
|
|
|
if self.se is not None:
|
|
|
|
|
x = self.se(x)
|
|
|
|
|
|
|
|
|
|