|
|
@ -7,7 +7,7 @@ import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn as nn
|
|
|
|
from torch.nn import functional as F
|
|
|
|
from torch.nn import functional as F
|
|
|
|
|
|
|
|
|
|
|
|
from .layers import create_conv2d, drop_path, make_divisible, get_act_fn, create_act_layer
|
|
|
|
from .layers import create_conv2d, drop_path, make_divisible, create_act_layer
|
|
|
|
from .layers.activations import sigmoid
|
|
|
|
from .layers.activations import sigmoid
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
__all__ = [
|
|
|
@ -19,31 +19,32 @@ class SqueezeExcite(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
in_chs (int): input channels to layer
|
|
|
|
in_chs (int): input channels to layer
|
|
|
|
se_ratio (float): ratio of squeeze reduction
|
|
|
|
rd_ratio (float): ratio of squeeze reduction
|
|
|
|
act_layer (nn.Module): activation layer of containing block
|
|
|
|
act_layer (nn.Module): activation layer of containing block
|
|
|
|
gate_fn (Callable): attention gate function
|
|
|
|
gate_layer (Callable): attention gate function
|
|
|
|
force_act_layer (nn.Module): override block's activation fn if this is set/bound
|
|
|
|
force_act_layer (nn.Module): override block's activation fn if this is set/bound
|
|
|
|
round_chs_fn (Callable): specify a fn to calculate rounding of reduced chs
|
|
|
|
rd_round_fn (Callable): specify a fn to calculate rounding of reduced chs
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
|
self, in_chs, se_ratio=0.25, act_layer=nn.ReLU, gate_fn=sigmoid,
|
|
|
|
self, in_chs, rd_ratio=0.25, rd_channels=None, act_layer=nn.ReLU,
|
|
|
|
force_act_layer=None, round_chs_fn=None):
|
|
|
|
gate_layer=nn.Sigmoid, force_act_layer=None, rd_round_fn=None):
|
|
|
|
super(SqueezeExcite, self).__init__()
|
|
|
|
super(SqueezeExcite, self).__init__()
|
|
|
|
round_chs_fn = round_chs_fn or round
|
|
|
|
if rd_channels is None:
|
|
|
|
reduced_chs = round_chs_fn(in_chs * se_ratio)
|
|
|
|
rd_round_fn = rd_round_fn or round
|
|
|
|
|
|
|
|
rd_channels = rd_round_fn(in_chs * rd_ratio)
|
|
|
|
act_layer = force_act_layer or act_layer
|
|
|
|
act_layer = force_act_layer or act_layer
|
|
|
|
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
|
|
|
|
self.conv_reduce = nn.Conv2d(in_chs, rd_channels, 1, bias=True)
|
|
|
|
self.act1 = create_act_layer(act_layer, inplace=True)
|
|
|
|
self.act1 = create_act_layer(act_layer, inplace=True)
|
|
|
|
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
|
|
|
|
self.conv_expand = nn.Conv2d(rd_channels, in_chs, 1, bias=True)
|
|
|
|
self.gate_fn = get_act_fn(gate_fn)
|
|
|
|
self.gate = create_act_layer(gate_layer)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
x_se = x.mean((2, 3), keepdim=True)
|
|
|
|
x_se = x.mean((2, 3), keepdim=True)
|
|
|
|
x_se = self.conv_reduce(x_se)
|
|
|
|
x_se = self.conv_reduce(x_se)
|
|
|
|
x_se = self.act1(x_se)
|
|
|
|
x_se = self.act1(x_se)
|
|
|
|
x_se = self.conv_expand(x_se)
|
|
|
|
x_se = self.conv_expand(x_se)
|
|
|
|
return x * self.gate_fn(x_se)
|
|
|
|
return x * self.gate(x_se)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConvBnAct(nn.Module):
|
|
|
|
class ConvBnAct(nn.Module):
|
|
|
@ -85,10 +86,9 @@ class DepthwiseSeparableConv(nn.Module):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
|
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='',
|
|
|
|
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='',
|
|
|
|
noskip=False, pw_kernel_size=1, pw_act=False, se_ratio=0.,
|
|
|
|
noskip=False, pw_kernel_size=1, pw_act=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
|
|
|
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.):
|
|
|
|
se_layer=None, drop_path_rate=0.):
|
|
|
|
super(DepthwiseSeparableConv, self).__init__()
|
|
|
|
super(DepthwiseSeparableConv, self).__init__()
|
|
|
|
has_se = se_layer is not None and se_ratio > 0.
|
|
|
|
|
|
|
|
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
|
|
|
|
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
|
|
|
|
self.has_pw_act = pw_act # activation after point-wise conv
|
|
|
|
self.has_pw_act = pw_act # activation after point-wise conv
|
|
|
|
self.drop_path_rate = drop_path_rate
|
|
|
|
self.drop_path_rate = drop_path_rate
|
|
|
@ -99,7 +99,7 @@ class DepthwiseSeparableConv(nn.Module):
|
|
|
|
self.act1 = act_layer(inplace=True)
|
|
|
|
self.act1 = act_layer(inplace=True)
|
|
|
|
|
|
|
|
|
|
|
|
# Squeeze-and-excitation
|
|
|
|
# Squeeze-and-excitation
|
|
|
|
self.se = se_layer(in_chs, se_ratio=se_ratio, act_layer=act_layer) if has_se else nn.Identity()
|
|
|
|
self.se = se_layer(in_chs, act_layer=act_layer) if se_layer else nn.Identity()
|
|
|
|
|
|
|
|
|
|
|
|
self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
|
|
|
|
self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
|
|
|
|
self.bn2 = norm_layer(out_chs)
|
|
|
|
self.bn2 = norm_layer(out_chs)
|
|
|
@ -144,12 +144,11 @@ class InvertedResidual(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
|
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='',
|
|
|
|
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='',
|
|
|
|
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, se_ratio=0.,
|
|
|
|
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU,
|
|
|
|
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, conv_kwargs=None, drop_path_rate=0.):
|
|
|
|
norm_layer=nn.BatchNorm2d, se_layer=None, conv_kwargs=None, drop_path_rate=0.):
|
|
|
|
super(InvertedResidual, self).__init__()
|
|
|
|
super(InvertedResidual, self).__init__()
|
|
|
|
conv_kwargs = conv_kwargs or {}
|
|
|
|
conv_kwargs = conv_kwargs or {}
|
|
|
|
mid_chs = make_divisible(in_chs * exp_ratio)
|
|
|
|
mid_chs = make_divisible(in_chs * exp_ratio)
|
|
|
|
has_se = se_layer is not None and se_ratio > 0.
|
|
|
|
|
|
|
|
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
|
|
|
|
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
|
|
|
|
self.drop_path_rate = drop_path_rate
|
|
|
|
self.drop_path_rate = drop_path_rate
|
|
|
|
|
|
|
|
|
|
|
@ -166,7 +165,7 @@ class InvertedResidual(nn.Module):
|
|
|
|
self.act2 = act_layer(inplace=True)
|
|
|
|
self.act2 = act_layer(inplace=True)
|
|
|
|
|
|
|
|
|
|
|
|
# Squeeze-and-excitation
|
|
|
|
# Squeeze-and-excitation
|
|
|
|
self.se = se_layer(mid_chs, se_ratio=se_ratio, act_layer=act_layer) if has_se else nn.Identity()
|
|
|
|
self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity()
|
|
|
|
|
|
|
|
|
|
|
|
# Point-wise linear projection
|
|
|
|
# Point-wise linear projection
|
|
|
|
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
|
|
|
|
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
|
|
|
@ -212,8 +211,8 @@ class CondConvResidual(InvertedResidual):
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
|
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='',
|
|
|
|
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='',
|
|
|
|
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, se_ratio=0.,
|
|
|
|
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU,
|
|
|
|
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, num_experts=0, drop_path_rate=0.):
|
|
|
|
norm_layer=nn.BatchNorm2d, se_layer=None, num_experts=0, drop_path_rate=0.):
|
|
|
|
|
|
|
|
|
|
|
|
self.num_experts = num_experts
|
|
|
|
self.num_experts = num_experts
|
|
|
|
conv_kwargs = dict(num_experts=self.num_experts)
|
|
|
|
conv_kwargs = dict(num_experts=self.num_experts)
|
|
|
@ -221,8 +220,8 @@ class CondConvResidual(InvertedResidual):
|
|
|
|
super(CondConvResidual, self).__init__(
|
|
|
|
super(CondConvResidual, self).__init__(
|
|
|
|
in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, pad_type=pad_type,
|
|
|
|
in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, pad_type=pad_type,
|
|
|
|
act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size,
|
|
|
|
act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size,
|
|
|
|
pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_layer=se_layer,
|
|
|
|
pw_kernel_size=pw_kernel_size, se_layer=se_layer, norm_layer=norm_layer, conv_kwargs=conv_kwargs,
|
|
|
|
norm_layer=norm_layer, conv_kwargs=conv_kwargs, drop_path_rate=drop_path_rate)
|
|
|
|
drop_path_rate=drop_path_rate)
|
|
|
|
|
|
|
|
|
|
|
|
self.routing_fn = nn.Linear(in_chs, self.num_experts)
|
|
|
|
self.routing_fn = nn.Linear(in_chs, self.num_experts)
|
|
|
|
|
|
|
|
|
|
|
@ -271,8 +270,8 @@ class EdgeResidual(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
|
self, in_chs, out_chs, exp_kernel_size=3, stride=1, dilation=1, pad_type='',
|
|
|
|
self, in_chs, out_chs, exp_kernel_size=3, stride=1, dilation=1, pad_type='',
|
|
|
|
force_in_chs=0, noskip=False, exp_ratio=1.0, pw_kernel_size=1, se_ratio=0.,
|
|
|
|
force_in_chs=0, noskip=False, exp_ratio=1.0, pw_kernel_size=1, act_layer=nn.ReLU,
|
|
|
|
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.):
|
|
|
|
norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.):
|
|
|
|
super(EdgeResidual, self).__init__()
|
|
|
|
super(EdgeResidual, self).__init__()
|
|
|
|
if force_in_chs > 0:
|
|
|
|
if force_in_chs > 0:
|
|
|
|
mid_chs = make_divisible(force_in_chs * exp_ratio)
|
|
|
|
mid_chs = make_divisible(force_in_chs * exp_ratio)
|
|
|
@ -289,7 +288,7 @@ class EdgeResidual(nn.Module):
|
|
|
|
self.act1 = act_layer(inplace=True)
|
|
|
|
self.act1 = act_layer(inplace=True)
|
|
|
|
|
|
|
|
|
|
|
|
# Squeeze-and-excitation
|
|
|
|
# Squeeze-and-excitation
|
|
|
|
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, act_layer=act_layer) if has_se else nn.Identity()
|
|
|
|
self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity()
|
|
|
|
|
|
|
|
|
|
|
|
# Point-wise linear projection
|
|
|
|
# Point-wise linear projection
|
|
|
|
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type)
|
|
|
|
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type)
|
|
|
|