Refactor cspnet configuration using dataclasses, update feature extraction for new cs3 variants.

pull/1327/head
Ross Wightman 2 years ago
parent eca09b8642
commit db0cee9910

@ -12,7 +12,10 @@ Reference impl via darknet cfg files at https://github.com/WongKinYiu/CrossStage
Hacked together by / Copyright 2020 Ross Wightman
"""
import collections.abc
from dataclasses import dataclass, field, asdict
from functools import partial
from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -20,7 +23,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, named_apply, MATCH_PREV_GROUP
from .layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, create_attn, get_norm_act_layer
from .layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, create_attn, create_act_layer, make_divisible
from .registry import register_model
@ -58,218 +61,278 @@ default_cfgs = {
),
'darknetaa53': _cfg(url=''),
'cs3darknet_s': _cfg(
url=''),
'cs3darknet_m': _cfg(
url=''),
'cs3darknet_l': _cfg(
url=''),
'cs3darknet_x': _cfg(
url=''),
'cs3darknet_focus_s': _cfg(
url=''),
'cs3darknet_focus_m': _cfg(
url=''),
'cs3darknet_focus_l': _cfg(
url=''),
'cs3darknet_focus_x': _cfg(
url=''),
'cs3sedarknet_xdw': _cfg(
url=''),
}
@dataclass
class CspStemCfg:
out_chs: Union[int, Tuple[int, ...]] = 32
stride: Union[int, Tuple[int, ...]] = 2
kernel_size: int = 3
padding: Union[int, str] = ''
pool: Optional[str] = ''
def _pad_arg(x, n):
# pads an argument tuple to specified n by padding with last value
if not isinstance(x, (tuple, list)):
x = (x,)
curr_n = len(x)
pad_n = n - curr_n
if pad_n <= 0:
return x[:n]
return tuple(x + (x[-1],) * pad_n)
@dataclass
class CspStagesCfg:
depth: Tuple[int, ...] = (3, 3, 5, 2) # block depth (number of block repeats in stages)
out_chs: Tuple[int, ...] = (128, 256, 512, 1024) # number of output channels for blocks in stage
stride: Union[int, Tuple[int, ...]] = 2 # stride of stage
groups: Union[int, Tuple[int, ...]] = 1 # num kxk conv groups
block_ratio: Union[float, Tuple[float, ...]] = 1.0
bottle_ratio: Union[float, Tuple[float, ...]] = 1. # bottleneck-ratio of blocks in stage
avg_down: Union[bool, Tuple[bool, ...]] = False
attn_layer: Optional[Union[str, Tuple[str, ...]]] = None
stage_type: Union[str, Tuple[str]] = 'csp' # stage type ('csp', 'cs2', 'dark')
block_type: Union[str, Tuple[str]] = 'bottle' # blocks type for stages ('bottle', 'dark')
# cross-stage only
expand_ratio: Union[float, Tuple[float, ...]] = 1.0
cross_linear: Union[bool, Tuple[bool, ...]] = False
down_growth: Union[bool, Tuple[bool, ...]] = False
def __post_init__(self):
n = len(self.depth)
assert len(self.out_chs) == n
self.stride = _pad_arg(self.stride, n)
self.groups = _pad_arg(self.groups, n)
self.block_ratio = _pad_arg(self.block_ratio, n)
self.bottle_ratio = _pad_arg(self.bottle_ratio, n)
self.avg_down = _pad_arg(self.avg_down, n)
self.attn_layer = _pad_arg(self.attn_layer, n)
self.stage_type = _pad_arg(self.stage_type, n)
self.block_type = _pad_arg(self.block_type, n)
self.expand_ratio = _pad_arg(self.expand_ratio, n)
self.cross_linear = _pad_arg(self.cross_linear, n)
self.down_growth = _pad_arg(self.down_growth, n)
@dataclass
class CspModelCfg:
stem: CspStemCfg
stages: CspStagesCfg
zero_init_last: bool = True # zero init last weight (usually bn) in residual path
act_layer: str = 'relu'
norm_layer: str = 'batchnorm'
aa_layer: Optional[str] = None # FIXME support string factory for this
def _cs3darknet_cfg(width_multiplier=1.0, depth_multiplier=1.0, avg_down=False, act_layer='silu', focus=False):
if focus:
stem_cfg = CspStemCfg(
out_chs=make_divisible(64 * width_multiplier),
kernel_size=6, stride=2, padding=2, pool='')
else:
stem_cfg = CspStemCfg(
out_chs=tuple([make_divisible(c * width_multiplier) for c in (32, 64)]),
kernel_size=3, stride=2, pool='')
return CspModelCfg(
stem=stem_cfg,
stages=CspStagesCfg(
out_chs=tuple([make_divisible(c * width_multiplier) for c in (128, 256, 512, 1024)]),
depth=tuple([int(d * depth_multiplier) for d in (3, 6, 9, 3)]),
stride=2,
bottle_ratio=1.,
block_ratio=0.5,
avg_down=avg_down,
stage_type='cs3',
block_type='dark',
),
act_layer=act_layer,
)
model_cfgs = dict(
cspresnet50=dict(
stem=dict(out_chs=64, kernel_size=7, stride=2, pool='max'),
stage=dict(
out_chs=(128, 256, 512, 1024),
cspresnet50=CspModelCfg(
stem=CspStemCfg(out_chs=64, kernel_size=7, stride=4, pool='max'),
stages=CspStagesCfg(
depth=(3, 3, 5, 2),
stride=(1,) + (2,) * 3,
exp_ratio=(2.,) * 4,
bottle_ratio=(0.5,) * 4,
block_ratio=(1.,) * 4,
out_chs=(128, 256, 512, 1024),
stride=(1, 2),
expand_ratio=2.,
bottle_ratio=0.5,
cross_linear=True,
)
),
),
cspresnet50d=dict(
stem=dict(out_chs=[32, 32, 64], kernel_size=3, stride=2, pool='max'),
stage=dict(
out_chs=(128, 256, 512, 1024),
cspresnet50d=CspModelCfg(
stem=CspStemCfg(out_chs=(32, 32, 64), kernel_size=3, stride=4, pool='max'),
stages=CspStagesCfg(
depth=(3, 3, 5, 2),
stride=(1,) + (2,) * 3,
exp_ratio=(2.,) * 4,
bottle_ratio=(0.5,) * 4,
block_ratio=(1.,) * 4,
out_chs=(128, 256, 512, 1024),
stride=(1,) + (2,),
expand_ratio=2.,
bottle_ratio=0.5,
block_ratio=1.,
cross_linear=True,
)
),
cspresnet50w=dict(
stem=dict(out_chs=[32, 32, 64], kernel_size=3, stride=2, pool='max'),
stage=dict(
out_chs=(256, 512, 1024, 2048),
cspresnet50w=CspModelCfg(
stem=CspStemCfg(out_chs=(32, 32, 64), kernel_size=3, stride=4, pool='max'),
stages=CspStagesCfg(
depth=(3, 3, 5, 2),
stride=(1,) + (2,) * 3,
exp_ratio=(1.,) * 4,
bottle_ratio=(0.25,) * 4,
block_ratio=(0.5,) * 4,
out_chs=(256, 512, 1024, 2048),
stride=(1,) + (2,),
expand_ratio=1.,
bottle_ratio=0.25,
block_ratio=0.5,
cross_linear=True,
)
),
cspresnext50=dict(
stem=dict(out_chs=64, kernel_size=7, stride=2, pool='max'),
stage=dict(
out_chs=(256, 512, 1024, 2048),
cspresnext50=CspModelCfg(
stem=CspStemCfg(out_chs=64, kernel_size=7, stride=4, pool='max'),
stages=CspStagesCfg(
depth=(3, 3, 5, 2),
stride=(1,) + (2,) * 3,
groups=(32,) * 4,
exp_ratio=(1.,) * 4,
bottle_ratio=(1.,) * 4,
block_ratio=(0.5,) * 4,
out_chs=(256, 512, 1024, 2048),
stride=(1,) + (2,),
groups=32,
expand_ratio=1.,
bottle_ratio=1.,
block_ratio=0.5,
cross_linear=True,
)
),
cspdarknet53=dict(
stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''),
stage=dict(
out_chs=(64, 128, 256, 512, 1024),
cspdarknet53=CspModelCfg(
stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
stages=CspStagesCfg(
depth=(1, 2, 8, 8, 4),
stride=(2,) * 5,
exp_ratio=(2.,) + (1.,) * 4,
bottle_ratio=(0.5,) + (1.0,) * 4,
block_ratio=(1.,) + (0.5,) * 4,
out_chs=(64, 128, 256, 512, 1024),
stride=2,
expand_ratio=(2.,) + (1.,),
bottle_ratio=(0.5,) + (1.,),
block_ratio=(1.,) + (0.5,),
down_growth=True,
)
block_type='dark',
),
act_layer='leaky_relu',
),
darknet17=dict(
stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''),
stage=dict(
out_chs=(64, 128, 256, 512, 1024),
darknet17=CspModelCfg(
stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
stages=CspStagesCfg(
depth=(1,) * 5,
stride=(2,) * 5,
bottle_ratio=(0.5,) * 5,
block_ratio=(1.,) * 5,
)
),
darknet21=dict(
stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''),
stage=dict(
out_chs=(64, 128, 256, 512, 1024),
depth=(1, 1, 1, 2, 2),
stride=(2,) * 5,
bottle_ratio=(0.5,) * 5,
block_ratio=(1.,) * 5,
)
stride=(2,),
bottle_ratio=(0.5,),
block_ratio=(1.,),
stage_type='dark',
block_type='dark',
),
act_layer='leaky_relu',
),
sedarknet21=dict(
stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''),
stage=dict(
out_chs=(64, 128, 256, 512, 1024),
darknet21=CspModelCfg(
stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
stages=CspStagesCfg(
depth=(1, 1, 1, 2, 2),
stride=(2,) * 5,
bottle_ratio=(0.5,) * 5,
block_ratio=(1.,) * 5,
attn_layer=('se',) * 5,
)
),
darknet53=dict(
stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''),
stage=dict(
out_chs=(64, 128, 256, 512, 1024),
depth=(1, 2, 8, 8, 4),
stride=(2,) * 5,
bottle_ratio=(0.5,) * 5,
block_ratio=(1.,) * 5,
)
),
stride=(2,),
bottle_ratio=(0.5,),
block_ratio=(1.,),
stage_type='dark',
block_type='dark',
darknetaa53=dict(
stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''),
stage=dict(
out_chs=(64, 128, 256, 512, 1024),
depth=(1, 2, 8, 8, 4),
stride=(2,) * 5,
bottle_ratio=(0.5,) * 5,
block_ratio=(1.,) * 5,
avg_down=True,
),
act_layer='leaky_relu',
),
sedarknet21=CspModelCfg(
stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
stages=CspStagesCfg(
depth=(1, 1, 1, 2, 2),
out_chs=(64, 128, 256, 512, 1024),
stride=2,
bottle_ratio=0.5,
block_ratio=1.,
attn_layer='se',
stage_type='dark',
block_type='dark',
cs3darknet_m=dict(
stem=dict(out_chs=(24, 48), kernel_size=3, stride=2, pool=''),
stage=dict(
out_chs=(96, 192, 384, 768),
depth=(2, 4, 6, 2),
stride=(2,) * 4,
bottle_ratio=(1.,) * 4,
block_ratio=(0.5,) * 4,
avg_down=False,
),
act_layer='leaky_relu',
),
cs3darknet_l=dict(
stem=dict(out_chs=(32, 64), kernel_size=3, stride=2, pool=''),
stage=dict(
out_chs=(128, 256, 512, 1024),
depth=(3, 6, 9, 3),
stride=(2,) * 4,
bottle_ratio=(1.,) * 4,
block_ratio=(0.5,) * 4,
avg_down=False,
darknet53=CspModelCfg(
stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
stages=CspStagesCfg(
depth=(1, 2, 8, 8, 4),
out_chs=(64, 128, 256, 512, 1024),
stride=2,
bottle_ratio=0.5,
block_ratio=1.,
stage_type='dark',
block_type='dark',
),
act_layer='leaky_relu',
),
cs3darknet_focus_m=dict(
stem=dict(out_chs=48, kernel_size=6, stride=2, padding=2, pool=''),
stage=dict(
out_chs=(96, 192, 384, 768),
depth=(2, 4, 6, 2),
stride=(2,) * 4,
bottle_ratio=(1.,) * 4,
block_ratio=(0.5,) * 4,
avg_down=False,
darknetaa53=CspModelCfg(
stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
stages=CspStagesCfg(
depth=(1, 2, 8, 8, 4),
out_chs=(64, 128, 256, 512, 1024),
stride=2,
bottle_ratio=0.5,
block_ratio=1.,
avg_down=True,
stage_type='dark',
block_type='dark',
),
act_layer='leaky_relu',
),
cs3darknet_focus_l=dict(
stem=dict(out_chs=64, kernel_size=6, stride=2, padding=2, pool=''),
stage=dict(
out_chs=(128, 256, 512, 1024),
depth=(3, 6, 9, 3),
stride=(2,) * 4,
bottle_ratio=(1.,) * 4,
block_ratio=(0.5,) * 4,
avg_down=False,
),
)
)
cs3darknet_s=_cs3darknet_cfg(width_multiplier=0.5, depth_multiplier=0.5),
cs3darknet_m=_cs3darknet_cfg(width_multiplier=0.75, depth_multiplier=0.67),
cs3darknet_l=_cs3darknet_cfg(),
cs3darknet_x=_cs3darknet_cfg(width_multiplier=1.25, depth_multiplier=1.33),
def create_stem(
in_chans=3,
out_chs=32,
kernel_size=3,
stride=2,
pool='',
padding='',
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
aa_layer=None
):
stem = nn.Sequential()
if not isinstance(out_chs, (tuple, list)):
out_chs = [out_chs]
assert len(out_chs)
in_c = in_chans
for i, out_c in enumerate(out_chs):
conv_name = f'conv{i + 1}'
stem.add_module(conv_name, ConvNormAct(
in_c, out_c, kernel_size,
stride=stride if i == 0 else 1,
padding=padding if i == 0 else '',
act_layer=act_layer,
norm_layer=norm_layer
))
in_c = out_c
last_conv = conv_name
if pool:
if aa_layer is not None:
stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=1, padding=1))
stem.add_module('aa', aa_layer(channels=in_c, stride=2))
else:
stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
return stem, dict(num_chs=in_c, reduction=stride, module='.'.join(['stem', last_conv]))
cs3darknet_focus_s=_cs3darknet_cfg(width_multiplier=0.5, depth_multiplier=0.5, focus=True),
cs3darknet_focus_m=_cs3darknet_cfg(width_multiplier=0.75, depth_multiplier=0.67, focus=True),
cs3darknet_focus_l=_cs3darknet_cfg(focus=True),
cs3darknet_focus_x=_cs3darknet_cfg(width_multiplier=1.25, depth_multiplier=1.33, focus=True),
cs3sedarknet_xdw=CspModelCfg(
stem=CspStemCfg(out_chs=(32, 64), kernel_size=3, stride=2, pool=''),
stages=CspStagesCfg(
depth=(3, 6, 12, 4),
out_chs=(256, 512, 1024, 2048),
stride=2,
groups=(1, 1, 256, 512),
bottle_ratio=0.5,
block_ratio=0.5,
attn_layer='se',
),
),
)
class ResBottleneck(nn.Module):
class BottleneckBlock(nn.Module):
""" ResNe(X)t Bottleneck Block
"""
@ -286,9 +349,9 @@ class ResBottleneck(nn.Module):
attn_layer=None,
aa_layer=None,
drop_block=None,
drop_path=None
drop_path=0.
):
super(ResBottleneck, self).__init__()
super(BottleneckBlock, self).__init__()
mid_chs = int(round(out_chs * bottle_ratio))
ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
@ -299,8 +362,8 @@ class ResBottleneck(nn.Module):
self.attn2 = create_attn(attn_layer, channels=mid_chs) if not attn_last else None
self.conv3 = ConvNormAct(mid_chs, out_chs, kernel_size=1, apply_act=False, **ckwargs)
self.attn3 = create_attn(attn_layer, channels=out_chs) if attn_last else None
self.drop_path = drop_path
self.act3 = act_layer()
self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
self.act3 = create_act_layer(act_layer)
def zero_init_last(self):
nn.init.zeros_(self.conv3.bn.weight)
@ -314,9 +377,7 @@ class ResBottleneck(nn.Module):
x = self.conv3(x)
if self.attn3 is not None:
x = self.attn3(x)
if self.drop_path is not None:
x = self.drop_path(x)
x = x + shortcut
x = self.drop_path(x) + shortcut
# FIXME partial shortcut needed if first block handled as per original, not used for my current impl
#x[:, :shortcut.size(1)] += shortcut
x = self.act3(x)
@ -339,7 +400,7 @@ class DarkBlock(nn.Module):
attn_layer=None,
aa_layer=None,
drop_block=None,
drop_path=None
drop_path=0.
):
super(DarkBlock, self).__init__()
mid_chs = int(round(out_chs * bottle_ratio))
@ -349,7 +410,7 @@ class DarkBlock(nn.Module):
mid_chs, out_chs, kernel_size=3, dilation=dilation, groups=groups,
aa_layer=aa_layer, drop_layer=drop_block, **ckwargs)
self.attn = create_attn(attn_layer, channels=out_chs, act_layer=act_layer)
self.drop_path = drop_path
self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
def zero_init_last(self):
nn.init.zeros_(self.conv2.bn.weight)
@ -360,9 +421,7 @@ class DarkBlock(nn.Module):
x = self.conv2(x)
if self.attn is not None:
x = self.attn(x)
if self.drop_path is not None:
x = self.drop_path(x)
x = x + shortcut
x = self.drop_path(x) + shortcut
return x
@ -377,27 +436,27 @@ class CrossStage(nn.Module):
depth,
block_ratio=1.,
bottle_ratio=1.,
exp_ratio=1.,
expand_ratio=1.,
groups=1,
first_dilation=None,
avg_down=False,
down_growth=False,
cross_linear=False,
block_dpr=None,
block_fn=ResBottleneck,
block_fn=BottleneckBlock,
**block_kwargs
):
super(CrossStage, self).__init__()
first_dilation = first_dilation or dilation
down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels
self.exp_chs = exp_chs = int(round(out_chs * exp_ratio))
self.expand_chs = exp_chs = int(round(out_chs * expand_ratio))
block_out_chs = int(round(out_chs * block_ratio))
conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'))
if stride != 1 or first_dilation != dilation:
if avg_down:
self.conv_down = nn.Sequential(
nn.AvgPool2d(3, 2, 1) if stride == 2 else nn.Identity(), # FIXME dilation handling
nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
)
else:
@ -417,9 +476,15 @@ class CrossStage(nn.Module):
self.blocks = nn.Sequential()
for i in range(depth):
drop_path = DropPath(block_dpr[i]) if block_dpr and block_dpr[i] else None
self.blocks.add_module(str(i), block_fn(
prev_chs, block_out_chs, dilation, bottle_ratio, groups, drop_path=drop_path, **block_kwargs))
in_chs=prev_chs,
out_chs=block_out_chs,
dilation=dilation,
bottle_ratio=bottle_ratio,
groups=groups,
drop_path=block_dpr[i] if block_dpr is not None else 0.,
**block_kwargs
))
prev_chs = block_out_chs
# transition convs
@ -429,7 +494,7 @@ class CrossStage(nn.Module):
def forward(self, x):
x = self.conv_down(x)
x = self.conv_exp(x)
xs, xb = x.split(self.exp_chs // 2, dim=1)
xs, xb = x.split(self.expand_chs // 2, dim=1)
xb = self.blocks(xb)
xb = self.conv_transition_b(xb).contiguous()
out = self.conv_transition(torch.cat([xs, xb], dim=1))
@ -449,27 +514,27 @@ class CrossStage3(nn.Module):
depth,
block_ratio=1.,
bottle_ratio=1.,
exp_ratio=1.,
expand_ratio=1.,
groups=1,
first_dilation=None,
avg_down=False,
down_growth=False,
cross_linear=False,
block_dpr=None,
block_fn=ResBottleneck,
block_fn=BottleneckBlock,
**block_kwargs
):
super(CrossStage3, self).__init__()
first_dilation = first_dilation or dilation
down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels
self.exp_chs = exp_chs = int(round(out_chs * exp_ratio))
self.expand_chs = exp_chs = int(round(out_chs * expand_ratio))
block_out_chs = int(round(out_chs * block_ratio))
conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'))
if stride != 1 or first_dilation != dilation:
if avg_down:
self.conv_down = nn.Sequential(
nn.AvgPool2d(3, 2, 1) if stride == 2 else nn.Identity(), # FIXME dilation handling
nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
)
else:
@ -487,9 +552,15 @@ class CrossStage3(nn.Module):
self.blocks = nn.Sequential()
for i in range(depth):
drop_path = DropPath(block_dpr[i]) if block_dpr and block_dpr[i] else None
self.blocks.add_module(str(i), block_fn(
prev_chs, block_out_chs, dilation, bottle_ratio, groups, drop_path=drop_path, **block_kwargs))
in_chs=prev_chs,
out_chs=block_out_chs,
dilation=dilation,
bottle_ratio=bottle_ratio,
groups=groups,
drop_path=block_dpr[i] if block_dpr is not None else 0.,
**block_kwargs
))
prev_chs = block_out_chs
# transition convs
@ -498,7 +569,7 @@ class CrossStage3(nn.Module):
def forward(self, x):
x = self.conv_down(x)
x = self.conv_exp(x)
x1, x2 = x.split(self.exp_chs // 2, dim=1)
x1, x2 = x.split(self.expand_chs // 2, dim=1)
x1 = self.blocks(x1)
out = self.conv_transition(torch.cat([x1, x2], dim=1))
return out
@ -519,7 +590,7 @@ class DarkStage(nn.Module):
groups=1,
first_dilation=None,
avg_down=False,
block_fn=ResBottleneck,
block_fn=BottleneckBlock,
block_dpr=None,
**block_kwargs
):
@ -529,7 +600,7 @@ class DarkStage(nn.Module):
if avg_down:
self.conv_down = nn.Sequential(
nn.AvgPool2d(3, 2, 1) if stride == 2 else nn.Identity(), # FIXME dilation handling
nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
)
else:
@ -541,9 +612,15 @@ class DarkStage(nn.Module):
block_out_chs = int(round(out_chs * block_ratio))
self.blocks = nn.Sequential()
for i in range(depth):
drop_path = DropPath(block_dpr[i]) if block_dpr and block_dpr[i] else None
self.blocks.add_module(str(i), block_fn(
prev_chs, block_out_chs, dilation, bottle_ratio, groups, drop_path=drop_path, **block_kwargs))
in_chs=prev_chs,
out_chs=block_out_chs,
dilation=dilation,
bottle_ratio=bottle_ratio,
groups=groups,
drop_path=block_dpr[i] if block_dpr is not None else 0.,
**block_kwargs
))
prev_chs = block_out_chs
def forward(self, x):
@ -552,38 +629,131 @@ class DarkStage(nn.Module):
return x
def _cfg_to_stage_args(cfg, curr_stride=2, output_stride=32, drop_path_rate=0.):
# get per stage args for stage and containing blocks, calculate strides to meet target output_stride
num_stages = len(cfg['depth'])
if 'groups' not in cfg:
cfg['groups'] = (1,) * num_stages
if 'down_growth' in cfg and not isinstance(cfg['down_growth'], (list, tuple)):
cfg['down_growth'] = (cfg['down_growth'],) * num_stages
if 'cross_linear' in cfg and not isinstance(cfg['cross_linear'], (list, tuple)):
cfg['cross_linear'] = (cfg['cross_linear'],) * num_stages
if 'avg_down' in cfg and not isinstance(cfg['avg_down'], (list, tuple)):
cfg['avg_down'] = (cfg['avg_down'],) * num_stages
cfg['block_dpr'] = [None] * num_stages if not drop_path_rate else \
[x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg['depth'])).split(cfg['depth'])]
stage_strides = []
stage_dilations = []
stage_first_dilations = []
def create_csp_stem(
in_chans=3,
out_chs=32,
kernel_size=3,
stride=2,
pool='',
padding='',
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
aa_layer=None
):
stem = nn.Sequential()
feature_info = []
if not isinstance(out_chs, (tuple, list)):
out_chs = [out_chs]
stem_depth = len(out_chs)
assert stem_depth
assert stride in (1, 2, 4)
prev_feat = None
prev_chs = in_chans
last_idx = stem_depth - 1
stem_stride = 1
for i, chs in enumerate(out_chs):
conv_name = f'conv{i + 1}'
conv_stride = 2 if (i == 0 and stride > 1) or (i == last_idx and stride > 2 and not pool) else 1
if conv_stride > 1 and prev_feat is not None:
feature_info.append(prev_feat)
stem.add_module(conv_name, ConvNormAct(
prev_chs, chs, kernel_size,
stride=conv_stride,
padding=padding if i == 0 else '',
act_layer=act_layer,
norm_layer=norm_layer
))
stem_stride *= conv_stride
prev_chs = chs
prev_feat = dict(num_chs=prev_chs, reduction=stem_stride, module='.'.join(['stem', conv_name]))
if pool:
assert stride > 2
if prev_feat is not None:
feature_info.append(prev_feat)
if aa_layer is not None:
stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=1, padding=1))
stem.add_module('aa', aa_layer(channels=prev_chs, stride=2))
pool_name = 'aa'
else:
stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
pool_name = 'pool'
stem_stride *= 2
prev_feat = dict(num_chs=prev_chs, reduction=stem_stride, module='.'.join(['stem', pool_name]))
feature_info.append(prev_feat)
return stem, feature_info
def _get_stage_fn(stage_type: str, stage_args):
assert stage_type in ('dark', 'csp', 'cs3')
if stage_type == 'dark':
stage_args.pop('expand_ratio', None)
stage_args.pop('cross_linear', None)
stage_args.pop('down_growth', None)
stage_fn = DarkStage
elif stage_type == 'csp':
stage_fn = CrossStage
else:
stage_fn = CrossStage3
return stage_fn, stage_args
def _get_block_fn(stage_type: str, stage_args):
assert stage_type in ('dark', 'bottle')
if stage_type == 'dark':
return DarkBlock, stage_args
else:
return BottleneckBlock, stage_args
def create_csp_stages(
cfg: CspModelCfg,
drop_path_rate: float,
output_stride: int,
stem_feat: Dict[str, Any]
):
cfg_dict = asdict(cfg.stages)
num_stages = len(cfg.stages.depth)
cfg_dict['block_dpr'] = [None] * num_stages if not drop_path_rate else \
[x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.stages.depth)).split(cfg.stages.depth)]
stage_args = [dict(zip(cfg_dict.keys(), values)) for values in zip(*cfg_dict.values())]
block_kwargs = dict(
act_layer=cfg.act_layer,
norm_layer=cfg.norm_layer,
aa_layer=cfg.aa_layer
)
dilation = 1
for cfg_stride in cfg['stride']:
stage_first_dilations.append(dilation)
if curr_stride >= output_stride:
dilation *= cfg_stride
net_stride = stem_feat['reduction']
prev_chs = stem_feat['num_chs']
prev_feat = stem_feat
feature_info = []
stages = []
for stage_idx, stage_args in enumerate(stage_args):
stage_fn, stage_args = _get_stage_fn(stage_args.pop('stage_type'), stage_args)
block_fn, stage_args = _get_block_fn(stage_args.pop('block_type'), stage_args)
stride = stage_args.pop('stride')
if stride != 1 and prev_feat:
feature_info.append(prev_feat)
if net_stride >= output_stride and stride > 1:
dilation *= stride
stride = 1
else:
stride = cfg_stride
curr_stride *= stride
stage_strides.append(stride)
stage_dilations.append(dilation)
cfg['stride'] = stage_strides
cfg['dilation'] = stage_dilations
cfg['first_dilation'] = stage_first_dilations
stage_args = [dict(zip(cfg.keys(), values)) for values in zip(*cfg.values())]
return stage_args
net_stride *= stride
first_dilation = 1 if dilation in (1, 2) else 2
stages += [stage_fn(
prev_chs,
**stage_args,
stride=stride,
first_dilation=first_dilation,
dilation=dilation,
block_fn=block_fn,
**block_kwargs,
)]
prev_chs = stage_args['out_chs']
prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}')
feature_info.append(prev_feat)
return nn.Sequential(*stages), feature_info
class CspNet(nn.Module):
@ -598,43 +768,39 @@ class CspNet(nn.Module):
def __init__(
self,
cfg,
cfg: CspModelCfg,
in_chans=3,
num_classes=1000,
output_stride=32,
global_pool='avg',
act_layer=nn.LeakyReLU,
norm_layer=nn.BatchNorm2d,
aa_layer=None,
drop_rate=0.,
drop_path_rate=0.,
zero_init_last=True,
stage_fn=CrossStage,
block_fn=ResBottleneck):
zero_init_last=True
):
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
assert output_stride in (8, 16, 32)
layer_args = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer)
layer_args = dict(
act_layer=cfg.act_layer,
norm_layer=cfg.norm_layer,
aa_layer=cfg.aa_layer
)
self.feature_info = []
# Construct the stem
self.stem, stem_feat_info = create_stem(in_chans, **cfg['stem'], **layer_args)
self.feature_info = [stem_feat_info]
prev_chs = stem_feat_info['num_chs']
curr_stride = stem_feat_info['reduction'] # reduction does not include pool
if cfg['stem']['pool']:
curr_stride *= 2
self.stem, stem_feat_info = create_csp_stem(in_chans, **asdict(cfg.stem), **layer_args)
self.feature_info.extend(stem_feat_info[:-1])
# Construct the stages
per_stage_args = _cfg_to_stage_args(
cfg['stage'], curr_stride=curr_stride, output_stride=output_stride, drop_path_rate=drop_path_rate)
self.stages = nn.Sequential()
for i, sa in enumerate(per_stage_args):
self.stages.add_module(
str(i), stage_fn(prev_chs, **sa, **layer_args, block_fn=block_fn))
prev_chs = sa['out_chs']
curr_stride *= sa['stride']
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]
self.stages, stage_feat_info = create_csp_stages(
cfg,
drop_path_rate=drop_path_rate,
output_stride=output_stride,
stem_feat=stem_feat_info[-1],
)
prev_chs = stage_feat_info[-1]['num_chs']
self.feature_info.extend(stage_feat_info)
# Construct the head
self.num_features = prev_chs
@ -729,54 +895,74 @@ def cspresnext50(pretrained=False, **kwargs):
@register_model
def cspdarknet53(pretrained=False, **kwargs):
return _create_cspnet('cspdarknet53', pretrained=pretrained, block_fn=DarkBlock, **kwargs)
return _create_cspnet('cspdarknet53', pretrained=pretrained, **kwargs)
@register_model
def darknet17(pretrained=False, **kwargs):
return _create_cspnet('darknet17', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs)
return _create_cspnet('darknet17', pretrained=pretrained, **kwargs)
@register_model
def darknet21(pretrained=False, **kwargs):
return _create_cspnet('darknet21', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs)
return _create_cspnet('darknet21', pretrained=pretrained, **kwargs)
@register_model
def sedarknet21(pretrained=False, **kwargs):
return _create_cspnet('sedarknet21', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs)
return _create_cspnet('sedarknet21', pretrained=pretrained, **kwargs)
@register_model
def darknet53(pretrained=False, **kwargs):
return _create_cspnet('darknet53', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs)
return _create_cspnet('darknet53', pretrained=pretrained, **kwargs)
@register_model
def darknetaa53(pretrained=False, **kwargs):
return _create_cspnet(
'darknetaa53', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs)
return _create_cspnet('darknetaa53', pretrained=pretrained, **kwargs)
@register_model
def cs3darknet_s(pretrained=False, **kwargs):
return _create_cspnet('cs3darknet_s', pretrained=pretrained, **kwargs)
@register_model
def cs3darknet_m(pretrained=False, **kwargs):
return _create_cspnet(
'cs3darknet_m', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage3, act_layer='silu', **kwargs)
return _create_cspnet('cs3darknet_m', pretrained=pretrained, **kwargs)
@register_model
def cs3darknet_l(pretrained=False, **kwargs):
return _create_cspnet(
'cs3darknet_l', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage3, act_layer='silu', **kwargs)
return _create_cspnet('cs3darknet_l', pretrained=pretrained, **kwargs)
@register_model
def cs3darknet_x(pretrained=False, **kwargs):
return _create_cspnet('cs3darknet_x', pretrained=pretrained, **kwargs)
@register_model
def cs3darknet_focus_s(pretrained=False, **kwargs):
return _create_cspnet('cs3darknet_focus_s', pretrained=pretrained, **kwargs)
@register_model
def cs3darknet_focus_m(pretrained=False, **kwargs):
return _create_cspnet(
'cs3darknet_focus_m', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage3, act_layer='silu', **kwargs)
return _create_cspnet('cs3darknet_focus_m', pretrained=pretrained, **kwargs)
@register_model
def cs3darknet_focus_l(pretrained=False, **kwargs):
return _create_cspnet(
'cs3darknet_focus_l', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage3, act_layer='silu', **kwargs)
return _create_cspnet('cs3darknet_focus_l', pretrained=pretrained, **kwargs)
@register_model
def cs3darknet_focus_x(pretrained=False, **kwargs):
return _create_cspnet('cs3darknet_focus_x', pretrained=pretrained, **kwargs)
@register_model
def cs3sedarknet_xdw(pretrained=False, **kwargs):
return _create_cspnet('cs3sedarknet_xdw', pretrained=pretrained, **kwargs)

Loading…
Cancel
Save