diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index 4591f101..f0a26baf 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -12,7 +12,10 @@ Reference impl via darknet cfg files at https://github.com/WongKinYiu/CrossStage Hacked together by / Copyright 2020 Ross Wightman """ +import collections.abc +from dataclasses import dataclass, field, asdict from functools import partial +from typing import Any, Callable, Dict, Optional, Tuple, Union import torch import torch.nn as nn @@ -20,7 +23,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg, named_apply, MATCH_PREV_GROUP -from .layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, create_attn, get_norm_act_layer +from .layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, create_attn, create_act_layer, make_divisible from .registry import register_model @@ -58,218 +61,278 @@ default_cfgs = { ), 'darknetaa53': _cfg(url=''), + 'cs3darknet_s': _cfg( + url=''), 'cs3darknet_m': _cfg( url=''), 'cs3darknet_l': _cfg( url=''), + 'cs3darknet_x': _cfg( + url=''), + + 'cs3darknet_focus_s': _cfg( + url=''), 'cs3darknet_focus_m': _cfg( url=''), 'cs3darknet_focus_l': _cfg( url=''), + 'cs3darknet_focus_x': _cfg( + url=''), + + 'cs3sedarknet_xdw': _cfg( + url=''), } +@dataclass +class CspStemCfg: + out_chs: Union[int, Tuple[int, ...]] = 32 + stride: Union[int, Tuple[int, ...]] = 2 + kernel_size: int = 3 + padding: Union[int, str] = '' + pool: Optional[str] = '' + + +def _pad_arg(x, n): + # pads an argument tuple to specified n by padding with last value + if not isinstance(x, (tuple, list)): + x = (x,) + curr_n = len(x) + pad_n = n - curr_n + if pad_n <= 0: + return x[:n] + return tuple(x + (x[-1],) * pad_n) + + +@dataclass +class CspStagesCfg: + depth: Tuple[int, ...] = (3, 3, 5, 2) # block depth (number of block repeats in stages) + out_chs: Tuple[int, ...] = (128, 256, 512, 1024) # number of output channels for blocks in stage + stride: Union[int, Tuple[int, ...]] = 2 # stride of stage + groups: Union[int, Tuple[int, ...]] = 1 # num kxk conv groups + block_ratio: Union[float, Tuple[float, ...]] = 1.0 + bottle_ratio: Union[float, Tuple[float, ...]] = 1. # bottleneck-ratio of blocks in stage + avg_down: Union[bool, Tuple[bool, ...]] = False + attn_layer: Optional[Union[str, Tuple[str, ...]]] = None + stage_type: Union[str, Tuple[str]] = 'csp' # stage type ('csp', 'cs2', 'dark') + block_type: Union[str, Tuple[str]] = 'bottle' # blocks type for stages ('bottle', 'dark') + + # cross-stage only + expand_ratio: Union[float, Tuple[float, ...]] = 1.0 + cross_linear: Union[bool, Tuple[bool, ...]] = False + down_growth: Union[bool, Tuple[bool, ...]] = False + + def __post_init__(self): + n = len(self.depth) + assert len(self.out_chs) == n + self.stride = _pad_arg(self.stride, n) + self.groups = _pad_arg(self.groups, n) + self.block_ratio = _pad_arg(self.block_ratio, n) + self.bottle_ratio = _pad_arg(self.bottle_ratio, n) + self.avg_down = _pad_arg(self.avg_down, n) + self.attn_layer = _pad_arg(self.attn_layer, n) + self.stage_type = _pad_arg(self.stage_type, n) + self.block_type = _pad_arg(self.block_type, n) + + self.expand_ratio = _pad_arg(self.expand_ratio, n) + self.cross_linear = _pad_arg(self.cross_linear, n) + self.down_growth = _pad_arg(self.down_growth, n) + + +@dataclass +class CspModelCfg: + stem: CspStemCfg + stages: CspStagesCfg + zero_init_last: bool = True # zero init last weight (usually bn) in residual path + act_layer: str = 'relu' + norm_layer: str = 'batchnorm' + aa_layer: Optional[str] = None # FIXME support string factory for this + + +def _cs3darknet_cfg(width_multiplier=1.0, depth_multiplier=1.0, avg_down=False, act_layer='silu', focus=False): + if focus: + stem_cfg = CspStemCfg( + out_chs=make_divisible(64 * width_multiplier), + kernel_size=6, stride=2, padding=2, pool='') + else: + stem_cfg = CspStemCfg( + out_chs=tuple([make_divisible(c * width_multiplier) for c in (32, 64)]), + kernel_size=3, stride=2, pool='') + return CspModelCfg( + stem=stem_cfg, + stages=CspStagesCfg( + out_chs=tuple([make_divisible(c * width_multiplier) for c in (128, 256, 512, 1024)]), + depth=tuple([int(d * depth_multiplier) for d in (3, 6, 9, 3)]), + stride=2, + bottle_ratio=1., + block_ratio=0.5, + avg_down=avg_down, + stage_type='cs3', + block_type='dark', + ), + act_layer=act_layer, + ) + + model_cfgs = dict( - cspresnet50=dict( - stem=dict(out_chs=64, kernel_size=7, stride=2, pool='max'), - stage=dict( - out_chs=(128, 256, 512, 1024), + cspresnet50=CspModelCfg( + stem=CspStemCfg(out_chs=64, kernel_size=7, stride=4, pool='max'), + stages=CspStagesCfg( depth=(3, 3, 5, 2), - stride=(1,) + (2,) * 3, - exp_ratio=(2.,) * 4, - bottle_ratio=(0.5,) * 4, - block_ratio=(1.,) * 4, + out_chs=(128, 256, 512, 1024), + stride=(1, 2), + expand_ratio=2., + bottle_ratio=0.5, cross_linear=True, - ) + ), ), - cspresnet50d=dict( - stem=dict(out_chs=[32, 32, 64], kernel_size=3, stride=2, pool='max'), - stage=dict( - out_chs=(128, 256, 512, 1024), + cspresnet50d=CspModelCfg( + stem=CspStemCfg(out_chs=(32, 32, 64), kernel_size=3, stride=4, pool='max'), + stages=CspStagesCfg( depth=(3, 3, 5, 2), - stride=(1,) + (2,) * 3, - exp_ratio=(2.,) * 4, - bottle_ratio=(0.5,) * 4, - block_ratio=(1.,) * 4, + out_chs=(128, 256, 512, 1024), + stride=(1,) + (2,), + expand_ratio=2., + bottle_ratio=0.5, + block_ratio=1., cross_linear=True, ) ), - cspresnet50w=dict( - stem=dict(out_chs=[32, 32, 64], kernel_size=3, stride=2, pool='max'), - stage=dict( - out_chs=(256, 512, 1024, 2048), + cspresnet50w=CspModelCfg( + stem=CspStemCfg(out_chs=(32, 32, 64), kernel_size=3, stride=4, pool='max'), + stages=CspStagesCfg( depth=(3, 3, 5, 2), - stride=(1,) + (2,) * 3, - exp_ratio=(1.,) * 4, - bottle_ratio=(0.25,) * 4, - block_ratio=(0.5,) * 4, + out_chs=(256, 512, 1024, 2048), + stride=(1,) + (2,), + expand_ratio=1., + bottle_ratio=0.25, + block_ratio=0.5, cross_linear=True, ) ), - cspresnext50=dict( - stem=dict(out_chs=64, kernel_size=7, stride=2, pool='max'), - stage=dict( - out_chs=(256, 512, 1024, 2048), + cspresnext50=CspModelCfg( + stem=CspStemCfg(out_chs=64, kernel_size=7, stride=4, pool='max'), + stages=CspStagesCfg( depth=(3, 3, 5, 2), - stride=(1,) + (2,) * 3, - groups=(32,) * 4, - exp_ratio=(1.,) * 4, - bottle_ratio=(1.,) * 4, - block_ratio=(0.5,) * 4, + out_chs=(256, 512, 1024, 2048), + stride=(1,) + (2,), + groups=32, + expand_ratio=1., + bottle_ratio=1., + block_ratio=0.5, cross_linear=True, ) ), - cspdarknet53=dict( - stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), - stage=dict( - out_chs=(64, 128, 256, 512, 1024), + cspdarknet53=CspModelCfg( + stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''), + stages=CspStagesCfg( depth=(1, 2, 8, 8, 4), - stride=(2,) * 5, - exp_ratio=(2.,) + (1.,) * 4, - bottle_ratio=(0.5,) + (1.0,) * 4, - block_ratio=(1.,) + (0.5,) * 4, + out_chs=(64, 128, 256, 512, 1024), + stride=2, + expand_ratio=(2.,) + (1.,), + bottle_ratio=(0.5,) + (1.,), + block_ratio=(1.,) + (0.5,), down_growth=True, - ) + block_type='dark', + ), + act_layer='leaky_relu', ), - darknet17=dict( - stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), - stage=dict( - out_chs=(64, 128, 256, 512, 1024), + darknet17=CspModelCfg( + stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''), + stages=CspStagesCfg( depth=(1,) * 5, - stride=(2,) * 5, - bottle_ratio=(0.5,) * 5, - block_ratio=(1.,) * 5, - ) - ), - darknet21=dict( - stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), - stage=dict( out_chs=(64, 128, 256, 512, 1024), - depth=(1, 1, 1, 2, 2), - stride=(2,) * 5, - bottle_ratio=(0.5,) * 5, - block_ratio=(1.,) * 5, - ) + stride=(2,), + bottle_ratio=(0.5,), + block_ratio=(1.,), + stage_type='dark', + block_type='dark', + ), + act_layer='leaky_relu', ), - sedarknet21=dict( - stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), - stage=dict( - out_chs=(64, 128, 256, 512, 1024), + darknet21=CspModelCfg( + stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''), + stages=CspStagesCfg( depth=(1, 1, 1, 2, 2), - stride=(2,) * 5, - bottle_ratio=(0.5,) * 5, - block_ratio=(1.,) * 5, - attn_layer=('se',) * 5, - ) - ), - darknet53=dict( - stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), - stage=dict( out_chs=(64, 128, 256, 512, 1024), - depth=(1, 2, 8, 8, 4), - stride=(2,) * 5, - bottle_ratio=(0.5,) * 5, - block_ratio=(1.,) * 5, - ) - ), + stride=(2,), + bottle_ratio=(0.5,), + block_ratio=(1.,), + stage_type='dark', + block_type='dark', - darknetaa53=dict( - stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), - stage=dict( - out_chs=(64, 128, 256, 512, 1024), - depth=(1, 2, 8, 8, 4), - stride=(2,) * 5, - bottle_ratio=(0.5,) * 5, - block_ratio=(1.,) * 5, - avg_down=True, ), + act_layer='leaky_relu', ), + sedarknet21=CspModelCfg( + stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''), + stages=CspStagesCfg( + depth=(1, 1, 1, 2, 2), + out_chs=(64, 128, 256, 512, 1024), + stride=2, + bottle_ratio=0.5, + block_ratio=1., + attn_layer='se', + stage_type='dark', + block_type='dark', - cs3darknet_m=dict( - stem=dict(out_chs=(24, 48), kernel_size=3, stride=2, pool=''), - stage=dict( - out_chs=(96, 192, 384, 768), - depth=(2, 4, 6, 2), - stride=(2,) * 4, - bottle_ratio=(1.,) * 4, - block_ratio=(0.5,) * 4, - avg_down=False, ), + act_layer='leaky_relu', ), - cs3darknet_l=dict( - stem=dict(out_chs=(32, 64), kernel_size=3, stride=2, pool=''), - stage=dict( - out_chs=(128, 256, 512, 1024), - depth=(3, 6, 9, 3), - stride=(2,) * 4, - bottle_ratio=(1.,) * 4, - block_ratio=(0.5,) * 4, - avg_down=False, + darknet53=CspModelCfg( + stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''), + stages=CspStagesCfg( + depth=(1, 2, 8, 8, 4), + out_chs=(64, 128, 256, 512, 1024), + stride=2, + bottle_ratio=0.5, + block_ratio=1., + stage_type='dark', + block_type='dark', ), + act_layer='leaky_relu', ), - - cs3darknet_focus_m=dict( - stem=dict(out_chs=48, kernel_size=6, stride=2, padding=2, pool=''), - stage=dict( - out_chs=(96, 192, 384, 768), - depth=(2, 4, 6, 2), - stride=(2,) * 4, - bottle_ratio=(1.,) * 4, - block_ratio=(0.5,) * 4, - avg_down=False, + darknetaa53=CspModelCfg( + stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''), + stages=CspStagesCfg( + depth=(1, 2, 8, 8, 4), + out_chs=(64, 128, 256, 512, 1024), + stride=2, + bottle_ratio=0.5, + block_ratio=1., + avg_down=True, + stage_type='dark', + block_type='dark', ), + act_layer='leaky_relu', ), - cs3darknet_focus_l=dict( - stem=dict(out_chs=64, kernel_size=6, stride=2, padding=2, pool=''), - stage=dict( - out_chs=(128, 256, 512, 1024), - depth=(3, 6, 9, 3), - stride=(2,) * 4, - bottle_ratio=(1.,) * 4, - block_ratio=(0.5,) * 4, - avg_down=False, - ), - ) -) + cs3darknet_s=_cs3darknet_cfg(width_multiplier=0.5, depth_multiplier=0.5), + cs3darknet_m=_cs3darknet_cfg(width_multiplier=0.75, depth_multiplier=0.67), + cs3darknet_l=_cs3darknet_cfg(), + cs3darknet_x=_cs3darknet_cfg(width_multiplier=1.25, depth_multiplier=1.33), -def create_stem( - in_chans=3, - out_chs=32, - kernel_size=3, - stride=2, - pool='', - padding='', - act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, - aa_layer=None -): - stem = nn.Sequential() - if not isinstance(out_chs, (tuple, list)): - out_chs = [out_chs] - assert len(out_chs) - in_c = in_chans - for i, out_c in enumerate(out_chs): - conv_name = f'conv{i + 1}' - stem.add_module(conv_name, ConvNormAct( - in_c, out_c, kernel_size, - stride=stride if i == 0 else 1, - padding=padding if i == 0 else '', - act_layer=act_layer, - norm_layer=norm_layer - )) - in_c = out_c - last_conv = conv_name - if pool: - if aa_layer is not None: - stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=1, padding=1)) - stem.add_module('aa', aa_layer(channels=in_c, stride=2)) - else: - stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) - return stem, dict(num_chs=in_c, reduction=stride, module='.'.join(['stem', last_conv])) + cs3darknet_focus_s=_cs3darknet_cfg(width_multiplier=0.5, depth_multiplier=0.5, focus=True), + cs3darknet_focus_m=_cs3darknet_cfg(width_multiplier=0.75, depth_multiplier=0.67, focus=True), + cs3darknet_focus_l=_cs3darknet_cfg(focus=True), + cs3darknet_focus_x=_cs3darknet_cfg(width_multiplier=1.25, depth_multiplier=1.33, focus=True), + + cs3sedarknet_xdw=CspModelCfg( + stem=CspStemCfg(out_chs=(32, 64), kernel_size=3, stride=2, pool=''), + stages=CspStagesCfg( + depth=(3, 6, 12, 4), + out_chs=(256, 512, 1024, 2048), + stride=2, + groups=(1, 1, 256, 512), + bottle_ratio=0.5, + block_ratio=0.5, + attn_layer='se', + ), + ), +) -class ResBottleneck(nn.Module): +class BottleneckBlock(nn.Module): """ ResNe(X)t Bottleneck Block """ @@ -286,9 +349,9 @@ class ResBottleneck(nn.Module): attn_layer=None, aa_layer=None, drop_block=None, - drop_path=None + drop_path=0. ): - super(ResBottleneck, self).__init__() + super(BottleneckBlock, self).__init__() mid_chs = int(round(out_chs * bottle_ratio)) ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer) @@ -299,8 +362,8 @@ class ResBottleneck(nn.Module): self.attn2 = create_attn(attn_layer, channels=mid_chs) if not attn_last else None self.conv3 = ConvNormAct(mid_chs, out_chs, kernel_size=1, apply_act=False, **ckwargs) self.attn3 = create_attn(attn_layer, channels=out_chs) if attn_last else None - self.drop_path = drop_path - self.act3 = act_layer() + self.drop_path = DropPath(drop_path) if drop_path else nn.Identity() + self.act3 = create_act_layer(act_layer) def zero_init_last(self): nn.init.zeros_(self.conv3.bn.weight) @@ -314,9 +377,7 @@ class ResBottleneck(nn.Module): x = self.conv3(x) if self.attn3 is not None: x = self.attn3(x) - if self.drop_path is not None: - x = self.drop_path(x) - x = x + shortcut + x = self.drop_path(x) + shortcut # FIXME partial shortcut needed if first block handled as per original, not used for my current impl #x[:, :shortcut.size(1)] += shortcut x = self.act3(x) @@ -339,7 +400,7 @@ class DarkBlock(nn.Module): attn_layer=None, aa_layer=None, drop_block=None, - drop_path=None + drop_path=0. ): super(DarkBlock, self).__init__() mid_chs = int(round(out_chs * bottle_ratio)) @@ -349,7 +410,7 @@ class DarkBlock(nn.Module): mid_chs, out_chs, kernel_size=3, dilation=dilation, groups=groups, aa_layer=aa_layer, drop_layer=drop_block, **ckwargs) self.attn = create_attn(attn_layer, channels=out_chs, act_layer=act_layer) - self.drop_path = drop_path + self.drop_path = DropPath(drop_path) if drop_path else nn.Identity() def zero_init_last(self): nn.init.zeros_(self.conv2.bn.weight) @@ -360,9 +421,7 @@ class DarkBlock(nn.Module): x = self.conv2(x) if self.attn is not None: x = self.attn(x) - if self.drop_path is not None: - x = self.drop_path(x) - x = x + shortcut + x = self.drop_path(x) + shortcut return x @@ -377,27 +436,27 @@ class CrossStage(nn.Module): depth, block_ratio=1., bottle_ratio=1., - exp_ratio=1., + expand_ratio=1., groups=1, first_dilation=None, avg_down=False, down_growth=False, cross_linear=False, block_dpr=None, - block_fn=ResBottleneck, + block_fn=BottleneckBlock, **block_kwargs ): super(CrossStage, self).__init__() first_dilation = first_dilation or dilation down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels - self.exp_chs = exp_chs = int(round(out_chs * exp_ratio)) + self.expand_chs = exp_chs = int(round(out_chs * expand_ratio)) block_out_chs = int(round(out_chs * block_ratio)) conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer')) if stride != 1 or first_dilation != dilation: if avg_down: self.conv_down = nn.Sequential( - nn.AvgPool2d(3, 2, 1) if stride == 2 else nn.Identity(), # FIXME dilation handling + nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) ) else: @@ -417,9 +476,15 @@ class CrossStage(nn.Module): self.blocks = nn.Sequential() for i in range(depth): - drop_path = DropPath(block_dpr[i]) if block_dpr and block_dpr[i] else None self.blocks.add_module(str(i), block_fn( - prev_chs, block_out_chs, dilation, bottle_ratio, groups, drop_path=drop_path, **block_kwargs)) + in_chs=prev_chs, + out_chs=block_out_chs, + dilation=dilation, + bottle_ratio=bottle_ratio, + groups=groups, + drop_path=block_dpr[i] if block_dpr is not None else 0., + **block_kwargs + )) prev_chs = block_out_chs # transition convs @@ -429,7 +494,7 @@ class CrossStage(nn.Module): def forward(self, x): x = self.conv_down(x) x = self.conv_exp(x) - xs, xb = x.split(self.exp_chs // 2, dim=1) + xs, xb = x.split(self.expand_chs // 2, dim=1) xb = self.blocks(xb) xb = self.conv_transition_b(xb).contiguous() out = self.conv_transition(torch.cat([xs, xb], dim=1)) @@ -449,27 +514,27 @@ class CrossStage3(nn.Module): depth, block_ratio=1., bottle_ratio=1., - exp_ratio=1., + expand_ratio=1., groups=1, first_dilation=None, avg_down=False, down_growth=False, cross_linear=False, block_dpr=None, - block_fn=ResBottleneck, + block_fn=BottleneckBlock, **block_kwargs ): super(CrossStage3, self).__init__() first_dilation = first_dilation or dilation down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels - self.exp_chs = exp_chs = int(round(out_chs * exp_ratio)) + self.expand_chs = exp_chs = int(round(out_chs * expand_ratio)) block_out_chs = int(round(out_chs * block_ratio)) conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer')) if stride != 1 or first_dilation != dilation: if avg_down: self.conv_down = nn.Sequential( - nn.AvgPool2d(3, 2, 1) if stride == 2 else nn.Identity(), # FIXME dilation handling + nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) ) else: @@ -487,9 +552,15 @@ class CrossStage3(nn.Module): self.blocks = nn.Sequential() for i in range(depth): - drop_path = DropPath(block_dpr[i]) if block_dpr and block_dpr[i] else None self.blocks.add_module(str(i), block_fn( - prev_chs, block_out_chs, dilation, bottle_ratio, groups, drop_path=drop_path, **block_kwargs)) + in_chs=prev_chs, + out_chs=block_out_chs, + dilation=dilation, + bottle_ratio=bottle_ratio, + groups=groups, + drop_path=block_dpr[i] if block_dpr is not None else 0., + **block_kwargs + )) prev_chs = block_out_chs # transition convs @@ -498,7 +569,7 @@ class CrossStage3(nn.Module): def forward(self, x): x = self.conv_down(x) x = self.conv_exp(x) - x1, x2 = x.split(self.exp_chs // 2, dim=1) + x1, x2 = x.split(self.expand_chs // 2, dim=1) x1 = self.blocks(x1) out = self.conv_transition(torch.cat([x1, x2], dim=1)) return out @@ -519,7 +590,7 @@ class DarkStage(nn.Module): groups=1, first_dilation=None, avg_down=False, - block_fn=ResBottleneck, + block_fn=BottleneckBlock, block_dpr=None, **block_kwargs ): @@ -529,7 +600,7 @@ class DarkStage(nn.Module): if avg_down: self.conv_down = nn.Sequential( - nn.AvgPool2d(3, 2, 1) if stride == 2 else nn.Identity(), # FIXME dilation handling + nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) ) else: @@ -541,9 +612,15 @@ class DarkStage(nn.Module): block_out_chs = int(round(out_chs * block_ratio)) self.blocks = nn.Sequential() for i in range(depth): - drop_path = DropPath(block_dpr[i]) if block_dpr and block_dpr[i] else None self.blocks.add_module(str(i), block_fn( - prev_chs, block_out_chs, dilation, bottle_ratio, groups, drop_path=drop_path, **block_kwargs)) + in_chs=prev_chs, + out_chs=block_out_chs, + dilation=dilation, + bottle_ratio=bottle_ratio, + groups=groups, + drop_path=block_dpr[i] if block_dpr is not None else 0., + **block_kwargs + )) prev_chs = block_out_chs def forward(self, x): @@ -552,38 +629,131 @@ class DarkStage(nn.Module): return x -def _cfg_to_stage_args(cfg, curr_stride=2, output_stride=32, drop_path_rate=0.): - # get per stage args for stage and containing blocks, calculate strides to meet target output_stride - num_stages = len(cfg['depth']) - if 'groups' not in cfg: - cfg['groups'] = (1,) * num_stages - if 'down_growth' in cfg and not isinstance(cfg['down_growth'], (list, tuple)): - cfg['down_growth'] = (cfg['down_growth'],) * num_stages - if 'cross_linear' in cfg and not isinstance(cfg['cross_linear'], (list, tuple)): - cfg['cross_linear'] = (cfg['cross_linear'],) * num_stages - if 'avg_down' in cfg and not isinstance(cfg['avg_down'], (list, tuple)): - cfg['avg_down'] = (cfg['avg_down'],) * num_stages - cfg['block_dpr'] = [None] * num_stages if not drop_path_rate else \ - [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg['depth'])).split(cfg['depth'])] - stage_strides = [] - stage_dilations = [] - stage_first_dilations = [] +def create_csp_stem( + in_chans=3, + out_chs=32, + kernel_size=3, + stride=2, + pool='', + padding='', + act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, + aa_layer=None +): + stem = nn.Sequential() + feature_info = [] + if not isinstance(out_chs, (tuple, list)): + out_chs = [out_chs] + stem_depth = len(out_chs) + assert stem_depth + assert stride in (1, 2, 4) + prev_feat = None + prev_chs = in_chans + last_idx = stem_depth - 1 + stem_stride = 1 + for i, chs in enumerate(out_chs): + conv_name = f'conv{i + 1}' + conv_stride = 2 if (i == 0 and stride > 1) or (i == last_idx and stride > 2 and not pool) else 1 + if conv_stride > 1 and prev_feat is not None: + feature_info.append(prev_feat) + stem.add_module(conv_name, ConvNormAct( + prev_chs, chs, kernel_size, + stride=conv_stride, + padding=padding if i == 0 else '', + act_layer=act_layer, + norm_layer=norm_layer + )) + stem_stride *= conv_stride + prev_chs = chs + prev_feat = dict(num_chs=prev_chs, reduction=stem_stride, module='.'.join(['stem', conv_name])) + if pool: + assert stride > 2 + if prev_feat is not None: + feature_info.append(prev_feat) + if aa_layer is not None: + stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=1, padding=1)) + stem.add_module('aa', aa_layer(channels=prev_chs, stride=2)) + pool_name = 'aa' + else: + stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + pool_name = 'pool' + stem_stride *= 2 + prev_feat = dict(num_chs=prev_chs, reduction=stem_stride, module='.'.join(['stem', pool_name])) + feature_info.append(prev_feat) + return stem, feature_info + + +def _get_stage_fn(stage_type: str, stage_args): + assert stage_type in ('dark', 'csp', 'cs3') + if stage_type == 'dark': + stage_args.pop('expand_ratio', None) + stage_args.pop('cross_linear', None) + stage_args.pop('down_growth', None) + stage_fn = DarkStage + elif stage_type == 'csp': + stage_fn = CrossStage + else: + stage_fn = CrossStage3 + return stage_fn, stage_args + + +def _get_block_fn(stage_type: str, stage_args): + assert stage_type in ('dark', 'bottle') + if stage_type == 'dark': + return DarkBlock, stage_args + else: + return BottleneckBlock, stage_args + + +def create_csp_stages( + cfg: CspModelCfg, + drop_path_rate: float, + output_stride: int, + stem_feat: Dict[str, Any] +): + cfg_dict = asdict(cfg.stages) + num_stages = len(cfg.stages.depth) + cfg_dict['block_dpr'] = [None] * num_stages if not drop_path_rate else \ + [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.stages.depth)).split(cfg.stages.depth)] + stage_args = [dict(zip(cfg_dict.keys(), values)) for values in zip(*cfg_dict.values())] + block_kwargs = dict( + act_layer=cfg.act_layer, + norm_layer=cfg.norm_layer, + aa_layer=cfg.aa_layer + ) + dilation = 1 - for cfg_stride in cfg['stride']: - stage_first_dilations.append(dilation) - if curr_stride >= output_stride: - dilation *= cfg_stride + net_stride = stem_feat['reduction'] + prev_chs = stem_feat['num_chs'] + prev_feat = stem_feat + feature_info = [] + stages = [] + for stage_idx, stage_args in enumerate(stage_args): + stage_fn, stage_args = _get_stage_fn(stage_args.pop('stage_type'), stage_args) + block_fn, stage_args = _get_block_fn(stage_args.pop('block_type'), stage_args) + stride = stage_args.pop('stride') + if stride != 1 and prev_feat: + feature_info.append(prev_feat) + if net_stride >= output_stride and stride > 1: + dilation *= stride stride = 1 - else: - stride = cfg_stride - curr_stride *= stride - stage_strides.append(stride) - stage_dilations.append(dilation) - cfg['stride'] = stage_strides - cfg['dilation'] = stage_dilations - cfg['first_dilation'] = stage_first_dilations - stage_args = [dict(zip(cfg.keys(), values)) for values in zip(*cfg.values())] - return stage_args + net_stride *= stride + first_dilation = 1 if dilation in (1, 2) else 2 + + stages += [stage_fn( + prev_chs, + **stage_args, + stride=stride, + first_dilation=first_dilation, + dilation=dilation, + block_fn=block_fn, + **block_kwargs, + )] + prev_chs = stage_args['out_chs'] + prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}') + + feature_info.append(prev_feat) + return nn.Sequential(*stages), feature_info class CspNet(nn.Module): @@ -598,43 +768,39 @@ class CspNet(nn.Module): def __init__( self, - cfg, + cfg: CspModelCfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', - act_layer=nn.LeakyReLU, - norm_layer=nn.BatchNorm2d, - aa_layer=None, drop_rate=0., drop_path_rate=0., - zero_init_last=True, - stage_fn=CrossStage, - block_fn=ResBottleneck): + zero_init_last=True + ): super().__init__() self.num_classes = num_classes self.drop_rate = drop_rate assert output_stride in (8, 16, 32) - layer_args = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer) + layer_args = dict( + act_layer=cfg.act_layer, + norm_layer=cfg.norm_layer, + aa_layer=cfg.aa_layer + ) + self.feature_info = [] # Construct the stem - self.stem, stem_feat_info = create_stem(in_chans, **cfg['stem'], **layer_args) - self.feature_info = [stem_feat_info] - prev_chs = stem_feat_info['num_chs'] - curr_stride = stem_feat_info['reduction'] # reduction does not include pool - if cfg['stem']['pool']: - curr_stride *= 2 + self.stem, stem_feat_info = create_csp_stem(in_chans, **asdict(cfg.stem), **layer_args) + self.feature_info.extend(stem_feat_info[:-1]) # Construct the stages - per_stage_args = _cfg_to_stage_args( - cfg['stage'], curr_stride=curr_stride, output_stride=output_stride, drop_path_rate=drop_path_rate) - self.stages = nn.Sequential() - for i, sa in enumerate(per_stage_args): - self.stages.add_module( - str(i), stage_fn(prev_chs, **sa, **layer_args, block_fn=block_fn)) - prev_chs = sa['out_chs'] - curr_stride *= sa['stride'] - self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')] + self.stages, stage_feat_info = create_csp_stages( + cfg, + drop_path_rate=drop_path_rate, + output_stride=output_stride, + stem_feat=stem_feat_info[-1], + ) + prev_chs = stage_feat_info[-1]['num_chs'] + self.feature_info.extend(stage_feat_info) # Construct the head self.num_features = prev_chs @@ -729,54 +895,74 @@ def cspresnext50(pretrained=False, **kwargs): @register_model def cspdarknet53(pretrained=False, **kwargs): - return _create_cspnet('cspdarknet53', pretrained=pretrained, block_fn=DarkBlock, **kwargs) + return _create_cspnet('cspdarknet53', pretrained=pretrained, **kwargs) @register_model def darknet17(pretrained=False, **kwargs): - return _create_cspnet('darknet17', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) + return _create_cspnet('darknet17', pretrained=pretrained, **kwargs) @register_model def darknet21(pretrained=False, **kwargs): - return _create_cspnet('darknet21', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) + return _create_cspnet('darknet21', pretrained=pretrained, **kwargs) @register_model def sedarknet21(pretrained=False, **kwargs): - return _create_cspnet('sedarknet21', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) + return _create_cspnet('sedarknet21', pretrained=pretrained, **kwargs) @register_model def darknet53(pretrained=False, **kwargs): - return _create_cspnet('darknet53', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) + return _create_cspnet('darknet53', pretrained=pretrained, **kwargs) @register_model def darknetaa53(pretrained=False, **kwargs): - return _create_cspnet( - 'darknetaa53', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) + return _create_cspnet('darknetaa53', pretrained=pretrained, **kwargs) + + +@register_model +def cs3darknet_s(pretrained=False, **kwargs): + return _create_cspnet('cs3darknet_s', pretrained=pretrained, **kwargs) @register_model def cs3darknet_m(pretrained=False, **kwargs): - return _create_cspnet( - 'cs3darknet_m', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage3, act_layer='silu', **kwargs) + return _create_cspnet('cs3darknet_m', pretrained=pretrained, **kwargs) @register_model def cs3darknet_l(pretrained=False, **kwargs): - return _create_cspnet( - 'cs3darknet_l', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage3, act_layer='silu', **kwargs) + return _create_cspnet('cs3darknet_l', pretrained=pretrained, **kwargs) + + +@register_model +def cs3darknet_x(pretrained=False, **kwargs): + return _create_cspnet('cs3darknet_x', pretrained=pretrained, **kwargs) + + +@register_model +def cs3darknet_focus_s(pretrained=False, **kwargs): + return _create_cspnet('cs3darknet_focus_s', pretrained=pretrained, **kwargs) @register_model def cs3darknet_focus_m(pretrained=False, **kwargs): - return _create_cspnet( - 'cs3darknet_focus_m', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage3, act_layer='silu', **kwargs) + return _create_cspnet('cs3darknet_focus_m', pretrained=pretrained, **kwargs) @register_model def cs3darknet_focus_l(pretrained=False, **kwargs): - return _create_cspnet( - 'cs3darknet_focus_l', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage3, act_layer='silu', **kwargs) \ No newline at end of file + return _create_cspnet('cs3darknet_focus_l', pretrained=pretrained, **kwargs) + + +@register_model +def cs3darknet_focus_x(pretrained=False, **kwargs): + return _create_cspnet('cs3darknet_focus_x', pretrained=pretrained, **kwargs) + + +@register_model +def cs3sedarknet_xdw(pretrained=False, **kwargs): + return _create_cspnet('cs3sedarknet_xdw', pretrained=pretrained, **kwargs)