Refactor cspnet configuration using dataclasses, update feature extraction for new cs3 variants.

pull/1327/head
Ross Wightman 2 years ago
parent eca09b8642
commit db0cee9910

@ -12,7 +12,10 @@ Reference impl via darknet cfg files at https://github.com/WongKinYiu/CrossStage
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
import collections.abc
from dataclasses import dataclass, field, asdict
from functools import partial from functools import partial
from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -20,7 +23,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, named_apply, MATCH_PREV_GROUP from .helpers import build_model_with_cfg, named_apply, MATCH_PREV_GROUP
from .layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, create_attn, get_norm_act_layer from .layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, create_attn, create_act_layer, make_divisible
from .registry import register_model from .registry import register_model
@ -58,218 +61,278 @@ default_cfgs = {
), ),
'darknetaa53': _cfg(url=''), 'darknetaa53': _cfg(url=''),
'cs3darknet_s': _cfg(
url=''),
'cs3darknet_m': _cfg( 'cs3darknet_m': _cfg(
url=''), url=''),
'cs3darknet_l': _cfg( 'cs3darknet_l': _cfg(
url=''), url=''),
'cs3darknet_x': _cfg(
url=''),
'cs3darknet_focus_s': _cfg(
url=''),
'cs3darknet_focus_m': _cfg( 'cs3darknet_focus_m': _cfg(
url=''), url=''),
'cs3darknet_focus_l': _cfg( 'cs3darknet_focus_l': _cfg(
url=''), url=''),
'cs3darknet_focus_x': _cfg(
url=''),
'cs3sedarknet_xdw': _cfg(
url=''),
} }
@dataclass
class CspStemCfg:
out_chs: Union[int, Tuple[int, ...]] = 32
stride: Union[int, Tuple[int, ...]] = 2
kernel_size: int = 3
padding: Union[int, str] = ''
pool: Optional[str] = ''
def _pad_arg(x, n):
# pads an argument tuple to specified n by padding with last value
if not isinstance(x, (tuple, list)):
x = (x,)
curr_n = len(x)
pad_n = n - curr_n
if pad_n <= 0:
return x[:n]
return tuple(x + (x[-1],) * pad_n)
@dataclass
class CspStagesCfg:
depth: Tuple[int, ...] = (3, 3, 5, 2) # block depth (number of block repeats in stages)
out_chs: Tuple[int, ...] = (128, 256, 512, 1024) # number of output channels for blocks in stage
stride: Union[int, Tuple[int, ...]] = 2 # stride of stage
groups: Union[int, Tuple[int, ...]] = 1 # num kxk conv groups
block_ratio: Union[float, Tuple[float, ...]] = 1.0
bottle_ratio: Union[float, Tuple[float, ...]] = 1. # bottleneck-ratio of blocks in stage
avg_down: Union[bool, Tuple[bool, ...]] = False
attn_layer: Optional[Union[str, Tuple[str, ...]]] = None
stage_type: Union[str, Tuple[str]] = 'csp' # stage type ('csp', 'cs2', 'dark')
block_type: Union[str, Tuple[str]] = 'bottle' # blocks type for stages ('bottle', 'dark')
# cross-stage only
expand_ratio: Union[float, Tuple[float, ...]] = 1.0
cross_linear: Union[bool, Tuple[bool, ...]] = False
down_growth: Union[bool, Tuple[bool, ...]] = False
def __post_init__(self):
n = len(self.depth)
assert len(self.out_chs) == n
self.stride = _pad_arg(self.stride, n)
self.groups = _pad_arg(self.groups, n)
self.block_ratio = _pad_arg(self.block_ratio, n)
self.bottle_ratio = _pad_arg(self.bottle_ratio, n)
self.avg_down = _pad_arg(self.avg_down, n)
self.attn_layer = _pad_arg(self.attn_layer, n)
self.stage_type = _pad_arg(self.stage_type, n)
self.block_type = _pad_arg(self.block_type, n)
self.expand_ratio = _pad_arg(self.expand_ratio, n)
self.cross_linear = _pad_arg(self.cross_linear, n)
self.down_growth = _pad_arg(self.down_growth, n)
@dataclass
class CspModelCfg:
stem: CspStemCfg
stages: CspStagesCfg
zero_init_last: bool = True # zero init last weight (usually bn) in residual path
act_layer: str = 'relu'
norm_layer: str = 'batchnorm'
aa_layer: Optional[str] = None # FIXME support string factory for this
def _cs3darknet_cfg(width_multiplier=1.0, depth_multiplier=1.0, avg_down=False, act_layer='silu', focus=False):
if focus:
stem_cfg = CspStemCfg(
out_chs=make_divisible(64 * width_multiplier),
kernel_size=6, stride=2, padding=2, pool='')
else:
stem_cfg = CspStemCfg(
out_chs=tuple([make_divisible(c * width_multiplier) for c in (32, 64)]),
kernel_size=3, stride=2, pool='')
return CspModelCfg(
stem=stem_cfg,
stages=CspStagesCfg(
out_chs=tuple([make_divisible(c * width_multiplier) for c in (128, 256, 512, 1024)]),
depth=tuple([int(d * depth_multiplier) for d in (3, 6, 9, 3)]),
stride=2,
bottle_ratio=1.,
block_ratio=0.5,
avg_down=avg_down,
stage_type='cs3',
block_type='dark',
),
act_layer=act_layer,
)
model_cfgs = dict( model_cfgs = dict(
cspresnet50=dict( cspresnet50=CspModelCfg(
stem=dict(out_chs=64, kernel_size=7, stride=2, pool='max'), stem=CspStemCfg(out_chs=64, kernel_size=7, stride=4, pool='max'),
stage=dict( stages=CspStagesCfg(
out_chs=(128, 256, 512, 1024),
depth=(3, 3, 5, 2), depth=(3, 3, 5, 2),
stride=(1,) + (2,) * 3, out_chs=(128, 256, 512, 1024),
exp_ratio=(2.,) * 4, stride=(1, 2),
bottle_ratio=(0.5,) * 4, expand_ratio=2.,
block_ratio=(1.,) * 4, bottle_ratio=0.5,
cross_linear=True, cross_linear=True,
) ),
), ),
cspresnet50d=dict( cspresnet50d=CspModelCfg(
stem=dict(out_chs=[32, 32, 64], kernel_size=3, stride=2, pool='max'), stem=CspStemCfg(out_chs=(32, 32, 64), kernel_size=3, stride=4, pool='max'),
stage=dict( stages=CspStagesCfg(
out_chs=(128, 256, 512, 1024),
depth=(3, 3, 5, 2), depth=(3, 3, 5, 2),
stride=(1,) + (2,) * 3, out_chs=(128, 256, 512, 1024),
exp_ratio=(2.,) * 4, stride=(1,) + (2,),
bottle_ratio=(0.5,) * 4, expand_ratio=2.,
block_ratio=(1.,) * 4, bottle_ratio=0.5,
block_ratio=1.,
cross_linear=True, cross_linear=True,
) )
), ),
cspresnet50w=dict( cspresnet50w=CspModelCfg(
stem=dict(out_chs=[32, 32, 64], kernel_size=3, stride=2, pool='max'), stem=CspStemCfg(out_chs=(32, 32, 64), kernel_size=3, stride=4, pool='max'),
stage=dict( stages=CspStagesCfg(
out_chs=(256, 512, 1024, 2048),
depth=(3, 3, 5, 2), depth=(3, 3, 5, 2),
stride=(1,) + (2,) * 3, out_chs=(256, 512, 1024, 2048),
exp_ratio=(1.,) * 4, stride=(1,) + (2,),
bottle_ratio=(0.25,) * 4, expand_ratio=1.,
block_ratio=(0.5,) * 4, bottle_ratio=0.25,
block_ratio=0.5,
cross_linear=True, cross_linear=True,
) )
), ),
cspresnext50=dict( cspresnext50=CspModelCfg(
stem=dict(out_chs=64, kernel_size=7, stride=2, pool='max'), stem=CspStemCfg(out_chs=64, kernel_size=7, stride=4, pool='max'),
stage=dict( stages=CspStagesCfg(
out_chs=(256, 512, 1024, 2048),
depth=(3, 3, 5, 2), depth=(3, 3, 5, 2),
stride=(1,) + (2,) * 3, out_chs=(256, 512, 1024, 2048),
groups=(32,) * 4, stride=(1,) + (2,),
exp_ratio=(1.,) * 4, groups=32,
bottle_ratio=(1.,) * 4, expand_ratio=1.,
block_ratio=(0.5,) * 4, bottle_ratio=1.,
block_ratio=0.5,
cross_linear=True, cross_linear=True,
) )
), ),
cspdarknet53=dict( cspdarknet53=CspModelCfg(
stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
stage=dict( stages=CspStagesCfg(
out_chs=(64, 128, 256, 512, 1024),
depth=(1, 2, 8, 8, 4), depth=(1, 2, 8, 8, 4),
stride=(2,) * 5, out_chs=(64, 128, 256, 512, 1024),
exp_ratio=(2.,) + (1.,) * 4, stride=2,
bottle_ratio=(0.5,) + (1.0,) * 4, expand_ratio=(2.,) + (1.,),
block_ratio=(1.,) + (0.5,) * 4, bottle_ratio=(0.5,) + (1.,),
block_ratio=(1.,) + (0.5,),
down_growth=True, down_growth=True,
) block_type='dark',
),
act_layer='leaky_relu',
), ),
darknet17=dict( darknet17=CspModelCfg(
stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
stage=dict( stages=CspStagesCfg(
out_chs=(64, 128, 256, 512, 1024),
depth=(1,) * 5, depth=(1,) * 5,
stride=(2,) * 5,
bottle_ratio=(0.5,) * 5,
block_ratio=(1.,) * 5,
)
),
darknet21=dict(
stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''),
stage=dict(
out_chs=(64, 128, 256, 512, 1024), out_chs=(64, 128, 256, 512, 1024),
depth=(1, 1, 1, 2, 2), stride=(2,),
stride=(2,) * 5, bottle_ratio=(0.5,),
bottle_ratio=(0.5,) * 5, block_ratio=(1.,),
block_ratio=(1.,) * 5, stage_type='dark',
) block_type='dark',
),
act_layer='leaky_relu',
), ),
sedarknet21=dict( darknet21=CspModelCfg(
stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
stage=dict( stages=CspStagesCfg(
out_chs=(64, 128, 256, 512, 1024),
depth=(1, 1, 1, 2, 2), depth=(1, 1, 1, 2, 2),
stride=(2,) * 5,
bottle_ratio=(0.5,) * 5,
block_ratio=(1.,) * 5,
attn_layer=('se',) * 5,
)
),
darknet53=dict(
stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''),
stage=dict(
out_chs=(64, 128, 256, 512, 1024), out_chs=(64, 128, 256, 512, 1024),
depth=(1, 2, 8, 8, 4), stride=(2,),
stride=(2,) * 5, bottle_ratio=(0.5,),
bottle_ratio=(0.5,) * 5, block_ratio=(1.,),
block_ratio=(1.,) * 5, stage_type='dark',
) block_type='dark',
),
darknetaa53=dict(
stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''),
stage=dict(
out_chs=(64, 128, 256, 512, 1024),
depth=(1, 2, 8, 8, 4),
stride=(2,) * 5,
bottle_ratio=(0.5,) * 5,
block_ratio=(1.,) * 5,
avg_down=True,
), ),
act_layer='leaky_relu',
), ),
sedarknet21=CspModelCfg(
stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
stages=CspStagesCfg(
depth=(1, 1, 1, 2, 2),
out_chs=(64, 128, 256, 512, 1024),
stride=2,
bottle_ratio=0.5,
block_ratio=1.,
attn_layer='se',
stage_type='dark',
block_type='dark',
cs3darknet_m=dict(
stem=dict(out_chs=(24, 48), kernel_size=3, stride=2, pool=''),
stage=dict(
out_chs=(96, 192, 384, 768),
depth=(2, 4, 6, 2),
stride=(2,) * 4,
bottle_ratio=(1.,) * 4,
block_ratio=(0.5,) * 4,
avg_down=False,
), ),
act_layer='leaky_relu',
), ),
cs3darknet_l=dict( darknet53=CspModelCfg(
stem=dict(out_chs=(32, 64), kernel_size=3, stride=2, pool=''), stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
stage=dict( stages=CspStagesCfg(
out_chs=(128, 256, 512, 1024), depth=(1, 2, 8, 8, 4),
depth=(3, 6, 9, 3), out_chs=(64, 128, 256, 512, 1024),
stride=(2,) * 4, stride=2,
bottle_ratio=(1.,) * 4, bottle_ratio=0.5,
block_ratio=(0.5,) * 4, block_ratio=1.,
avg_down=False, stage_type='dark',
block_type='dark',
), ),
act_layer='leaky_relu',
), ),
darknetaa53=CspModelCfg(
cs3darknet_focus_m=dict( stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
stem=dict(out_chs=48, kernel_size=6, stride=2, padding=2, pool=''), stages=CspStagesCfg(
stage=dict( depth=(1, 2, 8, 8, 4),
out_chs=(96, 192, 384, 768), out_chs=(64, 128, 256, 512, 1024),
depth=(2, 4, 6, 2), stride=2,
stride=(2,) * 4, bottle_ratio=0.5,
bottle_ratio=(1.,) * 4, block_ratio=1.,
block_ratio=(0.5,) * 4, avg_down=True,
avg_down=False, stage_type='dark',
block_type='dark',
), ),
act_layer='leaky_relu',
), ),
cs3darknet_focus_l=dict(
stem=dict(out_chs=64, kernel_size=6, stride=2, padding=2, pool=''),
stage=dict(
out_chs=(128, 256, 512, 1024),
depth=(3, 6, 9, 3),
stride=(2,) * 4,
bottle_ratio=(1.,) * 4,
block_ratio=(0.5,) * 4,
avg_down=False,
),
)
)
cs3darknet_s=_cs3darknet_cfg(width_multiplier=0.5, depth_multiplier=0.5),
cs3darknet_m=_cs3darknet_cfg(width_multiplier=0.75, depth_multiplier=0.67),
cs3darknet_l=_cs3darknet_cfg(),
cs3darknet_x=_cs3darknet_cfg(width_multiplier=1.25, depth_multiplier=1.33),
def create_stem( cs3darknet_focus_s=_cs3darknet_cfg(width_multiplier=0.5, depth_multiplier=0.5, focus=True),
in_chans=3, cs3darknet_focus_m=_cs3darknet_cfg(width_multiplier=0.75, depth_multiplier=0.67, focus=True),
out_chs=32, cs3darknet_focus_l=_cs3darknet_cfg(focus=True),
kernel_size=3, cs3darknet_focus_x=_cs3darknet_cfg(width_multiplier=1.25, depth_multiplier=1.33, focus=True),
stride=2,
pool='', cs3sedarknet_xdw=CspModelCfg(
padding='', stem=CspStemCfg(out_chs=(32, 64), kernel_size=3, stride=2, pool=''),
act_layer=nn.ReLU, stages=CspStagesCfg(
norm_layer=nn.BatchNorm2d, depth=(3, 6, 12, 4),
aa_layer=None out_chs=(256, 512, 1024, 2048),
): stride=2,
stem = nn.Sequential() groups=(1, 1, 256, 512),
if not isinstance(out_chs, (tuple, list)): bottle_ratio=0.5,
out_chs = [out_chs] block_ratio=0.5,
assert len(out_chs) attn_layer='se',
in_c = in_chans ),
for i, out_c in enumerate(out_chs): ),
conv_name = f'conv{i + 1}' )
stem.add_module(conv_name, ConvNormAct(
in_c, out_c, kernel_size,
stride=stride if i == 0 else 1,
padding=padding if i == 0 else '',
act_layer=act_layer,
norm_layer=norm_layer
))
in_c = out_c
last_conv = conv_name
if pool:
if aa_layer is not None:
stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=1, padding=1))
stem.add_module('aa', aa_layer(channels=in_c, stride=2))
else:
stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
return stem, dict(num_chs=in_c, reduction=stride, module='.'.join(['stem', last_conv]))
class ResBottleneck(nn.Module): class BottleneckBlock(nn.Module):
""" ResNe(X)t Bottleneck Block """ ResNe(X)t Bottleneck Block
""" """
@ -286,9 +349,9 @@ class ResBottleneck(nn.Module):
attn_layer=None, attn_layer=None,
aa_layer=None, aa_layer=None,
drop_block=None, drop_block=None,
drop_path=None drop_path=0.
): ):
super(ResBottleneck, self).__init__() super(BottleneckBlock, self).__init__()
mid_chs = int(round(out_chs * bottle_ratio)) mid_chs = int(round(out_chs * bottle_ratio))
ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer) ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
@ -299,8 +362,8 @@ class ResBottleneck(nn.Module):
self.attn2 = create_attn(attn_layer, channels=mid_chs) if not attn_last else None self.attn2 = create_attn(attn_layer, channels=mid_chs) if not attn_last else None
self.conv3 = ConvNormAct(mid_chs, out_chs, kernel_size=1, apply_act=False, **ckwargs) self.conv3 = ConvNormAct(mid_chs, out_chs, kernel_size=1, apply_act=False, **ckwargs)
self.attn3 = create_attn(attn_layer, channels=out_chs) if attn_last else None self.attn3 = create_attn(attn_layer, channels=out_chs) if attn_last else None
self.drop_path = drop_path self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
self.act3 = act_layer() self.act3 = create_act_layer(act_layer)
def zero_init_last(self): def zero_init_last(self):
nn.init.zeros_(self.conv3.bn.weight) nn.init.zeros_(self.conv3.bn.weight)
@ -314,9 +377,7 @@ class ResBottleneck(nn.Module):
x = self.conv3(x) x = self.conv3(x)
if self.attn3 is not None: if self.attn3 is not None:
x = self.attn3(x) x = self.attn3(x)
if self.drop_path is not None: x = self.drop_path(x) + shortcut
x = self.drop_path(x)
x = x + shortcut
# FIXME partial shortcut needed if first block handled as per original, not used for my current impl # FIXME partial shortcut needed if first block handled as per original, not used for my current impl
#x[:, :shortcut.size(1)] += shortcut #x[:, :shortcut.size(1)] += shortcut
x = self.act3(x) x = self.act3(x)
@ -339,7 +400,7 @@ class DarkBlock(nn.Module):
attn_layer=None, attn_layer=None,
aa_layer=None, aa_layer=None,
drop_block=None, drop_block=None,
drop_path=None drop_path=0.
): ):
super(DarkBlock, self).__init__() super(DarkBlock, self).__init__()
mid_chs = int(round(out_chs * bottle_ratio)) mid_chs = int(round(out_chs * bottle_ratio))
@ -349,7 +410,7 @@ class DarkBlock(nn.Module):
mid_chs, out_chs, kernel_size=3, dilation=dilation, groups=groups, mid_chs, out_chs, kernel_size=3, dilation=dilation, groups=groups,
aa_layer=aa_layer, drop_layer=drop_block, **ckwargs) aa_layer=aa_layer, drop_layer=drop_block, **ckwargs)
self.attn = create_attn(attn_layer, channels=out_chs, act_layer=act_layer) self.attn = create_attn(attn_layer, channels=out_chs, act_layer=act_layer)
self.drop_path = drop_path self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
def zero_init_last(self): def zero_init_last(self):
nn.init.zeros_(self.conv2.bn.weight) nn.init.zeros_(self.conv2.bn.weight)
@ -360,9 +421,7 @@ class DarkBlock(nn.Module):
x = self.conv2(x) x = self.conv2(x)
if self.attn is not None: if self.attn is not None:
x = self.attn(x) x = self.attn(x)
if self.drop_path is not None: x = self.drop_path(x) + shortcut
x = self.drop_path(x)
x = x + shortcut
return x return x
@ -377,27 +436,27 @@ class CrossStage(nn.Module):
depth, depth,
block_ratio=1., block_ratio=1.,
bottle_ratio=1., bottle_ratio=1.,
exp_ratio=1., expand_ratio=1.,
groups=1, groups=1,
first_dilation=None, first_dilation=None,
avg_down=False, avg_down=False,
down_growth=False, down_growth=False,
cross_linear=False, cross_linear=False,
block_dpr=None, block_dpr=None,
block_fn=ResBottleneck, block_fn=BottleneckBlock,
**block_kwargs **block_kwargs
): ):
super(CrossStage, self).__init__() super(CrossStage, self).__init__()
first_dilation = first_dilation or dilation first_dilation = first_dilation or dilation
down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels
self.exp_chs = exp_chs = int(round(out_chs * exp_ratio)) self.expand_chs = exp_chs = int(round(out_chs * expand_ratio))
block_out_chs = int(round(out_chs * block_ratio)) block_out_chs = int(round(out_chs * block_ratio))
conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer')) conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'))
if stride != 1 or first_dilation != dilation: if stride != 1 or first_dilation != dilation:
if avg_down: if avg_down:
self.conv_down = nn.Sequential( self.conv_down = nn.Sequential(
nn.AvgPool2d(3, 2, 1) if stride == 2 else nn.Identity(), # FIXME dilation handling nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
) )
else: else:
@ -417,9 +476,15 @@ class CrossStage(nn.Module):
self.blocks = nn.Sequential() self.blocks = nn.Sequential()
for i in range(depth): for i in range(depth):
drop_path = DropPath(block_dpr[i]) if block_dpr and block_dpr[i] else None
self.blocks.add_module(str(i), block_fn( self.blocks.add_module(str(i), block_fn(
prev_chs, block_out_chs, dilation, bottle_ratio, groups, drop_path=drop_path, **block_kwargs)) in_chs=prev_chs,
out_chs=block_out_chs,
dilation=dilation,
bottle_ratio=bottle_ratio,
groups=groups,
drop_path=block_dpr[i] if block_dpr is not None else 0.,
**block_kwargs
))
prev_chs = block_out_chs prev_chs = block_out_chs
# transition convs # transition convs
@ -429,7 +494,7 @@ class CrossStage(nn.Module):
def forward(self, x): def forward(self, x):
x = self.conv_down(x) x = self.conv_down(x)
x = self.conv_exp(x) x = self.conv_exp(x)
xs, xb = x.split(self.exp_chs // 2, dim=1) xs, xb = x.split(self.expand_chs // 2, dim=1)
xb = self.blocks(xb) xb = self.blocks(xb)
xb = self.conv_transition_b(xb).contiguous() xb = self.conv_transition_b(xb).contiguous()
out = self.conv_transition(torch.cat([xs, xb], dim=1)) out = self.conv_transition(torch.cat([xs, xb], dim=1))
@ -449,27 +514,27 @@ class CrossStage3(nn.Module):
depth, depth,
block_ratio=1., block_ratio=1.,
bottle_ratio=1., bottle_ratio=1.,
exp_ratio=1., expand_ratio=1.,
groups=1, groups=1,
first_dilation=None, first_dilation=None,
avg_down=False, avg_down=False,
down_growth=False, down_growth=False,
cross_linear=False, cross_linear=False,
block_dpr=None, block_dpr=None,
block_fn=ResBottleneck, block_fn=BottleneckBlock,
**block_kwargs **block_kwargs
): ):
super(CrossStage3, self).__init__() super(CrossStage3, self).__init__()
first_dilation = first_dilation or dilation first_dilation = first_dilation or dilation
down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels
self.exp_chs = exp_chs = int(round(out_chs * exp_ratio)) self.expand_chs = exp_chs = int(round(out_chs * expand_ratio))
block_out_chs = int(round(out_chs * block_ratio)) block_out_chs = int(round(out_chs * block_ratio))
conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer')) conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'))
if stride != 1 or first_dilation != dilation: if stride != 1 or first_dilation != dilation:
if avg_down: if avg_down:
self.conv_down = nn.Sequential( self.conv_down = nn.Sequential(
nn.AvgPool2d(3, 2, 1) if stride == 2 else nn.Identity(), # FIXME dilation handling nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
) )
else: else:
@ -487,9 +552,15 @@ class CrossStage3(nn.Module):
self.blocks = nn.Sequential() self.blocks = nn.Sequential()
for i in range(depth): for i in range(depth):
drop_path = DropPath(block_dpr[i]) if block_dpr and block_dpr[i] else None
self.blocks.add_module(str(i), block_fn( self.blocks.add_module(str(i), block_fn(
prev_chs, block_out_chs, dilation, bottle_ratio, groups, drop_path=drop_path, **block_kwargs)) in_chs=prev_chs,
out_chs=block_out_chs,
dilation=dilation,
bottle_ratio=bottle_ratio,
groups=groups,
drop_path=block_dpr[i] if block_dpr is not None else 0.,
**block_kwargs
))
prev_chs = block_out_chs prev_chs = block_out_chs
# transition convs # transition convs
@ -498,7 +569,7 @@ class CrossStage3(nn.Module):
def forward(self, x): def forward(self, x):
x = self.conv_down(x) x = self.conv_down(x)
x = self.conv_exp(x) x = self.conv_exp(x)
x1, x2 = x.split(self.exp_chs // 2, dim=1) x1, x2 = x.split(self.expand_chs // 2, dim=1)
x1 = self.blocks(x1) x1 = self.blocks(x1)
out = self.conv_transition(torch.cat([x1, x2], dim=1)) out = self.conv_transition(torch.cat([x1, x2], dim=1))
return out return out
@ -519,7 +590,7 @@ class DarkStage(nn.Module):
groups=1, groups=1,
first_dilation=None, first_dilation=None,
avg_down=False, avg_down=False,
block_fn=ResBottleneck, block_fn=BottleneckBlock,
block_dpr=None, block_dpr=None,
**block_kwargs **block_kwargs
): ):
@ -529,7 +600,7 @@ class DarkStage(nn.Module):
if avg_down: if avg_down:
self.conv_down = nn.Sequential( self.conv_down = nn.Sequential(
nn.AvgPool2d(3, 2, 1) if stride == 2 else nn.Identity(), # FIXME dilation handling nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
) )
else: else:
@ -541,9 +612,15 @@ class DarkStage(nn.Module):
block_out_chs = int(round(out_chs * block_ratio)) block_out_chs = int(round(out_chs * block_ratio))
self.blocks = nn.Sequential() self.blocks = nn.Sequential()
for i in range(depth): for i in range(depth):
drop_path = DropPath(block_dpr[i]) if block_dpr and block_dpr[i] else None
self.blocks.add_module(str(i), block_fn( self.blocks.add_module(str(i), block_fn(
prev_chs, block_out_chs, dilation, bottle_ratio, groups, drop_path=drop_path, **block_kwargs)) in_chs=prev_chs,
out_chs=block_out_chs,
dilation=dilation,
bottle_ratio=bottle_ratio,
groups=groups,
drop_path=block_dpr[i] if block_dpr is not None else 0.,
**block_kwargs
))
prev_chs = block_out_chs prev_chs = block_out_chs
def forward(self, x): def forward(self, x):
@ -552,38 +629,131 @@ class DarkStage(nn.Module):
return x return x
def _cfg_to_stage_args(cfg, curr_stride=2, output_stride=32, drop_path_rate=0.): def create_csp_stem(
# get per stage args for stage and containing blocks, calculate strides to meet target output_stride in_chans=3,
num_stages = len(cfg['depth']) out_chs=32,
if 'groups' not in cfg: kernel_size=3,
cfg['groups'] = (1,) * num_stages stride=2,
if 'down_growth' in cfg and not isinstance(cfg['down_growth'], (list, tuple)): pool='',
cfg['down_growth'] = (cfg['down_growth'],) * num_stages padding='',
if 'cross_linear' in cfg and not isinstance(cfg['cross_linear'], (list, tuple)): act_layer=nn.ReLU,
cfg['cross_linear'] = (cfg['cross_linear'],) * num_stages norm_layer=nn.BatchNorm2d,
if 'avg_down' in cfg and not isinstance(cfg['avg_down'], (list, tuple)): aa_layer=None
cfg['avg_down'] = (cfg['avg_down'],) * num_stages ):
cfg['block_dpr'] = [None] * num_stages if not drop_path_rate else \ stem = nn.Sequential()
[x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg['depth'])).split(cfg['depth'])] feature_info = []
stage_strides = [] if not isinstance(out_chs, (tuple, list)):
stage_dilations = [] out_chs = [out_chs]
stage_first_dilations = [] stem_depth = len(out_chs)
assert stem_depth
assert stride in (1, 2, 4)
prev_feat = None
prev_chs = in_chans
last_idx = stem_depth - 1
stem_stride = 1
for i, chs in enumerate(out_chs):
conv_name = f'conv{i + 1}'
conv_stride = 2 if (i == 0 and stride > 1) or (i == last_idx and stride > 2 and not pool) else 1
if conv_stride > 1 and prev_feat is not None:
feature_info.append(prev_feat)
stem.add_module(conv_name, ConvNormAct(
prev_chs, chs, kernel_size,
stride=conv_stride,
padding=padding if i == 0 else '',
act_layer=act_layer,
norm_layer=norm_layer
))
stem_stride *= conv_stride
prev_chs = chs
prev_feat = dict(num_chs=prev_chs, reduction=stem_stride, module='.'.join(['stem', conv_name]))
if pool:
assert stride > 2
if prev_feat is not None:
feature_info.append(prev_feat)
if aa_layer is not None:
stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=1, padding=1))
stem.add_module('aa', aa_layer(channels=prev_chs, stride=2))
pool_name = 'aa'
else:
stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
pool_name = 'pool'
stem_stride *= 2
prev_feat = dict(num_chs=prev_chs, reduction=stem_stride, module='.'.join(['stem', pool_name]))
feature_info.append(prev_feat)
return stem, feature_info
def _get_stage_fn(stage_type: str, stage_args):
assert stage_type in ('dark', 'csp', 'cs3')
if stage_type == 'dark':
stage_args.pop('expand_ratio', None)
stage_args.pop('cross_linear', None)
stage_args.pop('down_growth', None)
stage_fn = DarkStage
elif stage_type == 'csp':
stage_fn = CrossStage
else:
stage_fn = CrossStage3
return stage_fn, stage_args
def _get_block_fn(stage_type: str, stage_args):
assert stage_type in ('dark', 'bottle')
if stage_type == 'dark':
return DarkBlock, stage_args
else:
return BottleneckBlock, stage_args
def create_csp_stages(
cfg: CspModelCfg,
drop_path_rate: float,
output_stride: int,
stem_feat: Dict[str, Any]
):
cfg_dict = asdict(cfg.stages)
num_stages = len(cfg.stages.depth)
cfg_dict['block_dpr'] = [None] * num_stages if not drop_path_rate else \
[x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.stages.depth)).split(cfg.stages.depth)]
stage_args = [dict(zip(cfg_dict.keys(), values)) for values in zip(*cfg_dict.values())]
block_kwargs = dict(
act_layer=cfg.act_layer,
norm_layer=cfg.norm_layer,
aa_layer=cfg.aa_layer
)
dilation = 1 dilation = 1
for cfg_stride in cfg['stride']: net_stride = stem_feat['reduction']
stage_first_dilations.append(dilation) prev_chs = stem_feat['num_chs']
if curr_stride >= output_stride: prev_feat = stem_feat
dilation *= cfg_stride feature_info = []
stages = []
for stage_idx, stage_args in enumerate(stage_args):
stage_fn, stage_args = _get_stage_fn(stage_args.pop('stage_type'), stage_args)
block_fn, stage_args = _get_block_fn(stage_args.pop('block_type'), stage_args)
stride = stage_args.pop('stride')
if stride != 1 and prev_feat:
feature_info.append(prev_feat)
if net_stride >= output_stride and stride > 1:
dilation *= stride
stride = 1 stride = 1
else: net_stride *= stride
stride = cfg_stride first_dilation = 1 if dilation in (1, 2) else 2
curr_stride *= stride
stage_strides.append(stride) stages += [stage_fn(
stage_dilations.append(dilation) prev_chs,
cfg['stride'] = stage_strides **stage_args,
cfg['dilation'] = stage_dilations stride=stride,
cfg['first_dilation'] = stage_first_dilations first_dilation=first_dilation,
stage_args = [dict(zip(cfg.keys(), values)) for values in zip(*cfg.values())] dilation=dilation,
return stage_args block_fn=block_fn,
**block_kwargs,
)]
prev_chs = stage_args['out_chs']
prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}')
feature_info.append(prev_feat)
return nn.Sequential(*stages), feature_info
class CspNet(nn.Module): class CspNet(nn.Module):
@ -598,43 +768,39 @@ class CspNet(nn.Module):
def __init__( def __init__(
self, self,
cfg, cfg: CspModelCfg,
in_chans=3, in_chans=3,
num_classes=1000, num_classes=1000,
output_stride=32, output_stride=32,
global_pool='avg', global_pool='avg',
act_layer=nn.LeakyReLU,
norm_layer=nn.BatchNorm2d,
aa_layer=None,
drop_rate=0., drop_rate=0.,
drop_path_rate=0., drop_path_rate=0.,
zero_init_last=True, zero_init_last=True
stage_fn=CrossStage, ):
block_fn=ResBottleneck):
super().__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.drop_rate = drop_rate self.drop_rate = drop_rate
assert output_stride in (8, 16, 32) assert output_stride in (8, 16, 32)
layer_args = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer) layer_args = dict(
act_layer=cfg.act_layer,
norm_layer=cfg.norm_layer,
aa_layer=cfg.aa_layer
)
self.feature_info = []
# Construct the stem # Construct the stem
self.stem, stem_feat_info = create_stem(in_chans, **cfg['stem'], **layer_args) self.stem, stem_feat_info = create_csp_stem(in_chans, **asdict(cfg.stem), **layer_args)
self.feature_info = [stem_feat_info] self.feature_info.extend(stem_feat_info[:-1])
prev_chs = stem_feat_info['num_chs']
curr_stride = stem_feat_info['reduction'] # reduction does not include pool
if cfg['stem']['pool']:
curr_stride *= 2
# Construct the stages # Construct the stages
per_stage_args = _cfg_to_stage_args( self.stages, stage_feat_info = create_csp_stages(
cfg['stage'], curr_stride=curr_stride, output_stride=output_stride, drop_path_rate=drop_path_rate) cfg,
self.stages = nn.Sequential() drop_path_rate=drop_path_rate,
for i, sa in enumerate(per_stage_args): output_stride=output_stride,
self.stages.add_module( stem_feat=stem_feat_info[-1],
str(i), stage_fn(prev_chs, **sa, **layer_args, block_fn=block_fn)) )
prev_chs = sa['out_chs'] prev_chs = stage_feat_info[-1]['num_chs']
curr_stride *= sa['stride'] self.feature_info.extend(stage_feat_info)
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]
# Construct the head # Construct the head
self.num_features = prev_chs self.num_features = prev_chs
@ -729,54 +895,74 @@ def cspresnext50(pretrained=False, **kwargs):
@register_model @register_model
def cspdarknet53(pretrained=False, **kwargs): def cspdarknet53(pretrained=False, **kwargs):
return _create_cspnet('cspdarknet53', pretrained=pretrained, block_fn=DarkBlock, **kwargs) return _create_cspnet('cspdarknet53', pretrained=pretrained, **kwargs)
@register_model @register_model
def darknet17(pretrained=False, **kwargs): def darknet17(pretrained=False, **kwargs):
return _create_cspnet('darknet17', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) return _create_cspnet('darknet17', pretrained=pretrained, **kwargs)
@register_model @register_model
def darknet21(pretrained=False, **kwargs): def darknet21(pretrained=False, **kwargs):
return _create_cspnet('darknet21', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) return _create_cspnet('darknet21', pretrained=pretrained, **kwargs)
@register_model @register_model
def sedarknet21(pretrained=False, **kwargs): def sedarknet21(pretrained=False, **kwargs):
return _create_cspnet('sedarknet21', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) return _create_cspnet('sedarknet21', pretrained=pretrained, **kwargs)
@register_model @register_model
def darknet53(pretrained=False, **kwargs): def darknet53(pretrained=False, **kwargs):
return _create_cspnet('darknet53', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) return _create_cspnet('darknet53', pretrained=pretrained, **kwargs)
@register_model @register_model
def darknetaa53(pretrained=False, **kwargs): def darknetaa53(pretrained=False, **kwargs):
return _create_cspnet( return _create_cspnet('darknetaa53', pretrained=pretrained, **kwargs)
'darknetaa53', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs)
@register_model
def cs3darknet_s(pretrained=False, **kwargs):
return _create_cspnet('cs3darknet_s', pretrained=pretrained, **kwargs)
@register_model @register_model
def cs3darknet_m(pretrained=False, **kwargs): def cs3darknet_m(pretrained=False, **kwargs):
return _create_cspnet( return _create_cspnet('cs3darknet_m', pretrained=pretrained, **kwargs)
'cs3darknet_m', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage3, act_layer='silu', **kwargs)
@register_model @register_model
def cs3darknet_l(pretrained=False, **kwargs): def cs3darknet_l(pretrained=False, **kwargs):
return _create_cspnet( return _create_cspnet('cs3darknet_l', pretrained=pretrained, **kwargs)
'cs3darknet_l', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage3, act_layer='silu', **kwargs)
@register_model
def cs3darknet_x(pretrained=False, **kwargs):
return _create_cspnet('cs3darknet_x', pretrained=pretrained, **kwargs)
@register_model
def cs3darknet_focus_s(pretrained=False, **kwargs):
return _create_cspnet('cs3darknet_focus_s', pretrained=pretrained, **kwargs)
@register_model @register_model
def cs3darknet_focus_m(pretrained=False, **kwargs): def cs3darknet_focus_m(pretrained=False, **kwargs):
return _create_cspnet( return _create_cspnet('cs3darknet_focus_m', pretrained=pretrained, **kwargs)
'cs3darknet_focus_m', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage3, act_layer='silu', **kwargs)
@register_model @register_model
def cs3darknet_focus_l(pretrained=False, **kwargs): def cs3darknet_focus_l(pretrained=False, **kwargs):
return _create_cspnet( return _create_cspnet('cs3darknet_focus_l', pretrained=pretrained, **kwargs)
'cs3darknet_focus_l', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage3, act_layer='silu', **kwargs)
@register_model
def cs3darknet_focus_x(pretrained=False, **kwargs):
return _create_cspnet('cs3darknet_focus_x', pretrained=pretrained, **kwargs)
@register_model
def cs3sedarknet_xdw(pretrained=False, **kwargs):
return _create_cspnet('cs3sedarknet_xdw', pretrained=pretrained, **kwargs)

Loading…
Cancel
Save