|
|
|
@ -15,7 +15,7 @@ from math import ceil
|
|
|
|
|
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
from .helpers import build_model_with_cfg
|
|
|
|
|
from .layers import ClassifierHead, create_act_layer, ConvBnAct
|
|
|
|
|
from .layers import ClassifierHead, create_act_layer, ConvBnAct, DropPath
|
|
|
|
|
from .registry import register_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -56,10 +56,10 @@ def make_divisible(v, divisor=8, min_value=None):
|
|
|
|
|
|
|
|
|
|
class SEWithNorm(nn.Module):
|
|
|
|
|
|
|
|
|
|
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, divisor=1, reduction_channels=None,
|
|
|
|
|
def __init__(self, channels, se_ratio=1 / 12., act_layer=nn.ReLU, divisor=1, reduction_channels=None,
|
|
|
|
|
gate_layer='sigmoid'):
|
|
|
|
|
super(SEWithNorm, self).__init__()
|
|
|
|
|
reduction_channels = reduction_channels or make_divisible(channels // reduction, divisor=divisor)
|
|
|
|
|
reduction_channels = reduction_channels or make_divisible(int(channels * se_ratio), divisor=divisor)
|
|
|
|
|
self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True)
|
|
|
|
|
self.bn = nn.BatchNorm2d(reduction_channels)
|
|
|
|
|
self.act = act_layer(inplace=True)
|
|
|
|
@ -76,7 +76,7 @@ class SEWithNorm(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LinearBottleneck(nn.Module):
|
|
|
|
|
def __init__(self, in_chs, out_chs, stride, exp_ratio=1.0, use_se=True, se_rd=12, ch_div=1):
|
|
|
|
|
def __init__(self, in_chs, out_chs, stride, exp_ratio=1.0, se_ratio=0., ch_div=1, drop_path=None):
|
|
|
|
|
super(LinearBottleneck, self).__init__()
|
|
|
|
|
self.use_shortcut = stride == 1 and in_chs <= out_chs
|
|
|
|
|
self.in_channels = in_chs
|
|
|
|
@ -90,10 +90,11 @@ class LinearBottleneck(nn.Module):
|
|
|
|
|
self.conv_exp = None
|
|
|
|
|
|
|
|
|
|
self.conv_dw = ConvBnAct(dw_chs, dw_chs, 3, stride=stride, groups=dw_chs, apply_act=False)
|
|
|
|
|
self.se = SEWithNorm(dw_chs, reduction=se_rd, divisor=ch_div) if use_se else None
|
|
|
|
|
self.se = SEWithNorm(dw_chs, se_ratio=se_ratio, divisor=ch_div) if se_ratio > 0. else None
|
|
|
|
|
self.act_dw = nn.ReLU6()
|
|
|
|
|
|
|
|
|
|
self.conv_pwl = ConvBnAct(dw_chs, out_chs, 1, apply_act=False)
|
|
|
|
|
self.drop_path = drop_path
|
|
|
|
|
|
|
|
|
|
def feat_channels(self, exp=False):
|
|
|
|
|
return self.conv_dw.out_channels if exp else self.out_channels
|
|
|
|
@ -107,12 +108,14 @@ class LinearBottleneck(nn.Module):
|
|
|
|
|
x = self.se(x)
|
|
|
|
|
x = self.act_dw(x)
|
|
|
|
|
x = self.conv_pwl(x)
|
|
|
|
|
if self.drop_path is not None:
|
|
|
|
|
x = self.drop_path(x)
|
|
|
|
|
if self.use_shortcut:
|
|
|
|
|
x[:, 0:self.in_channels] += shortcut
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _block_cfg(width_mult=1.0, depth_mult=1.0, initial_chs=16, final_chs=180, use_se=True, ch_div=1):
|
|
|
|
|
def _block_cfg(width_mult=1.0, depth_mult=1.0, initial_chs=16, final_chs=180, se_ratio=0., ch_div=1):
|
|
|
|
|
layers = [1, 2, 2, 3, 3, 5]
|
|
|
|
|
strides = [1, 2, 2, 2, 1, 2]
|
|
|
|
|
layers = [ceil(element * depth_mult) for element in layers]
|
|
|
|
@ -127,29 +130,31 @@ def _block_cfg(width_mult=1.0, depth_mult=1.0, initial_chs=16, final_chs=180, us
|
|
|
|
|
out_chs_list.append(make_divisible(round(base_chs * width_mult), divisor=ch_div))
|
|
|
|
|
base_chs += final_chs / (depth // 3 * 1.0)
|
|
|
|
|
|
|
|
|
|
if use_se:
|
|
|
|
|
use_ses = [False] * (layers[0] + layers[1]) + [True] * sum(layers[2:])
|
|
|
|
|
else:
|
|
|
|
|
use_ses = [False] * sum(layers[:])
|
|
|
|
|
se_ratios = [0.] * (layers[0] + layers[1]) + [se_ratio] * sum(layers[2:])
|
|
|
|
|
|
|
|
|
|
return zip(out_chs_list, exp_ratios, strides, use_ses)
|
|
|
|
|
return list(zip(out_chs_list, exp_ratios, strides, se_ratios))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _build_blocks(block_cfg, prev_chs, width_mult, se_rd=12, ch_div=1, feature_location='bottleneck'):
|
|
|
|
|
def _build_blocks(
|
|
|
|
|
block_cfg, prev_chs, width_mult, ch_div=1, drop_path_rate=0., feature_location='bottleneck'):
|
|
|
|
|
feat_exp = feature_location == 'expansion'
|
|
|
|
|
feat_chs = [prev_chs]
|
|
|
|
|
feature_info = []
|
|
|
|
|
curr_stride = 2
|
|
|
|
|
features = []
|
|
|
|
|
for block_idx, (chs, exp_ratio, stride, se) in enumerate(block_cfg):
|
|
|
|
|
num_blocks = len(block_cfg)
|
|
|
|
|
for block_idx, (chs, exp_ratio, stride, se_ratio) in enumerate(block_cfg):
|
|
|
|
|
if stride > 1:
|
|
|
|
|
fname = 'stem' if block_idx == 0 else f'features.{block_idx - 1}'
|
|
|
|
|
if block_idx > 0 and feat_exp:
|
|
|
|
|
fname += '.act_dw'
|
|
|
|
|
feature_info += [dict(num_chs=feat_chs[-1], reduction=curr_stride, module=fname)]
|
|
|
|
|
curr_stride *= stride
|
|
|
|
|
block_dpr = drop_path_rate * block_idx / (num_blocks - 1) # stochastic depth linear decay rule
|
|
|
|
|
drop_path = DropPath(block_dpr) if block_dpr > 0. else None
|
|
|
|
|
features.append(LinearBottleneck(
|
|
|
|
|
in_chs=prev_chs, out_chs=chs, exp_ratio=exp_ratio, stride=stride, use_se=se, se_rd=se_rd, ch_div=ch_div))
|
|
|
|
|
in_chs=prev_chs, out_chs=chs, exp_ratio=exp_ratio, stride=stride, se_ratio=se_ratio,
|
|
|
|
|
ch_div=ch_div, drop_path=drop_path))
|
|
|
|
|
prev_chs = chs
|
|
|
|
|
feat_chs += [features[-1].feat_channels(feat_exp)]
|
|
|
|
|
pen_chs = make_divisible(1280 * width_mult, divisor=ch_div)
|
|
|
|
@ -162,8 +167,8 @@ def _build_blocks(block_cfg, prev_chs, width_mult, se_rd=12, ch_div=1, feature_l
|
|
|
|
|
|
|
|
|
|
class ReXNetV1(nn.Module):
|
|
|
|
|
def __init__(self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32,
|
|
|
|
|
initial_chs=16, final_chs=180, width_mult=1.0, depth_mult=1.0, use_se=True,
|
|
|
|
|
se_rd=12, ch_div=1, drop_rate=0.2, feature_location='bottleneck'):
|
|
|
|
|
initial_chs=16, final_chs=180, width_mult=1.0, depth_mult=1.0, se_ratio=1/12.,
|
|
|
|
|
ch_div=1, drop_rate=0.2, drop_path_rate=0., feature_location='bottleneck'):
|
|
|
|
|
super(ReXNetV1, self).__init__()
|
|
|
|
|
self.drop_rate = drop_rate
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
@ -173,9 +178,9 @@ class ReXNetV1(nn.Module):
|
|
|
|
|
stem_chs = make_divisible(round(stem_base_chs * width_mult), divisor=ch_div)
|
|
|
|
|
self.stem = ConvBnAct(in_chans, stem_chs, 3, stride=2, act_layer='swish')
|
|
|
|
|
|
|
|
|
|
block_cfg = _block_cfg(width_mult, depth_mult, initial_chs, final_chs, use_se, ch_div)
|
|
|
|
|
block_cfg = _block_cfg(width_mult, depth_mult, initial_chs, final_chs, se_ratio, ch_div)
|
|
|
|
|
features, self.feature_info = _build_blocks(
|
|
|
|
|
block_cfg, stem_chs, width_mult, se_rd, ch_div, feature_location)
|
|
|
|
|
block_cfg, stem_chs, width_mult, ch_div, drop_path_rate, feature_location)
|
|
|
|
|
self.num_features = features[-1].out_channels
|
|
|
|
|
self.features = nn.Sequential(*features)
|
|
|
|
|
|
|
|
|
|