Add DropPath (stochastic depth) to RegNet

pull/227/head
Ross Wightman 4 years ago
parent 47794d2c59
commit 6890300877

@ -18,7 +18,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .layers import ClassifierHead, AvgPool2dSame, ConvBnAct, SEModule from .layers import ClassifierHead, AvgPool2dSame, ConvBnAct, SEModule, DropPath
from .registry import register_model from .registry import register_model
@ -195,7 +195,7 @@ class RegStage(nn.Module):
"""Stage (sequence of blocks w/ the same output shape).""" """Stage (sequence of blocks w/ the same output shape)."""
def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio, group_width, def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio, group_width,
block_fn=Bottleneck, se_ratio=0.): block_fn=Bottleneck, se_ratio=0., drop_path_rate=None, drop_block=None):
super(RegStage, self).__init__() super(RegStage, self).__init__()
block_kwargs = {} # FIXME setup to pass various aa, norm, act layer common args block_kwargs = {} # FIXME setup to pass various aa, norm, act layer common args
first_dilation = 1 if dilation in (1, 2) else 2 first_dilation = 1 if dilation in (1, 2) else 2
@ -203,6 +203,7 @@ class RegStage(nn.Module):
block_stride = stride if i == 0 else 1 block_stride = stride if i == 0 else 1
block_in_chs = in_chs if i == 0 else out_chs block_in_chs = in_chs if i == 0 else out_chs
block_dilation = first_dilation if i == 0 else dilation block_dilation = first_dilation if i == 0 else dilation
drop_path = DropPath(drop_path_rate[i]) if drop_path_rate is not None else None
if (block_in_chs != out_chs) or (block_stride != 1): if (block_in_chs != out_chs) or (block_stride != 1):
proj_block = downsample_conv(block_in_chs, out_chs, 1, block_stride, block_dilation) proj_block = downsample_conv(block_in_chs, out_chs, 1, block_stride, block_dilation)
else: else:
@ -212,7 +213,7 @@ class RegStage(nn.Module):
self.add_module( self.add_module(
name, block_fn( name, block_fn(
block_in_chs, out_chs, block_stride, block_dilation, bottle_ratio, group_width, se_ratio, block_in_chs, out_chs, block_stride, block_dilation, bottle_ratio, group_width, se_ratio,
downsample=proj_block, **block_kwargs) downsample=proj_block, drop_block=drop_block, drop_path=drop_path, **block_kwargs)
) )
def forward(self, x): def forward(self, x):
@ -229,7 +230,7 @@ class RegNet(nn.Module):
""" """
def __init__(self, cfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', drop_rate=0., def __init__(self, cfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', drop_rate=0.,
zero_init_last_bn=True): drop_path_rate=0., zero_init_last_bn=True):
super().__init__() super().__init__()
# TODO add drop block, drop path, anti-aliasing, custom bn/act args # TODO add drop block, drop path, anti-aliasing, custom bn/act args
self.num_classes = num_classes self.num_classes = num_classes
@ -244,7 +245,7 @@ class RegNet(nn.Module):
# Construct the stages # Construct the stages
prev_width = stem_width prev_width = stem_width
curr_stride = 2 curr_stride = 2
stage_params = self._get_stage_params(cfg, output_stride=output_stride) stage_params = self._get_stage_params(cfg, output_stride=output_stride, drop_path_rate=drop_path_rate)
se_ratio = cfg['se_ratio'] se_ratio = cfg['se_ratio']
for i, stage_args in enumerate(stage_params): for i, stage_args in enumerate(stage_params):
stage_name = "s{}".format(i + 1) stage_name = "s{}".format(i + 1)
@ -272,7 +273,7 @@ class RegNet(nn.Module):
if hasattr(m, 'zero_init_last_bn'): if hasattr(m, 'zero_init_last_bn'):
m.zero_init_last_bn() m.zero_init_last_bn()
def _get_stage_params(self, cfg, default_stride=2, output_stride=32): def _get_stage_params(self, cfg, default_stride=2, output_stride=32, drop_path_rate=0.):
# Generate RegNet ws per block # Generate RegNet ws per block
w_a, w_0, w_m, d = cfg['wa'], cfg['w0'], cfg['wm'], cfg['depth'] w_a, w_0, w_m, d = cfg['wa'], cfg['w0'], cfg['wm'], cfg['depth']
widths, num_stages, _, _ = generate_regnet(w_a, w_0, w_m, d) widths, num_stages, _, _ = generate_regnet(w_a, w_0, w_m, d)
@ -285,24 +286,26 @@ class RegNet(nn.Module):
stage_bottle_ratios = [cfg['bottle_ratio'] for _ in range(num_stages)] stage_bottle_ratios = [cfg['bottle_ratio'] for _ in range(num_stages)]
stage_strides = [] stage_strides = []
stage_dilations = [] stage_dilations = []
total_stride = 2 net_stride = 2
dilation = 1 dilation = 1
for _ in range(num_stages): for _ in range(num_stages):
if total_stride >= output_stride: if net_stride >= output_stride:
dilation *= default_stride dilation *= default_stride
stride = 1 stride = 1
else: else:
stride = default_stride stride = default_stride
total_stride *= stride net_stride *= stride
stage_strides.append(stride) stage_strides.append(stride)
stage_dilations.append(dilation) stage_dilations.append(dilation)
stage_dpr = np.split(np.linspace(0, drop_path_rate, d), np.cumsum(stage_depths[:-1]))
# Adjust the compatibility of ws and gws # Adjust the compatibility of ws and gws
stage_widths, stage_groups = adjust_widths_groups_comp(stage_widths, stage_bottle_ratios, stage_groups) stage_widths, stage_groups = adjust_widths_groups_comp(stage_widths, stage_bottle_ratios, stage_groups)
param_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_width'] param_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_width', 'drop_path_rate']
stage_params = [ stage_params = [
dict(zip(param_names, params)) for params in dict(zip(param_names, params)) for params in
zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_bottle_ratios, stage_groups)] zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_bottle_ratios, stage_groups,
stage_dpr)]
return stage_params return stage_params
def get_classifier(self): def get_classifier(self):

Loading…
Cancel
Save