Add DropPath (stochastic depth) to ReXNet and VoVNet. RegNet DropPath impl tweak and dedupe se args.

pull/244/head
Ross Wightman 4 years ago
parent e8ca45854c
commit e8e2d9cabf

@ -195,7 +195,7 @@ class RegStage(nn.Module):
"""Stage (sequence of blocks w/ the same output shape).""" """Stage (sequence of blocks w/ the same output shape)."""
def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio, group_width, def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio, group_width,
block_fn=Bottleneck, se_ratio=0., drop_path_rate=None, drop_block=None): block_fn=Bottleneck, se_ratio=0., drop_path_rates=None, drop_block=None):
super(RegStage, self).__init__() super(RegStage, self).__init__()
block_kwargs = {} # FIXME setup to pass various aa, norm, act layer common args block_kwargs = {} # FIXME setup to pass various aa, norm, act layer common args
first_dilation = 1 if dilation in (1, 2) else 2 first_dilation = 1 if dilation in (1, 2) else 2
@ -203,7 +203,10 @@ class RegStage(nn.Module):
block_stride = stride if i == 0 else 1 block_stride = stride if i == 0 else 1
block_in_chs = in_chs if i == 0 else out_chs block_in_chs = in_chs if i == 0 else out_chs
block_dilation = first_dilation if i == 0 else dilation block_dilation = first_dilation if i == 0 else dilation
drop_path = DropPath(drop_path_rate[i]) if drop_path_rate is not None else None if drop_path_rates is not None and drop_path_rates[i] > 0.:
drop_path = DropPath(drop_path_rates[i])
else:
drop_path = None
if (block_in_chs != out_chs) or (block_stride != 1): if (block_in_chs != out_chs) or (block_stride != 1):
proj_block = downsample_conv(block_in_chs, out_chs, 1, block_stride, block_dilation) proj_block = downsample_conv(block_in_chs, out_chs, 1, block_stride, block_dilation)
else: else:
@ -301,7 +304,7 @@ class RegNet(nn.Module):
# Adjust the compatibility of ws and gws # Adjust the compatibility of ws and gws
stage_widths, stage_groups = adjust_widths_groups_comp(stage_widths, stage_bottle_ratios, stage_groups) stage_widths, stage_groups = adjust_widths_groups_comp(stage_widths, stage_bottle_ratios, stage_groups)
param_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_width', 'drop_path_rate'] param_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_width', 'drop_path_rates']
stage_params = [ stage_params = [
dict(zip(param_names, params)) for params in dict(zip(param_names, params)) for params in
zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_bottle_ratios, stage_groups, zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_bottle_ratios, stage_groups,

@ -15,7 +15,7 @@ from math import ceil
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .layers import ClassifierHead, create_act_layer, ConvBnAct from .layers import ClassifierHead, create_act_layer, ConvBnAct, DropPath
from .registry import register_model from .registry import register_model
@ -56,10 +56,10 @@ def make_divisible(v, divisor=8, min_value=None):
class SEWithNorm(nn.Module): class SEWithNorm(nn.Module):
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, divisor=1, reduction_channels=None, def __init__(self, channels, se_ratio=1 / 12., act_layer=nn.ReLU, divisor=1, reduction_channels=None,
gate_layer='sigmoid'): gate_layer='sigmoid'):
super(SEWithNorm, self).__init__() super(SEWithNorm, self).__init__()
reduction_channels = reduction_channels or make_divisible(channels // reduction, divisor=divisor) reduction_channels = reduction_channels or make_divisible(int(channels * se_ratio), divisor=divisor)
self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True) self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True)
self.bn = nn.BatchNorm2d(reduction_channels) self.bn = nn.BatchNorm2d(reduction_channels)
self.act = act_layer(inplace=True) self.act = act_layer(inplace=True)
@ -76,7 +76,7 @@ class SEWithNorm(nn.Module):
class LinearBottleneck(nn.Module): class LinearBottleneck(nn.Module):
def __init__(self, in_chs, out_chs, stride, exp_ratio=1.0, use_se=True, se_rd=12, ch_div=1): def __init__(self, in_chs, out_chs, stride, exp_ratio=1.0, se_ratio=0., ch_div=1, drop_path=None):
super(LinearBottleneck, self).__init__() super(LinearBottleneck, self).__init__()
self.use_shortcut = stride == 1 and in_chs <= out_chs self.use_shortcut = stride == 1 and in_chs <= out_chs
self.in_channels = in_chs self.in_channels = in_chs
@ -90,10 +90,11 @@ class LinearBottleneck(nn.Module):
self.conv_exp = None self.conv_exp = None
self.conv_dw = ConvBnAct(dw_chs, dw_chs, 3, stride=stride, groups=dw_chs, apply_act=False) self.conv_dw = ConvBnAct(dw_chs, dw_chs, 3, stride=stride, groups=dw_chs, apply_act=False)
self.se = SEWithNorm(dw_chs, reduction=se_rd, divisor=ch_div) if use_se else None self.se = SEWithNorm(dw_chs, se_ratio=se_ratio, divisor=ch_div) if se_ratio > 0. else None
self.act_dw = nn.ReLU6() self.act_dw = nn.ReLU6()
self.conv_pwl = ConvBnAct(dw_chs, out_chs, 1, apply_act=False) self.conv_pwl = ConvBnAct(dw_chs, out_chs, 1, apply_act=False)
self.drop_path = drop_path
def feat_channels(self, exp=False): def feat_channels(self, exp=False):
return self.conv_dw.out_channels if exp else self.out_channels return self.conv_dw.out_channels if exp else self.out_channels
@ -107,12 +108,14 @@ class LinearBottleneck(nn.Module):
x = self.se(x) x = self.se(x)
x = self.act_dw(x) x = self.act_dw(x)
x = self.conv_pwl(x) x = self.conv_pwl(x)
if self.drop_path is not None:
x = self.drop_path(x)
if self.use_shortcut: if self.use_shortcut:
x[:, 0:self.in_channels] += shortcut x[:, 0:self.in_channels] += shortcut
return x return x
def _block_cfg(width_mult=1.0, depth_mult=1.0, initial_chs=16, final_chs=180, use_se=True, ch_div=1): def _block_cfg(width_mult=1.0, depth_mult=1.0, initial_chs=16, final_chs=180, se_ratio=0., ch_div=1):
layers = [1, 2, 2, 3, 3, 5] layers = [1, 2, 2, 3, 3, 5]
strides = [1, 2, 2, 2, 1, 2] strides = [1, 2, 2, 2, 1, 2]
layers = [ceil(element * depth_mult) for element in layers] layers = [ceil(element * depth_mult) for element in layers]
@ -127,29 +130,31 @@ def _block_cfg(width_mult=1.0, depth_mult=1.0, initial_chs=16, final_chs=180, us
out_chs_list.append(make_divisible(round(base_chs * width_mult), divisor=ch_div)) out_chs_list.append(make_divisible(round(base_chs * width_mult), divisor=ch_div))
base_chs += final_chs / (depth // 3 * 1.0) base_chs += final_chs / (depth // 3 * 1.0)
if use_se: se_ratios = [0.] * (layers[0] + layers[1]) + [se_ratio] * sum(layers[2:])
use_ses = [False] * (layers[0] + layers[1]) + [True] * sum(layers[2:])
else:
use_ses = [False] * sum(layers[:])
return zip(out_chs_list, exp_ratios, strides, use_ses) return list(zip(out_chs_list, exp_ratios, strides, se_ratios))
def _build_blocks(block_cfg, prev_chs, width_mult, se_rd=12, ch_div=1, feature_location='bottleneck'): def _build_blocks(
block_cfg, prev_chs, width_mult, ch_div=1, drop_path_rate=0., feature_location='bottleneck'):
feat_exp = feature_location == 'expansion' feat_exp = feature_location == 'expansion'
feat_chs = [prev_chs] feat_chs = [prev_chs]
feature_info = [] feature_info = []
curr_stride = 2 curr_stride = 2
features = [] features = []
for block_idx, (chs, exp_ratio, stride, se) in enumerate(block_cfg): num_blocks = len(block_cfg)
for block_idx, (chs, exp_ratio, stride, se_ratio) in enumerate(block_cfg):
if stride > 1: if stride > 1:
fname = 'stem' if block_idx == 0 else f'features.{block_idx - 1}' fname = 'stem' if block_idx == 0 else f'features.{block_idx - 1}'
if block_idx > 0 and feat_exp: if block_idx > 0 and feat_exp:
fname += '.act_dw' fname += '.act_dw'
feature_info += [dict(num_chs=feat_chs[-1], reduction=curr_stride, module=fname)] feature_info += [dict(num_chs=feat_chs[-1], reduction=curr_stride, module=fname)]
curr_stride *= stride curr_stride *= stride
block_dpr = drop_path_rate * block_idx / (num_blocks - 1) # stochastic depth linear decay rule
drop_path = DropPath(block_dpr) if block_dpr > 0. else None
features.append(LinearBottleneck( features.append(LinearBottleneck(
in_chs=prev_chs, out_chs=chs, exp_ratio=exp_ratio, stride=stride, use_se=se, se_rd=se_rd, ch_div=ch_div)) in_chs=prev_chs, out_chs=chs, exp_ratio=exp_ratio, stride=stride, se_ratio=se_ratio,
ch_div=ch_div, drop_path=drop_path))
prev_chs = chs prev_chs = chs
feat_chs += [features[-1].feat_channels(feat_exp)] feat_chs += [features[-1].feat_channels(feat_exp)]
pen_chs = make_divisible(1280 * width_mult, divisor=ch_div) pen_chs = make_divisible(1280 * width_mult, divisor=ch_div)
@ -162,8 +167,8 @@ def _build_blocks(block_cfg, prev_chs, width_mult, se_rd=12, ch_div=1, feature_l
class ReXNetV1(nn.Module): class ReXNetV1(nn.Module):
def __init__(self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32, def __init__(self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32,
initial_chs=16, final_chs=180, width_mult=1.0, depth_mult=1.0, use_se=True, initial_chs=16, final_chs=180, width_mult=1.0, depth_mult=1.0, se_ratio=1/12.,
se_rd=12, ch_div=1, drop_rate=0.2, feature_location='bottleneck'): ch_div=1, drop_rate=0.2, drop_path_rate=0., feature_location='bottleneck'):
super(ReXNetV1, self).__init__() super(ReXNetV1, self).__init__()
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.num_classes = num_classes self.num_classes = num_classes
@ -173,9 +178,9 @@ class ReXNetV1(nn.Module):
stem_chs = make_divisible(round(stem_base_chs * width_mult), divisor=ch_div) stem_chs = make_divisible(round(stem_base_chs * width_mult), divisor=ch_div)
self.stem = ConvBnAct(in_chans, stem_chs, 3, stride=2, act_layer='swish') self.stem = ConvBnAct(in_chans, stem_chs, 3, stride=2, act_layer='swish')
block_cfg = _block_cfg(width_mult, depth_mult, initial_chs, final_chs, use_se, ch_div) block_cfg = _block_cfg(width_mult, depth_mult, initial_chs, final_chs, se_ratio, ch_div)
features, self.feature_info = _build_blocks( features, self.feature_info = _build_blocks(
block_cfg, stem_chs, width_mult, se_rd, ch_div, feature_location) block_cfg, stem_chs, width_mult, ch_div, drop_path_rate, feature_location)
self.num_features = features[-1].out_channels self.num_features = features[-1].out_channels
self.features = nn.Sequential(*features) self.features = nn.Sequential(*features)

@ -20,7 +20,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .registry import register_model from .registry import register_model
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .layers import ConvBnAct, SeparableConvBnAct, BatchNormAct2d, ClassifierHead, \ from .layers import ConvBnAct, SeparableConvBnAct, BatchNormAct2d, ClassifierHead, DropPath,\
create_attn, create_norm_act, get_norm_act_layer create_attn, create_norm_act, get_norm_act_layer
@ -179,7 +179,7 @@ class SequentialAppendList(nn.Sequential):
class OsaBlock(nn.Module): class OsaBlock(nn.Module):
def __init__(self, in_chs, mid_chs, out_chs, layer_per_block, residual=False, def __init__(self, in_chs, mid_chs, out_chs, layer_per_block, residual=False,
depthwise=False, attn='', norm_layer=BatchNormAct2d, act_layer=nn.ReLU): depthwise=False, attn='', norm_layer=BatchNormAct2d, act_layer=nn.ReLU, drop_path=None):
super(OsaBlock, self).__init__() super(OsaBlock, self).__init__()
self.residual = residual self.residual = residual
@ -212,6 +212,8 @@ class OsaBlock(nn.Module):
else: else:
self.attn = None self.attn = None
self.drop_path = drop_path
def forward(self, x): def forward(self, x):
output = [x] output = [x]
if self.conv_reduction is not None: if self.conv_reduction is not None:
@ -220,6 +222,8 @@ class OsaBlock(nn.Module):
x = self.conv_concat(x) x = self.conv_concat(x)
if self.attn is not None: if self.attn is not None:
x = self.attn(x) x = self.attn(x)
if self.drop_path is not None:
x = self.drop_path(x)
if self.residual: if self.residual:
x = x + output[0] x = x + output[0]
return x return x
@ -228,7 +232,8 @@ class OsaBlock(nn.Module):
class OsaStage(nn.Module): class OsaStage(nn.Module):
def __init__(self, in_chs, mid_chs, out_chs, block_per_stage, layer_per_block, downsample=True, def __init__(self, in_chs, mid_chs, out_chs, block_per_stage, layer_per_block, downsample=True,
residual=True, depthwise=False, attn='ese', norm_layer=BatchNormAct2d, act_layer=nn.ReLU): residual=True, depthwise=False, attn='ese', norm_layer=BatchNormAct2d, act_layer=nn.ReLU,
drop_path_rates=None):
super(OsaStage, self).__init__() super(OsaStage, self).__init__()
if downsample: if downsample:
@ -239,10 +244,15 @@ class OsaStage(nn.Module):
blocks = [] blocks = []
for i in range(block_per_stage): for i in range(block_per_stage):
last_block = i == block_per_stage - 1 last_block = i == block_per_stage - 1
if drop_path_rates is not None and drop_path_rates[i] > 0.:
drop_path = DropPath(drop_path_rates[i])
else:
drop_path = None
blocks += [OsaBlock( blocks += [OsaBlock(
in_chs if i == 0 else out_chs, mid_chs, out_chs, layer_per_block, residual=residual and i > 0, in_chs, mid_chs, out_chs, layer_per_block, residual=residual and i > 0, depthwise=depthwise,
depthwise=depthwise, attn=attn if last_block else '', norm_layer=norm_layer, act_layer=act_layer) attn=attn if last_block else '', norm_layer=norm_layer, act_layer=act_layer, drop_path=drop_path)
] ]
in_chs = out_chs
self.blocks = nn.Sequential(*blocks) self.blocks = nn.Sequential(*blocks)
def forward(self, x): def forward(self, x):
@ -255,7 +265,7 @@ class OsaStage(nn.Module):
class VovNet(nn.Module): class VovNet(nn.Module):
def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0., stem_stride=4, def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0., stem_stride=4,
output_stride=32, norm_layer=BatchNormAct2d, act_layer=nn.ReLU): output_stride=32, norm_layer=BatchNormAct2d, act_layer=nn.ReLU, drop_path_rate=0.):
""" VovNet (v2) """ VovNet (v2)
""" """
super(VovNet, self).__init__() super(VovNet, self).__init__()
@ -284,6 +294,7 @@ class VovNet(nn.Module):
current_stride = stem_stride current_stride = stem_stride
# OSA stages # OSA stages
stage_dpr = torch.split(torch.linspace(0, drop_path_rate, sum(block_per_stage)), block_per_stage)
in_ch_list = stem_chs[-1:] + stage_out_chs[:-1] in_ch_list = stem_chs[-1:] + stage_out_chs[:-1]
stage_args = dict(residual=cfg["residual"], depthwise=cfg["depthwise"], attn=cfg["attn"], **conv_kwargs) stage_args = dict(residual=cfg["residual"], depthwise=cfg["depthwise"], attn=cfg["attn"], **conv_kwargs)
stages = [] stages = []
@ -291,7 +302,7 @@ class VovNet(nn.Module):
downsample = stem_stride == 2 or i > 0 # first stage has no stride/downsample if stem_stride is 4 downsample = stem_stride == 2 or i > 0 # first stage has no stride/downsample if stem_stride is 4
stages += [OsaStage( stages += [OsaStage(
in_ch_list[i], stage_conv_chs[i], stage_out_chs[i], block_per_stage[i], layer_per_block, in_ch_list[i], stage_conv_chs[i], stage_out_chs[i], block_per_stage[i], layer_per_block,
downsample=downsample, **stage_args) downsample=downsample, drop_path_rates=stage_dpr[i], **stage_args)
] ]
self.num_features = stage_out_chs[i] self.num_features = stage_out_chs[i]
current_stride *= 2 if downsample else 1 current_stride *= 2 if downsample else 1

Loading…
Cancel
Save