|
|
@ -18,7 +18,7 @@ import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
from .helpers import build_model_with_cfg
|
|
|
|
from .helpers import build_model_with_cfg
|
|
|
|
from .layers import ClassifierHead, AvgPool2dSame, ConvBnAct, SEModule
|
|
|
|
from .layers import ClassifierHead, AvgPool2dSame, ConvBnAct, SEModule, DropPath
|
|
|
|
from .registry import register_model
|
|
|
|
from .registry import register_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -195,7 +195,7 @@ class RegStage(nn.Module):
|
|
|
|
"""Stage (sequence of blocks w/ the same output shape)."""
|
|
|
|
"""Stage (sequence of blocks w/ the same output shape)."""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio, group_width,
|
|
|
|
def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio, group_width,
|
|
|
|
block_fn=Bottleneck, se_ratio=0.):
|
|
|
|
block_fn=Bottleneck, se_ratio=0., drop_path_rate=None, drop_block=None):
|
|
|
|
super(RegStage, self).__init__()
|
|
|
|
super(RegStage, self).__init__()
|
|
|
|
block_kwargs = {} # FIXME setup to pass various aa, norm, act layer common args
|
|
|
|
block_kwargs = {} # FIXME setup to pass various aa, norm, act layer common args
|
|
|
|
first_dilation = 1 if dilation in (1, 2) else 2
|
|
|
|
first_dilation = 1 if dilation in (1, 2) else 2
|
|
|
@ -203,6 +203,7 @@ class RegStage(nn.Module):
|
|
|
|
block_stride = stride if i == 0 else 1
|
|
|
|
block_stride = stride if i == 0 else 1
|
|
|
|
block_in_chs = in_chs if i == 0 else out_chs
|
|
|
|
block_in_chs = in_chs if i == 0 else out_chs
|
|
|
|
block_dilation = first_dilation if i == 0 else dilation
|
|
|
|
block_dilation = first_dilation if i == 0 else dilation
|
|
|
|
|
|
|
|
drop_path = DropPath(drop_path_rate[i]) if drop_path_rate is not None else None
|
|
|
|
if (block_in_chs != out_chs) or (block_stride != 1):
|
|
|
|
if (block_in_chs != out_chs) or (block_stride != 1):
|
|
|
|
proj_block = downsample_conv(block_in_chs, out_chs, 1, block_stride, block_dilation)
|
|
|
|
proj_block = downsample_conv(block_in_chs, out_chs, 1, block_stride, block_dilation)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
@ -212,7 +213,7 @@ class RegStage(nn.Module):
|
|
|
|
self.add_module(
|
|
|
|
self.add_module(
|
|
|
|
name, block_fn(
|
|
|
|
name, block_fn(
|
|
|
|
block_in_chs, out_chs, block_stride, block_dilation, bottle_ratio, group_width, se_ratio,
|
|
|
|
block_in_chs, out_chs, block_stride, block_dilation, bottle_ratio, group_width, se_ratio,
|
|
|
|
downsample=proj_block, **block_kwargs)
|
|
|
|
downsample=proj_block, drop_block=drop_block, drop_path=drop_path, **block_kwargs)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
@ -229,7 +230,7 @@ class RegNet(nn.Module):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, cfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', drop_rate=0.,
|
|
|
|
def __init__(self, cfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', drop_rate=0.,
|
|
|
|
zero_init_last_bn=True):
|
|
|
|
drop_path_rate=0., zero_init_last_bn=True):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
# TODO add drop block, drop path, anti-aliasing, custom bn/act args
|
|
|
|
# TODO add drop block, drop path, anti-aliasing, custom bn/act args
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.num_classes = num_classes
|
|
|
@ -244,7 +245,7 @@ class RegNet(nn.Module):
|
|
|
|
# Construct the stages
|
|
|
|
# Construct the stages
|
|
|
|
prev_width = stem_width
|
|
|
|
prev_width = stem_width
|
|
|
|
curr_stride = 2
|
|
|
|
curr_stride = 2
|
|
|
|
stage_params = self._get_stage_params(cfg, output_stride=output_stride)
|
|
|
|
stage_params = self._get_stage_params(cfg, output_stride=output_stride, drop_path_rate=drop_path_rate)
|
|
|
|
se_ratio = cfg['se_ratio']
|
|
|
|
se_ratio = cfg['se_ratio']
|
|
|
|
for i, stage_args in enumerate(stage_params):
|
|
|
|
for i, stage_args in enumerate(stage_params):
|
|
|
|
stage_name = "s{}".format(i + 1)
|
|
|
|
stage_name = "s{}".format(i + 1)
|
|
|
@ -272,7 +273,7 @@ class RegNet(nn.Module):
|
|
|
|
if hasattr(m, 'zero_init_last_bn'):
|
|
|
|
if hasattr(m, 'zero_init_last_bn'):
|
|
|
|
m.zero_init_last_bn()
|
|
|
|
m.zero_init_last_bn()
|
|
|
|
|
|
|
|
|
|
|
|
def _get_stage_params(self, cfg, default_stride=2, output_stride=32):
|
|
|
|
def _get_stage_params(self, cfg, default_stride=2, output_stride=32, drop_path_rate=0.):
|
|
|
|
# Generate RegNet ws per block
|
|
|
|
# Generate RegNet ws per block
|
|
|
|
w_a, w_0, w_m, d = cfg['wa'], cfg['w0'], cfg['wm'], cfg['depth']
|
|
|
|
w_a, w_0, w_m, d = cfg['wa'], cfg['w0'], cfg['wm'], cfg['depth']
|
|
|
|
widths, num_stages, _, _ = generate_regnet(w_a, w_0, w_m, d)
|
|
|
|
widths, num_stages, _, _ = generate_regnet(w_a, w_0, w_m, d)
|
|
|
@ -285,24 +286,26 @@ class RegNet(nn.Module):
|
|
|
|
stage_bottle_ratios = [cfg['bottle_ratio'] for _ in range(num_stages)]
|
|
|
|
stage_bottle_ratios = [cfg['bottle_ratio'] for _ in range(num_stages)]
|
|
|
|
stage_strides = []
|
|
|
|
stage_strides = []
|
|
|
|
stage_dilations = []
|
|
|
|
stage_dilations = []
|
|
|
|
total_stride = 2
|
|
|
|
net_stride = 2
|
|
|
|
dilation = 1
|
|
|
|
dilation = 1
|
|
|
|
for _ in range(num_stages):
|
|
|
|
for _ in range(num_stages):
|
|
|
|
if total_stride >= output_stride:
|
|
|
|
if net_stride >= output_stride:
|
|
|
|
dilation *= default_stride
|
|
|
|
dilation *= default_stride
|
|
|
|
stride = 1
|
|
|
|
stride = 1
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
stride = default_stride
|
|
|
|
stride = default_stride
|
|
|
|
total_stride *= stride
|
|
|
|
net_stride *= stride
|
|
|
|
stage_strides.append(stride)
|
|
|
|
stage_strides.append(stride)
|
|
|
|
stage_dilations.append(dilation)
|
|
|
|
stage_dilations.append(dilation)
|
|
|
|
|
|
|
|
stage_dpr = np.split(np.linspace(0, drop_path_rate, d), np.cumsum(stage_depths[:-1]))
|
|
|
|
|
|
|
|
|
|
|
|
# Adjust the compatibility of ws and gws
|
|
|
|
# Adjust the compatibility of ws and gws
|
|
|
|
stage_widths, stage_groups = adjust_widths_groups_comp(stage_widths, stage_bottle_ratios, stage_groups)
|
|
|
|
stage_widths, stage_groups = adjust_widths_groups_comp(stage_widths, stage_bottle_ratios, stage_groups)
|
|
|
|
param_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_width']
|
|
|
|
param_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_width', 'drop_path_rate']
|
|
|
|
stage_params = [
|
|
|
|
stage_params = [
|
|
|
|
dict(zip(param_names, params)) for params in
|
|
|
|
dict(zip(param_names, params)) for params in
|
|
|
|
zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_bottle_ratios, stage_groups)]
|
|
|
|
zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_bottle_ratios, stage_groups,
|
|
|
|
|
|
|
|
stage_dpr)]
|
|
|
|
return stage_params
|
|
|
|
return stage_params
|
|
|
|
|
|
|
|
|
|
|
|
def get_classifier(self):
|
|
|
|
def get_classifier(self):
|
|
|
|