|
|
|
@ -12,7 +12,10 @@ Reference impl via darknet cfg files at https://github.com/WongKinYiu/CrossStage
|
|
|
|
|
|
|
|
|
|
Hacked together by / Copyright 2020 Ross Wightman
|
|
|
|
|
"""
|
|
|
|
|
import collections.abc
|
|
|
|
|
from dataclasses import dataclass, field, asdict
|
|
|
|
|
from functools import partial
|
|
|
|
|
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
@ -20,7 +23,7 @@ import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
from .helpers import build_model_with_cfg, named_apply, MATCH_PREV_GROUP
|
|
|
|
|
from .layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, create_attn, get_norm_act_layer
|
|
|
|
|
from .layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, create_attn, create_act_layer, make_divisible
|
|
|
|
|
from .registry import register_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -58,218 +61,278 @@ default_cfgs = {
|
|
|
|
|
),
|
|
|
|
|
'darknetaa53': _cfg(url=''),
|
|
|
|
|
|
|
|
|
|
'cs3darknet_s': _cfg(
|
|
|
|
|
url=''),
|
|
|
|
|
'cs3darknet_m': _cfg(
|
|
|
|
|
url=''),
|
|
|
|
|
'cs3darknet_l': _cfg(
|
|
|
|
|
url=''),
|
|
|
|
|
'cs3darknet_x': _cfg(
|
|
|
|
|
url=''),
|
|
|
|
|
|
|
|
|
|
'cs3darknet_focus_s': _cfg(
|
|
|
|
|
url=''),
|
|
|
|
|
'cs3darknet_focus_m': _cfg(
|
|
|
|
|
url=''),
|
|
|
|
|
'cs3darknet_focus_l': _cfg(
|
|
|
|
|
url=''),
|
|
|
|
|
'cs3darknet_focus_x': _cfg(
|
|
|
|
|
url=''),
|
|
|
|
|
|
|
|
|
|
'cs3sedarknet_xdw': _cfg(
|
|
|
|
|
url=''),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class CspStemCfg:
|
|
|
|
|
out_chs: Union[int, Tuple[int, ...]] = 32
|
|
|
|
|
stride: Union[int, Tuple[int, ...]] = 2
|
|
|
|
|
kernel_size: int = 3
|
|
|
|
|
padding: Union[int, str] = ''
|
|
|
|
|
pool: Optional[str] = ''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _pad_arg(x, n):
|
|
|
|
|
# pads an argument tuple to specified n by padding with last value
|
|
|
|
|
if not isinstance(x, (tuple, list)):
|
|
|
|
|
x = (x,)
|
|
|
|
|
curr_n = len(x)
|
|
|
|
|
pad_n = n - curr_n
|
|
|
|
|
if pad_n <= 0:
|
|
|
|
|
return x[:n]
|
|
|
|
|
return tuple(x + (x[-1],) * pad_n)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class CspStagesCfg:
|
|
|
|
|
depth: Tuple[int, ...] = (3, 3, 5, 2) # block depth (number of block repeats in stages)
|
|
|
|
|
out_chs: Tuple[int, ...] = (128, 256, 512, 1024) # number of output channels for blocks in stage
|
|
|
|
|
stride: Union[int, Tuple[int, ...]] = 2 # stride of stage
|
|
|
|
|
groups: Union[int, Tuple[int, ...]] = 1 # num kxk conv groups
|
|
|
|
|
block_ratio: Union[float, Tuple[float, ...]] = 1.0
|
|
|
|
|
bottle_ratio: Union[float, Tuple[float, ...]] = 1. # bottleneck-ratio of blocks in stage
|
|
|
|
|
avg_down: Union[bool, Tuple[bool, ...]] = False
|
|
|
|
|
attn_layer: Optional[Union[str, Tuple[str, ...]]] = None
|
|
|
|
|
stage_type: Union[str, Tuple[str]] = 'csp' # stage type ('csp', 'cs2', 'dark')
|
|
|
|
|
block_type: Union[str, Tuple[str]] = 'bottle' # blocks type for stages ('bottle', 'dark')
|
|
|
|
|
|
|
|
|
|
# cross-stage only
|
|
|
|
|
expand_ratio: Union[float, Tuple[float, ...]] = 1.0
|
|
|
|
|
cross_linear: Union[bool, Tuple[bool, ...]] = False
|
|
|
|
|
down_growth: Union[bool, Tuple[bool, ...]] = False
|
|
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
|
n = len(self.depth)
|
|
|
|
|
assert len(self.out_chs) == n
|
|
|
|
|
self.stride = _pad_arg(self.stride, n)
|
|
|
|
|
self.groups = _pad_arg(self.groups, n)
|
|
|
|
|
self.block_ratio = _pad_arg(self.block_ratio, n)
|
|
|
|
|
self.bottle_ratio = _pad_arg(self.bottle_ratio, n)
|
|
|
|
|
self.avg_down = _pad_arg(self.avg_down, n)
|
|
|
|
|
self.attn_layer = _pad_arg(self.attn_layer, n)
|
|
|
|
|
self.stage_type = _pad_arg(self.stage_type, n)
|
|
|
|
|
self.block_type = _pad_arg(self.block_type, n)
|
|
|
|
|
|
|
|
|
|
self.expand_ratio = _pad_arg(self.expand_ratio, n)
|
|
|
|
|
self.cross_linear = _pad_arg(self.cross_linear, n)
|
|
|
|
|
self.down_growth = _pad_arg(self.down_growth, n)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class CspModelCfg:
|
|
|
|
|
stem: CspStemCfg
|
|
|
|
|
stages: CspStagesCfg
|
|
|
|
|
zero_init_last: bool = True # zero init last weight (usually bn) in residual path
|
|
|
|
|
act_layer: str = 'relu'
|
|
|
|
|
norm_layer: str = 'batchnorm'
|
|
|
|
|
aa_layer: Optional[str] = None # FIXME support string factory for this
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cs3darknet_cfg(width_multiplier=1.0, depth_multiplier=1.0, avg_down=False, act_layer='silu', focus=False):
|
|
|
|
|
if focus:
|
|
|
|
|
stem_cfg = CspStemCfg(
|
|
|
|
|
out_chs=make_divisible(64 * width_multiplier),
|
|
|
|
|
kernel_size=6, stride=2, padding=2, pool='')
|
|
|
|
|
else:
|
|
|
|
|
stem_cfg = CspStemCfg(
|
|
|
|
|
out_chs=tuple([make_divisible(c * width_multiplier) for c in (32, 64)]),
|
|
|
|
|
kernel_size=3, stride=2, pool='')
|
|
|
|
|
return CspModelCfg(
|
|
|
|
|
stem=stem_cfg,
|
|
|
|
|
stages=CspStagesCfg(
|
|
|
|
|
out_chs=tuple([make_divisible(c * width_multiplier) for c in (128, 256, 512, 1024)]),
|
|
|
|
|
depth=tuple([int(d * depth_multiplier) for d in (3, 6, 9, 3)]),
|
|
|
|
|
stride=2,
|
|
|
|
|
bottle_ratio=1.,
|
|
|
|
|
block_ratio=0.5,
|
|
|
|
|
avg_down=avg_down,
|
|
|
|
|
stage_type='cs3',
|
|
|
|
|
block_type='dark',
|
|
|
|
|
),
|
|
|
|
|
act_layer=act_layer,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_cfgs = dict(
|
|
|
|
|
cspresnet50=dict(
|
|
|
|
|
stem=dict(out_chs=64, kernel_size=7, stride=2, pool='max'),
|
|
|
|
|
stage=dict(
|
|
|
|
|
out_chs=(128, 256, 512, 1024),
|
|
|
|
|
cspresnet50=CspModelCfg(
|
|
|
|
|
stem=CspStemCfg(out_chs=64, kernel_size=7, stride=4, pool='max'),
|
|
|
|
|
stages=CspStagesCfg(
|
|
|
|
|
depth=(3, 3, 5, 2),
|
|
|
|
|
stride=(1,) + (2,) * 3,
|
|
|
|
|
exp_ratio=(2.,) * 4,
|
|
|
|
|
bottle_ratio=(0.5,) * 4,
|
|
|
|
|
block_ratio=(1.,) * 4,
|
|
|
|
|
out_chs=(128, 256, 512, 1024),
|
|
|
|
|
stride=(1, 2),
|
|
|
|
|
expand_ratio=2.,
|
|
|
|
|
bottle_ratio=0.5,
|
|
|
|
|
cross_linear=True,
|
|
|
|
|
)
|
|
|
|
|
),
|
|
|
|
|
),
|
|
|
|
|
cspresnet50d=dict(
|
|
|
|
|
stem=dict(out_chs=[32, 32, 64], kernel_size=3, stride=2, pool='max'),
|
|
|
|
|
stage=dict(
|
|
|
|
|
out_chs=(128, 256, 512, 1024),
|
|
|
|
|
cspresnet50d=CspModelCfg(
|
|
|
|
|
stem=CspStemCfg(out_chs=(32, 32, 64), kernel_size=3, stride=4, pool='max'),
|
|
|
|
|
stages=CspStagesCfg(
|
|
|
|
|
depth=(3, 3, 5, 2),
|
|
|
|
|
stride=(1,) + (2,) * 3,
|
|
|
|
|
exp_ratio=(2.,) * 4,
|
|
|
|
|
bottle_ratio=(0.5,) * 4,
|
|
|
|
|
block_ratio=(1.,) * 4,
|
|
|
|
|
out_chs=(128, 256, 512, 1024),
|
|
|
|
|
stride=(1,) + (2,),
|
|
|
|
|
expand_ratio=2.,
|
|
|
|
|
bottle_ratio=0.5,
|
|
|
|
|
block_ratio=1.,
|
|
|
|
|
cross_linear=True,
|
|
|
|
|
)
|
|
|
|
|
),
|
|
|
|
|
cspresnet50w=dict(
|
|
|
|
|
stem=dict(out_chs=[32, 32, 64], kernel_size=3, stride=2, pool='max'),
|
|
|
|
|
stage=dict(
|
|
|
|
|
out_chs=(256, 512, 1024, 2048),
|
|
|
|
|
cspresnet50w=CspModelCfg(
|
|
|
|
|
stem=CspStemCfg(out_chs=(32, 32, 64), kernel_size=3, stride=4, pool='max'),
|
|
|
|
|
stages=CspStagesCfg(
|
|
|
|
|
depth=(3, 3, 5, 2),
|
|
|
|
|
stride=(1,) + (2,) * 3,
|
|
|
|
|
exp_ratio=(1.,) * 4,
|
|
|
|
|
bottle_ratio=(0.25,) * 4,
|
|
|
|
|
block_ratio=(0.5,) * 4,
|
|
|
|
|
out_chs=(256, 512, 1024, 2048),
|
|
|
|
|
stride=(1,) + (2,),
|
|
|
|
|
expand_ratio=1.,
|
|
|
|
|
bottle_ratio=0.25,
|
|
|
|
|
block_ratio=0.5,
|
|
|
|
|
cross_linear=True,
|
|
|
|
|
)
|
|
|
|
|
),
|
|
|
|
|
cspresnext50=dict(
|
|
|
|
|
stem=dict(out_chs=64, kernel_size=7, stride=2, pool='max'),
|
|
|
|
|
stage=dict(
|
|
|
|
|
out_chs=(256, 512, 1024, 2048),
|
|
|
|
|
cspresnext50=CspModelCfg(
|
|
|
|
|
stem=CspStemCfg(out_chs=64, kernel_size=7, stride=4, pool='max'),
|
|
|
|
|
stages=CspStagesCfg(
|
|
|
|
|
depth=(3, 3, 5, 2),
|
|
|
|
|
stride=(1,) + (2,) * 3,
|
|
|
|
|
groups=(32,) * 4,
|
|
|
|
|
exp_ratio=(1.,) * 4,
|
|
|
|
|
bottle_ratio=(1.,) * 4,
|
|
|
|
|
block_ratio=(0.5,) * 4,
|
|
|
|
|
out_chs=(256, 512, 1024, 2048),
|
|
|
|
|
stride=(1,) + (2,),
|
|
|
|
|
groups=32,
|
|
|
|
|
expand_ratio=1.,
|
|
|
|
|
bottle_ratio=1.,
|
|
|
|
|
block_ratio=0.5,
|
|
|
|
|
cross_linear=True,
|
|
|
|
|
)
|
|
|
|
|
),
|
|
|
|
|
cspdarknet53=dict(
|
|
|
|
|
stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''),
|
|
|
|
|
stage=dict(
|
|
|
|
|
out_chs=(64, 128, 256, 512, 1024),
|
|
|
|
|
cspdarknet53=CspModelCfg(
|
|
|
|
|
stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
|
|
|
|
|
stages=CspStagesCfg(
|
|
|
|
|
depth=(1, 2, 8, 8, 4),
|
|
|
|
|
stride=(2,) * 5,
|
|
|
|
|
exp_ratio=(2.,) + (1.,) * 4,
|
|
|
|
|
bottle_ratio=(0.5,) + (1.0,) * 4,
|
|
|
|
|
block_ratio=(1.,) + (0.5,) * 4,
|
|
|
|
|
out_chs=(64, 128, 256, 512, 1024),
|
|
|
|
|
stride=2,
|
|
|
|
|
expand_ratio=(2.,) + (1.,),
|
|
|
|
|
bottle_ratio=(0.5,) + (1.,),
|
|
|
|
|
block_ratio=(1.,) + (0.5,),
|
|
|
|
|
down_growth=True,
|
|
|
|
|
)
|
|
|
|
|
block_type='dark',
|
|
|
|
|
),
|
|
|
|
|
act_layer='leaky_relu',
|
|
|
|
|
),
|
|
|
|
|
darknet17=dict(
|
|
|
|
|
stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''),
|
|
|
|
|
stage=dict(
|
|
|
|
|
out_chs=(64, 128, 256, 512, 1024),
|
|
|
|
|
darknet17=CspModelCfg(
|
|
|
|
|
stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
|
|
|
|
|
stages=CspStagesCfg(
|
|
|
|
|
depth=(1,) * 5,
|
|
|
|
|
stride=(2,) * 5,
|
|
|
|
|
bottle_ratio=(0.5,) * 5,
|
|
|
|
|
block_ratio=(1.,) * 5,
|
|
|
|
|
)
|
|
|
|
|
),
|
|
|
|
|
darknet21=dict(
|
|
|
|
|
stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''),
|
|
|
|
|
stage=dict(
|
|
|
|
|
out_chs=(64, 128, 256, 512, 1024),
|
|
|
|
|
depth=(1, 1, 1, 2, 2),
|
|
|
|
|
stride=(2,) * 5,
|
|
|
|
|
bottle_ratio=(0.5,) * 5,
|
|
|
|
|
block_ratio=(1.,) * 5,
|
|
|
|
|
)
|
|
|
|
|
stride=(2,),
|
|
|
|
|
bottle_ratio=(0.5,),
|
|
|
|
|
block_ratio=(1.,),
|
|
|
|
|
stage_type='dark',
|
|
|
|
|
block_type='dark',
|
|
|
|
|
),
|
|
|
|
|
act_layer='leaky_relu',
|
|
|
|
|
),
|
|
|
|
|
sedarknet21=dict(
|
|
|
|
|
stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''),
|
|
|
|
|
stage=dict(
|
|
|
|
|
out_chs=(64, 128, 256, 512, 1024),
|
|
|
|
|
darknet21=CspModelCfg(
|
|
|
|
|
stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
|
|
|
|
|
stages=CspStagesCfg(
|
|
|
|
|
depth=(1, 1, 1, 2, 2),
|
|
|
|
|
stride=(2,) * 5,
|
|
|
|
|
bottle_ratio=(0.5,) * 5,
|
|
|
|
|
block_ratio=(1.,) * 5,
|
|
|
|
|
attn_layer=('se',) * 5,
|
|
|
|
|
)
|
|
|
|
|
),
|
|
|
|
|
darknet53=dict(
|
|
|
|
|
stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''),
|
|
|
|
|
stage=dict(
|
|
|
|
|
out_chs=(64, 128, 256, 512, 1024),
|
|
|
|
|
depth=(1, 2, 8, 8, 4),
|
|
|
|
|
stride=(2,) * 5,
|
|
|
|
|
bottle_ratio=(0.5,) * 5,
|
|
|
|
|
block_ratio=(1.,) * 5,
|
|
|
|
|
)
|
|
|
|
|
),
|
|
|
|
|
stride=(2,),
|
|
|
|
|
bottle_ratio=(0.5,),
|
|
|
|
|
block_ratio=(1.,),
|
|
|
|
|
stage_type='dark',
|
|
|
|
|
block_type='dark',
|
|
|
|
|
|
|
|
|
|
darknetaa53=dict(
|
|
|
|
|
stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''),
|
|
|
|
|
stage=dict(
|
|
|
|
|
out_chs=(64, 128, 256, 512, 1024),
|
|
|
|
|
depth=(1, 2, 8, 8, 4),
|
|
|
|
|
stride=(2,) * 5,
|
|
|
|
|
bottle_ratio=(0.5,) * 5,
|
|
|
|
|
block_ratio=(1.,) * 5,
|
|
|
|
|
avg_down=True,
|
|
|
|
|
),
|
|
|
|
|
act_layer='leaky_relu',
|
|
|
|
|
),
|
|
|
|
|
sedarknet21=CspModelCfg(
|
|
|
|
|
stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
|
|
|
|
|
stages=CspStagesCfg(
|
|
|
|
|
depth=(1, 1, 1, 2, 2),
|
|
|
|
|
out_chs=(64, 128, 256, 512, 1024),
|
|
|
|
|
stride=2,
|
|
|
|
|
bottle_ratio=0.5,
|
|
|
|
|
block_ratio=1.,
|
|
|
|
|
attn_layer='se',
|
|
|
|
|
stage_type='dark',
|
|
|
|
|
block_type='dark',
|
|
|
|
|
|
|
|
|
|
cs3darknet_m=dict(
|
|
|
|
|
stem=dict(out_chs=(24, 48), kernel_size=3, stride=2, pool=''),
|
|
|
|
|
stage=dict(
|
|
|
|
|
out_chs=(96, 192, 384, 768),
|
|
|
|
|
depth=(2, 4, 6, 2),
|
|
|
|
|
stride=(2,) * 4,
|
|
|
|
|
bottle_ratio=(1.,) * 4,
|
|
|
|
|
block_ratio=(0.5,) * 4,
|
|
|
|
|
avg_down=False,
|
|
|
|
|
),
|
|
|
|
|
act_layer='leaky_relu',
|
|
|
|
|
),
|
|
|
|
|
cs3darknet_l=dict(
|
|
|
|
|
stem=dict(out_chs=(32, 64), kernel_size=3, stride=2, pool=''),
|
|
|
|
|
stage=dict(
|
|
|
|
|
out_chs=(128, 256, 512, 1024),
|
|
|
|
|
depth=(3, 6, 9, 3),
|
|
|
|
|
stride=(2,) * 4,
|
|
|
|
|
bottle_ratio=(1.,) * 4,
|
|
|
|
|
block_ratio=(0.5,) * 4,
|
|
|
|
|
avg_down=False,
|
|
|
|
|
darknet53=CspModelCfg(
|
|
|
|
|
stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
|
|
|
|
|
stages=CspStagesCfg(
|
|
|
|
|
depth=(1, 2, 8, 8, 4),
|
|
|
|
|
out_chs=(64, 128, 256, 512, 1024),
|
|
|
|
|
stride=2,
|
|
|
|
|
bottle_ratio=0.5,
|
|
|
|
|
block_ratio=1.,
|
|
|
|
|
stage_type='dark',
|
|
|
|
|
block_type='dark',
|
|
|
|
|
),
|
|
|
|
|
act_layer='leaky_relu',
|
|
|
|
|
),
|
|
|
|
|
|
|
|
|
|
cs3darknet_focus_m=dict(
|
|
|
|
|
stem=dict(out_chs=48, kernel_size=6, stride=2, padding=2, pool=''),
|
|
|
|
|
stage=dict(
|
|
|
|
|
out_chs=(96, 192, 384, 768),
|
|
|
|
|
depth=(2, 4, 6, 2),
|
|
|
|
|
stride=(2,) * 4,
|
|
|
|
|
bottle_ratio=(1.,) * 4,
|
|
|
|
|
block_ratio=(0.5,) * 4,
|
|
|
|
|
avg_down=False,
|
|
|
|
|
darknetaa53=CspModelCfg(
|
|
|
|
|
stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
|
|
|
|
|
stages=CspStagesCfg(
|
|
|
|
|
depth=(1, 2, 8, 8, 4),
|
|
|
|
|
out_chs=(64, 128, 256, 512, 1024),
|
|
|
|
|
stride=2,
|
|
|
|
|
bottle_ratio=0.5,
|
|
|
|
|
block_ratio=1.,
|
|
|
|
|
avg_down=True,
|
|
|
|
|
stage_type='dark',
|
|
|
|
|
block_type='dark',
|
|
|
|
|
),
|
|
|
|
|
act_layer='leaky_relu',
|
|
|
|
|
),
|
|
|
|
|
cs3darknet_focus_l=dict(
|
|
|
|
|
stem=dict(out_chs=64, kernel_size=6, stride=2, padding=2, pool=''),
|
|
|
|
|
stage=dict(
|
|
|
|
|
out_chs=(128, 256, 512, 1024),
|
|
|
|
|
depth=(3, 6, 9, 3),
|
|
|
|
|
stride=(2,) * 4,
|
|
|
|
|
bottle_ratio=(1.,) * 4,
|
|
|
|
|
block_ratio=(0.5,) * 4,
|
|
|
|
|
avg_down=False,
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
cs3darknet_s=_cs3darknet_cfg(width_multiplier=0.5, depth_multiplier=0.5),
|
|
|
|
|
cs3darknet_m=_cs3darknet_cfg(width_multiplier=0.75, depth_multiplier=0.67),
|
|
|
|
|
cs3darknet_l=_cs3darknet_cfg(),
|
|
|
|
|
cs3darknet_x=_cs3darknet_cfg(width_multiplier=1.25, depth_multiplier=1.33),
|
|
|
|
|
|
|
|
|
|
def create_stem(
|
|
|
|
|
in_chans=3,
|
|
|
|
|
out_chs=32,
|
|
|
|
|
kernel_size=3,
|
|
|
|
|
stride=2,
|
|
|
|
|
pool='',
|
|
|
|
|
padding='',
|
|
|
|
|
act_layer=nn.ReLU,
|
|
|
|
|
norm_layer=nn.BatchNorm2d,
|
|
|
|
|
aa_layer=None
|
|
|
|
|
):
|
|
|
|
|
stem = nn.Sequential()
|
|
|
|
|
if not isinstance(out_chs, (tuple, list)):
|
|
|
|
|
out_chs = [out_chs]
|
|
|
|
|
assert len(out_chs)
|
|
|
|
|
in_c = in_chans
|
|
|
|
|
for i, out_c in enumerate(out_chs):
|
|
|
|
|
conv_name = f'conv{i + 1}'
|
|
|
|
|
stem.add_module(conv_name, ConvNormAct(
|
|
|
|
|
in_c, out_c, kernel_size,
|
|
|
|
|
stride=stride if i == 0 else 1,
|
|
|
|
|
padding=padding if i == 0 else '',
|
|
|
|
|
act_layer=act_layer,
|
|
|
|
|
norm_layer=norm_layer
|
|
|
|
|
))
|
|
|
|
|
in_c = out_c
|
|
|
|
|
last_conv = conv_name
|
|
|
|
|
if pool:
|
|
|
|
|
if aa_layer is not None:
|
|
|
|
|
stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=1, padding=1))
|
|
|
|
|
stem.add_module('aa', aa_layer(channels=in_c, stride=2))
|
|
|
|
|
else:
|
|
|
|
|
stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
|
|
|
|
|
return stem, dict(num_chs=in_c, reduction=stride, module='.'.join(['stem', last_conv]))
|
|
|
|
|
cs3darknet_focus_s=_cs3darknet_cfg(width_multiplier=0.5, depth_multiplier=0.5, focus=True),
|
|
|
|
|
cs3darknet_focus_m=_cs3darknet_cfg(width_multiplier=0.75, depth_multiplier=0.67, focus=True),
|
|
|
|
|
cs3darknet_focus_l=_cs3darknet_cfg(focus=True),
|
|
|
|
|
cs3darknet_focus_x=_cs3darknet_cfg(width_multiplier=1.25, depth_multiplier=1.33, focus=True),
|
|
|
|
|
|
|
|
|
|
cs3sedarknet_xdw=CspModelCfg(
|
|
|
|
|
stem=CspStemCfg(out_chs=(32, 64), kernel_size=3, stride=2, pool=''),
|
|
|
|
|
stages=CspStagesCfg(
|
|
|
|
|
depth=(3, 6, 12, 4),
|
|
|
|
|
out_chs=(256, 512, 1024, 2048),
|
|
|
|
|
stride=2,
|
|
|
|
|
groups=(1, 1, 256, 512),
|
|
|
|
|
bottle_ratio=0.5,
|
|
|
|
|
block_ratio=0.5,
|
|
|
|
|
attn_layer='se',
|
|
|
|
|
),
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ResBottleneck(nn.Module):
|
|
|
|
|
class BottleneckBlock(nn.Module):
|
|
|
|
|
""" ResNe(X)t Bottleneck Block
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
@ -286,9 +349,9 @@ class ResBottleneck(nn.Module):
|
|
|
|
|
attn_layer=None,
|
|
|
|
|
aa_layer=None,
|
|
|
|
|
drop_block=None,
|
|
|
|
|
drop_path=None
|
|
|
|
|
drop_path=0.
|
|
|
|
|
):
|
|
|
|
|
super(ResBottleneck, self).__init__()
|
|
|
|
|
super(BottleneckBlock, self).__init__()
|
|
|
|
|
mid_chs = int(round(out_chs * bottle_ratio))
|
|
|
|
|
ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
|
|
|
|
|
|
|
|
|
@ -299,8 +362,8 @@ class ResBottleneck(nn.Module):
|
|
|
|
|
self.attn2 = create_attn(attn_layer, channels=mid_chs) if not attn_last else None
|
|
|
|
|
self.conv3 = ConvNormAct(mid_chs, out_chs, kernel_size=1, apply_act=False, **ckwargs)
|
|
|
|
|
self.attn3 = create_attn(attn_layer, channels=out_chs) if attn_last else None
|
|
|
|
|
self.drop_path = drop_path
|
|
|
|
|
self.act3 = act_layer()
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
|
|
|
|
|
self.act3 = create_act_layer(act_layer)
|
|
|
|
|
|
|
|
|
|
def zero_init_last(self):
|
|
|
|
|
nn.init.zeros_(self.conv3.bn.weight)
|
|
|
|
@ -314,9 +377,7 @@ class ResBottleneck(nn.Module):
|
|
|
|
|
x = self.conv3(x)
|
|
|
|
|
if self.attn3 is not None:
|
|
|
|
|
x = self.attn3(x)
|
|
|
|
|
if self.drop_path is not None:
|
|
|
|
|
x = self.drop_path(x)
|
|
|
|
|
x = x + shortcut
|
|
|
|
|
x = self.drop_path(x) + shortcut
|
|
|
|
|
# FIXME partial shortcut needed if first block handled as per original, not used for my current impl
|
|
|
|
|
#x[:, :shortcut.size(1)] += shortcut
|
|
|
|
|
x = self.act3(x)
|
|
|
|
@ -339,7 +400,7 @@ class DarkBlock(nn.Module):
|
|
|
|
|
attn_layer=None,
|
|
|
|
|
aa_layer=None,
|
|
|
|
|
drop_block=None,
|
|
|
|
|
drop_path=None
|
|
|
|
|
drop_path=0.
|
|
|
|
|
):
|
|
|
|
|
super(DarkBlock, self).__init__()
|
|
|
|
|
mid_chs = int(round(out_chs * bottle_ratio))
|
|
|
|
@ -349,7 +410,7 @@ class DarkBlock(nn.Module):
|
|
|
|
|
mid_chs, out_chs, kernel_size=3, dilation=dilation, groups=groups,
|
|
|
|
|
aa_layer=aa_layer, drop_layer=drop_block, **ckwargs)
|
|
|
|
|
self.attn = create_attn(attn_layer, channels=out_chs, act_layer=act_layer)
|
|
|
|
|
self.drop_path = drop_path
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
|
|
|
|
|
|
|
|
|
|
def zero_init_last(self):
|
|
|
|
|
nn.init.zeros_(self.conv2.bn.weight)
|
|
|
|
@ -360,9 +421,7 @@ class DarkBlock(nn.Module):
|
|
|
|
|
x = self.conv2(x)
|
|
|
|
|
if self.attn is not None:
|
|
|
|
|
x = self.attn(x)
|
|
|
|
|
if self.drop_path is not None:
|
|
|
|
|
x = self.drop_path(x)
|
|
|
|
|
x = x + shortcut
|
|
|
|
|
x = self.drop_path(x) + shortcut
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -377,27 +436,27 @@ class CrossStage(nn.Module):
|
|
|
|
|
depth,
|
|
|
|
|
block_ratio=1.,
|
|
|
|
|
bottle_ratio=1.,
|
|
|
|
|
exp_ratio=1.,
|
|
|
|
|
expand_ratio=1.,
|
|
|
|
|
groups=1,
|
|
|
|
|
first_dilation=None,
|
|
|
|
|
avg_down=False,
|
|
|
|
|
down_growth=False,
|
|
|
|
|
cross_linear=False,
|
|
|
|
|
block_dpr=None,
|
|
|
|
|
block_fn=ResBottleneck,
|
|
|
|
|
block_fn=BottleneckBlock,
|
|
|
|
|
**block_kwargs
|
|
|
|
|
):
|
|
|
|
|
super(CrossStage, self).__init__()
|
|
|
|
|
first_dilation = first_dilation or dilation
|
|
|
|
|
down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels
|
|
|
|
|
self.exp_chs = exp_chs = int(round(out_chs * exp_ratio))
|
|
|
|
|
self.expand_chs = exp_chs = int(round(out_chs * expand_ratio))
|
|
|
|
|
block_out_chs = int(round(out_chs * block_ratio))
|
|
|
|
|
conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'))
|
|
|
|
|
|
|
|
|
|
if stride != 1 or first_dilation != dilation:
|
|
|
|
|
if avg_down:
|
|
|
|
|
self.conv_down = nn.Sequential(
|
|
|
|
|
nn.AvgPool2d(3, 2, 1) if stride == 2 else nn.Identity(), # FIXME dilation handling
|
|
|
|
|
nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
|
|
|
|
|
ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
@ -417,9 +476,15 @@ class CrossStage(nn.Module):
|
|
|
|
|
|
|
|
|
|
self.blocks = nn.Sequential()
|
|
|
|
|
for i in range(depth):
|
|
|
|
|
drop_path = DropPath(block_dpr[i]) if block_dpr and block_dpr[i] else None
|
|
|
|
|
self.blocks.add_module(str(i), block_fn(
|
|
|
|
|
prev_chs, block_out_chs, dilation, bottle_ratio, groups, drop_path=drop_path, **block_kwargs))
|
|
|
|
|
in_chs=prev_chs,
|
|
|
|
|
out_chs=block_out_chs,
|
|
|
|
|
dilation=dilation,
|
|
|
|
|
bottle_ratio=bottle_ratio,
|
|
|
|
|
groups=groups,
|
|
|
|
|
drop_path=block_dpr[i] if block_dpr is not None else 0.,
|
|
|
|
|
**block_kwargs
|
|
|
|
|
))
|
|
|
|
|
prev_chs = block_out_chs
|
|
|
|
|
|
|
|
|
|
# transition convs
|
|
|
|
@ -429,7 +494,7 @@ class CrossStage(nn.Module):
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.conv_down(x)
|
|
|
|
|
x = self.conv_exp(x)
|
|
|
|
|
xs, xb = x.split(self.exp_chs // 2, dim=1)
|
|
|
|
|
xs, xb = x.split(self.expand_chs // 2, dim=1)
|
|
|
|
|
xb = self.blocks(xb)
|
|
|
|
|
xb = self.conv_transition_b(xb).contiguous()
|
|
|
|
|
out = self.conv_transition(torch.cat([xs, xb], dim=1))
|
|
|
|
@ -449,27 +514,27 @@ class CrossStage3(nn.Module):
|
|
|
|
|
depth,
|
|
|
|
|
block_ratio=1.,
|
|
|
|
|
bottle_ratio=1.,
|
|
|
|
|
exp_ratio=1.,
|
|
|
|
|
expand_ratio=1.,
|
|
|
|
|
groups=1,
|
|
|
|
|
first_dilation=None,
|
|
|
|
|
avg_down=False,
|
|
|
|
|
down_growth=False,
|
|
|
|
|
cross_linear=False,
|
|
|
|
|
block_dpr=None,
|
|
|
|
|
block_fn=ResBottleneck,
|
|
|
|
|
block_fn=BottleneckBlock,
|
|
|
|
|
**block_kwargs
|
|
|
|
|
):
|
|
|
|
|
super(CrossStage3, self).__init__()
|
|
|
|
|
first_dilation = first_dilation or dilation
|
|
|
|
|
down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels
|
|
|
|
|
self.exp_chs = exp_chs = int(round(out_chs * exp_ratio))
|
|
|
|
|
self.expand_chs = exp_chs = int(round(out_chs * expand_ratio))
|
|
|
|
|
block_out_chs = int(round(out_chs * block_ratio))
|
|
|
|
|
conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'))
|
|
|
|
|
|
|
|
|
|
if stride != 1 or first_dilation != dilation:
|
|
|
|
|
if avg_down:
|
|
|
|
|
self.conv_down = nn.Sequential(
|
|
|
|
|
nn.AvgPool2d(3, 2, 1) if stride == 2 else nn.Identity(), # FIXME dilation handling
|
|
|
|
|
nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
|
|
|
|
|
ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
@ -487,9 +552,15 @@ class CrossStage3(nn.Module):
|
|
|
|
|
|
|
|
|
|
self.blocks = nn.Sequential()
|
|
|
|
|
for i in range(depth):
|
|
|
|
|
drop_path = DropPath(block_dpr[i]) if block_dpr and block_dpr[i] else None
|
|
|
|
|
self.blocks.add_module(str(i), block_fn(
|
|
|
|
|
prev_chs, block_out_chs, dilation, bottle_ratio, groups, drop_path=drop_path, **block_kwargs))
|
|
|
|
|
in_chs=prev_chs,
|
|
|
|
|
out_chs=block_out_chs,
|
|
|
|
|
dilation=dilation,
|
|
|
|
|
bottle_ratio=bottle_ratio,
|
|
|
|
|
groups=groups,
|
|
|
|
|
drop_path=block_dpr[i] if block_dpr is not None else 0.,
|
|
|
|
|
**block_kwargs
|
|
|
|
|
))
|
|
|
|
|
prev_chs = block_out_chs
|
|
|
|
|
|
|
|
|
|
# transition convs
|
|
|
|
@ -498,7 +569,7 @@ class CrossStage3(nn.Module):
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.conv_down(x)
|
|
|
|
|
x = self.conv_exp(x)
|
|
|
|
|
x1, x2 = x.split(self.exp_chs // 2, dim=1)
|
|
|
|
|
x1, x2 = x.split(self.expand_chs // 2, dim=1)
|
|
|
|
|
x1 = self.blocks(x1)
|
|
|
|
|
out = self.conv_transition(torch.cat([x1, x2], dim=1))
|
|
|
|
|
return out
|
|
|
|
@ -519,7 +590,7 @@ class DarkStage(nn.Module):
|
|
|
|
|
groups=1,
|
|
|
|
|
first_dilation=None,
|
|
|
|
|
avg_down=False,
|
|
|
|
|
block_fn=ResBottleneck,
|
|
|
|
|
block_fn=BottleneckBlock,
|
|
|
|
|
block_dpr=None,
|
|
|
|
|
**block_kwargs
|
|
|
|
|
):
|
|
|
|
@ -529,7 +600,7 @@ class DarkStage(nn.Module):
|
|
|
|
|
|
|
|
|
|
if avg_down:
|
|
|
|
|
self.conv_down = nn.Sequential(
|
|
|
|
|
nn.AvgPool2d(3, 2, 1) if stride == 2 else nn.Identity(), # FIXME dilation handling
|
|
|
|
|
nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
|
|
|
|
|
ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
@ -541,9 +612,15 @@ class DarkStage(nn.Module):
|
|
|
|
|
block_out_chs = int(round(out_chs * block_ratio))
|
|
|
|
|
self.blocks = nn.Sequential()
|
|
|
|
|
for i in range(depth):
|
|
|
|
|
drop_path = DropPath(block_dpr[i]) if block_dpr and block_dpr[i] else None
|
|
|
|
|
self.blocks.add_module(str(i), block_fn(
|
|
|
|
|
prev_chs, block_out_chs, dilation, bottle_ratio, groups, drop_path=drop_path, **block_kwargs))
|
|
|
|
|
in_chs=prev_chs,
|
|
|
|
|
out_chs=block_out_chs,
|
|
|
|
|
dilation=dilation,
|
|
|
|
|
bottle_ratio=bottle_ratio,
|
|
|
|
|
groups=groups,
|
|
|
|
|
drop_path=block_dpr[i] if block_dpr is not None else 0.,
|
|
|
|
|
**block_kwargs
|
|
|
|
|
))
|
|
|
|
|
prev_chs = block_out_chs
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
@ -552,38 +629,131 @@ class DarkStage(nn.Module):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cfg_to_stage_args(cfg, curr_stride=2, output_stride=32, drop_path_rate=0.):
|
|
|
|
|
# get per stage args for stage and containing blocks, calculate strides to meet target output_stride
|
|
|
|
|
num_stages = len(cfg['depth'])
|
|
|
|
|
if 'groups' not in cfg:
|
|
|
|
|
cfg['groups'] = (1,) * num_stages
|
|
|
|
|
if 'down_growth' in cfg and not isinstance(cfg['down_growth'], (list, tuple)):
|
|
|
|
|
cfg['down_growth'] = (cfg['down_growth'],) * num_stages
|
|
|
|
|
if 'cross_linear' in cfg and not isinstance(cfg['cross_linear'], (list, tuple)):
|
|
|
|
|
cfg['cross_linear'] = (cfg['cross_linear'],) * num_stages
|
|
|
|
|
if 'avg_down' in cfg and not isinstance(cfg['avg_down'], (list, tuple)):
|
|
|
|
|
cfg['avg_down'] = (cfg['avg_down'],) * num_stages
|
|
|
|
|
cfg['block_dpr'] = [None] * num_stages if not drop_path_rate else \
|
|
|
|
|
[x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg['depth'])).split(cfg['depth'])]
|
|
|
|
|
stage_strides = []
|
|
|
|
|
stage_dilations = []
|
|
|
|
|
stage_first_dilations = []
|
|
|
|
|
def create_csp_stem(
|
|
|
|
|
in_chans=3,
|
|
|
|
|
out_chs=32,
|
|
|
|
|
kernel_size=3,
|
|
|
|
|
stride=2,
|
|
|
|
|
pool='',
|
|
|
|
|
padding='',
|
|
|
|
|
act_layer=nn.ReLU,
|
|
|
|
|
norm_layer=nn.BatchNorm2d,
|
|
|
|
|
aa_layer=None
|
|
|
|
|
):
|
|
|
|
|
stem = nn.Sequential()
|
|
|
|
|
feature_info = []
|
|
|
|
|
if not isinstance(out_chs, (tuple, list)):
|
|
|
|
|
out_chs = [out_chs]
|
|
|
|
|
stem_depth = len(out_chs)
|
|
|
|
|
assert stem_depth
|
|
|
|
|
assert stride in (1, 2, 4)
|
|
|
|
|
prev_feat = None
|
|
|
|
|
prev_chs = in_chans
|
|
|
|
|
last_idx = stem_depth - 1
|
|
|
|
|
stem_stride = 1
|
|
|
|
|
for i, chs in enumerate(out_chs):
|
|
|
|
|
conv_name = f'conv{i + 1}'
|
|
|
|
|
conv_stride = 2 if (i == 0 and stride > 1) or (i == last_idx and stride > 2 and not pool) else 1
|
|
|
|
|
if conv_stride > 1 and prev_feat is not None:
|
|
|
|
|
feature_info.append(prev_feat)
|
|
|
|
|
stem.add_module(conv_name, ConvNormAct(
|
|
|
|
|
prev_chs, chs, kernel_size,
|
|
|
|
|
stride=conv_stride,
|
|
|
|
|
padding=padding if i == 0 else '',
|
|
|
|
|
act_layer=act_layer,
|
|
|
|
|
norm_layer=norm_layer
|
|
|
|
|
))
|
|
|
|
|
stem_stride *= conv_stride
|
|
|
|
|
prev_chs = chs
|
|
|
|
|
prev_feat = dict(num_chs=prev_chs, reduction=stem_stride, module='.'.join(['stem', conv_name]))
|
|
|
|
|
if pool:
|
|
|
|
|
assert stride > 2
|
|
|
|
|
if prev_feat is not None:
|
|
|
|
|
feature_info.append(prev_feat)
|
|
|
|
|
if aa_layer is not None:
|
|
|
|
|
stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=1, padding=1))
|
|
|
|
|
stem.add_module('aa', aa_layer(channels=prev_chs, stride=2))
|
|
|
|
|
pool_name = 'aa'
|
|
|
|
|
else:
|
|
|
|
|
stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
|
|
|
|
|
pool_name = 'pool'
|
|
|
|
|
stem_stride *= 2
|
|
|
|
|
prev_feat = dict(num_chs=prev_chs, reduction=stem_stride, module='.'.join(['stem', pool_name]))
|
|
|
|
|
feature_info.append(prev_feat)
|
|
|
|
|
return stem, feature_info
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_stage_fn(stage_type: str, stage_args):
|
|
|
|
|
assert stage_type in ('dark', 'csp', 'cs3')
|
|
|
|
|
if stage_type == 'dark':
|
|
|
|
|
stage_args.pop('expand_ratio', None)
|
|
|
|
|
stage_args.pop('cross_linear', None)
|
|
|
|
|
stage_args.pop('down_growth', None)
|
|
|
|
|
stage_fn = DarkStage
|
|
|
|
|
elif stage_type == 'csp':
|
|
|
|
|
stage_fn = CrossStage
|
|
|
|
|
else:
|
|
|
|
|
stage_fn = CrossStage3
|
|
|
|
|
return stage_fn, stage_args
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_block_fn(stage_type: str, stage_args):
|
|
|
|
|
assert stage_type in ('dark', 'bottle')
|
|
|
|
|
if stage_type == 'dark':
|
|
|
|
|
return DarkBlock, stage_args
|
|
|
|
|
else:
|
|
|
|
|
return BottleneckBlock, stage_args
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_csp_stages(
|
|
|
|
|
cfg: CspModelCfg,
|
|
|
|
|
drop_path_rate: float,
|
|
|
|
|
output_stride: int,
|
|
|
|
|
stem_feat: Dict[str, Any]
|
|
|
|
|
):
|
|
|
|
|
cfg_dict = asdict(cfg.stages)
|
|
|
|
|
num_stages = len(cfg.stages.depth)
|
|
|
|
|
cfg_dict['block_dpr'] = [None] * num_stages if not drop_path_rate else \
|
|
|
|
|
[x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.stages.depth)).split(cfg.stages.depth)]
|
|
|
|
|
stage_args = [dict(zip(cfg_dict.keys(), values)) for values in zip(*cfg_dict.values())]
|
|
|
|
|
block_kwargs = dict(
|
|
|
|
|
act_layer=cfg.act_layer,
|
|
|
|
|
norm_layer=cfg.norm_layer,
|
|
|
|
|
aa_layer=cfg.aa_layer
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
dilation = 1
|
|
|
|
|
for cfg_stride in cfg['stride']:
|
|
|
|
|
stage_first_dilations.append(dilation)
|
|
|
|
|
if curr_stride >= output_stride:
|
|
|
|
|
dilation *= cfg_stride
|
|
|
|
|
net_stride = stem_feat['reduction']
|
|
|
|
|
prev_chs = stem_feat['num_chs']
|
|
|
|
|
prev_feat = stem_feat
|
|
|
|
|
feature_info = []
|
|
|
|
|
stages = []
|
|
|
|
|
for stage_idx, stage_args in enumerate(stage_args):
|
|
|
|
|
stage_fn, stage_args = _get_stage_fn(stage_args.pop('stage_type'), stage_args)
|
|
|
|
|
block_fn, stage_args = _get_block_fn(stage_args.pop('block_type'), stage_args)
|
|
|
|
|
stride = stage_args.pop('stride')
|
|
|
|
|
if stride != 1 and prev_feat:
|
|
|
|
|
feature_info.append(prev_feat)
|
|
|
|
|
if net_stride >= output_stride and stride > 1:
|
|
|
|
|
dilation *= stride
|
|
|
|
|
stride = 1
|
|
|
|
|
else:
|
|
|
|
|
stride = cfg_stride
|
|
|
|
|
curr_stride *= stride
|
|
|
|
|
stage_strides.append(stride)
|
|
|
|
|
stage_dilations.append(dilation)
|
|
|
|
|
cfg['stride'] = stage_strides
|
|
|
|
|
cfg['dilation'] = stage_dilations
|
|
|
|
|
cfg['first_dilation'] = stage_first_dilations
|
|
|
|
|
stage_args = [dict(zip(cfg.keys(), values)) for values in zip(*cfg.values())]
|
|
|
|
|
return stage_args
|
|
|
|
|
net_stride *= stride
|
|
|
|
|
first_dilation = 1 if dilation in (1, 2) else 2
|
|
|
|
|
|
|
|
|
|
stages += [stage_fn(
|
|
|
|
|
prev_chs,
|
|
|
|
|
**stage_args,
|
|
|
|
|
stride=stride,
|
|
|
|
|
first_dilation=first_dilation,
|
|
|
|
|
dilation=dilation,
|
|
|
|
|
block_fn=block_fn,
|
|
|
|
|
**block_kwargs,
|
|
|
|
|
)]
|
|
|
|
|
prev_chs = stage_args['out_chs']
|
|
|
|
|
prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}')
|
|
|
|
|
|
|
|
|
|
feature_info.append(prev_feat)
|
|
|
|
|
return nn.Sequential(*stages), feature_info
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CspNet(nn.Module):
|
|
|
|
@ -598,43 +768,39 @@ class CspNet(nn.Module):
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
cfg,
|
|
|
|
|
cfg: CspModelCfg,
|
|
|
|
|
in_chans=3,
|
|
|
|
|
num_classes=1000,
|
|
|
|
|
output_stride=32,
|
|
|
|
|
global_pool='avg',
|
|
|
|
|
act_layer=nn.LeakyReLU,
|
|
|
|
|
norm_layer=nn.BatchNorm2d,
|
|
|
|
|
aa_layer=None,
|
|
|
|
|
drop_rate=0.,
|
|
|
|
|
drop_path_rate=0.,
|
|
|
|
|
zero_init_last=True,
|
|
|
|
|
stage_fn=CrossStage,
|
|
|
|
|
block_fn=ResBottleneck):
|
|
|
|
|
zero_init_last=True
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
self.drop_rate = drop_rate
|
|
|
|
|
assert output_stride in (8, 16, 32)
|
|
|
|
|
layer_args = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer)
|
|
|
|
|
layer_args = dict(
|
|
|
|
|
act_layer=cfg.act_layer,
|
|
|
|
|
norm_layer=cfg.norm_layer,
|
|
|
|
|
aa_layer=cfg.aa_layer
|
|
|
|
|
)
|
|
|
|
|
self.feature_info = []
|
|
|
|
|
|
|
|
|
|
# Construct the stem
|
|
|
|
|
self.stem, stem_feat_info = create_stem(in_chans, **cfg['stem'], **layer_args)
|
|
|
|
|
self.feature_info = [stem_feat_info]
|
|
|
|
|
prev_chs = stem_feat_info['num_chs']
|
|
|
|
|
curr_stride = stem_feat_info['reduction'] # reduction does not include pool
|
|
|
|
|
if cfg['stem']['pool']:
|
|
|
|
|
curr_stride *= 2
|
|
|
|
|
self.stem, stem_feat_info = create_csp_stem(in_chans, **asdict(cfg.stem), **layer_args)
|
|
|
|
|
self.feature_info.extend(stem_feat_info[:-1])
|
|
|
|
|
|
|
|
|
|
# Construct the stages
|
|
|
|
|
per_stage_args = _cfg_to_stage_args(
|
|
|
|
|
cfg['stage'], curr_stride=curr_stride, output_stride=output_stride, drop_path_rate=drop_path_rate)
|
|
|
|
|
self.stages = nn.Sequential()
|
|
|
|
|
for i, sa in enumerate(per_stage_args):
|
|
|
|
|
self.stages.add_module(
|
|
|
|
|
str(i), stage_fn(prev_chs, **sa, **layer_args, block_fn=block_fn))
|
|
|
|
|
prev_chs = sa['out_chs']
|
|
|
|
|
curr_stride *= sa['stride']
|
|
|
|
|
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]
|
|
|
|
|
self.stages, stage_feat_info = create_csp_stages(
|
|
|
|
|
cfg,
|
|
|
|
|
drop_path_rate=drop_path_rate,
|
|
|
|
|
output_stride=output_stride,
|
|
|
|
|
stem_feat=stem_feat_info[-1],
|
|
|
|
|
)
|
|
|
|
|
prev_chs = stage_feat_info[-1]['num_chs']
|
|
|
|
|
self.feature_info.extend(stage_feat_info)
|
|
|
|
|
|
|
|
|
|
# Construct the head
|
|
|
|
|
self.num_features = prev_chs
|
|
|
|
@ -729,54 +895,74 @@ def cspresnext50(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def cspdarknet53(pretrained=False, **kwargs):
|
|
|
|
|
return _create_cspnet('cspdarknet53', pretrained=pretrained, block_fn=DarkBlock, **kwargs)
|
|
|
|
|
return _create_cspnet('cspdarknet53', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def darknet17(pretrained=False, **kwargs):
|
|
|
|
|
return _create_cspnet('darknet17', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs)
|
|
|
|
|
return _create_cspnet('darknet17', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def darknet21(pretrained=False, **kwargs):
|
|
|
|
|
return _create_cspnet('darknet21', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs)
|
|
|
|
|
return _create_cspnet('darknet21', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def sedarknet21(pretrained=False, **kwargs):
|
|
|
|
|
return _create_cspnet('sedarknet21', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs)
|
|
|
|
|
return _create_cspnet('sedarknet21', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def darknet53(pretrained=False, **kwargs):
|
|
|
|
|
return _create_cspnet('darknet53', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs)
|
|
|
|
|
return _create_cspnet('darknet53', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def darknetaa53(pretrained=False, **kwargs):
|
|
|
|
|
return _create_cspnet(
|
|
|
|
|
'darknetaa53', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs)
|
|
|
|
|
return _create_cspnet('darknetaa53', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def cs3darknet_s(pretrained=False, **kwargs):
|
|
|
|
|
return _create_cspnet('cs3darknet_s', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def cs3darknet_m(pretrained=False, **kwargs):
|
|
|
|
|
return _create_cspnet(
|
|
|
|
|
'cs3darknet_m', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage3, act_layer='silu', **kwargs)
|
|
|
|
|
return _create_cspnet('cs3darknet_m', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def cs3darknet_l(pretrained=False, **kwargs):
|
|
|
|
|
return _create_cspnet(
|
|
|
|
|
'cs3darknet_l', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage3, act_layer='silu', **kwargs)
|
|
|
|
|
return _create_cspnet('cs3darknet_l', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def cs3darknet_x(pretrained=False, **kwargs):
|
|
|
|
|
return _create_cspnet('cs3darknet_x', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def cs3darknet_focus_s(pretrained=False, **kwargs):
|
|
|
|
|
return _create_cspnet('cs3darknet_focus_s', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def cs3darknet_focus_m(pretrained=False, **kwargs):
|
|
|
|
|
return _create_cspnet(
|
|
|
|
|
'cs3darknet_focus_m', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage3, act_layer='silu', **kwargs)
|
|
|
|
|
return _create_cspnet('cs3darknet_focus_m', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def cs3darknet_focus_l(pretrained=False, **kwargs):
|
|
|
|
|
return _create_cspnet(
|
|
|
|
|
'cs3darknet_focus_l', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage3, act_layer='silu', **kwargs)
|
|
|
|
|
return _create_cspnet('cs3darknet_focus_l', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def cs3darknet_focus_x(pretrained=False, **kwargs):
|
|
|
|
|
return _create_cspnet('cs3darknet_focus_x', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def cs3sedarknet_xdw(pretrained=False, **kwargs):
|
|
|
|
|
return _create_cspnet('cs3sedarknet_xdw', pretrained=pretrained, **kwargs)
|
|
|
|
|