|
|
|
@ -89,9 +89,17 @@ default_cfgs = {
|
|
|
|
|
'cs3sedarknet_l': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3sedarknet_l_c2ns-e8d1dc13.pth',
|
|
|
|
|
interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
|
|
|
|
'cs3sedarknet_x': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3sedarknet_x_c2ns-b4d0abc0.pth',
|
|
|
|
|
interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
|
|
|
|
|
|
|
|
|
'cs3sedarknet_xdw': _cfg(
|
|
|
|
|
url='', interpolation='bicubic'),
|
|
|
|
|
|
|
|
|
|
'cs3edgenet_x': _cfg(
|
|
|
|
|
url='', interpolation='bicubic'),
|
|
|
|
|
'cs3se_edgenet_x': _cfg(
|
|
|
|
|
url='', interpolation='bicubic'),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -162,7 +170,7 @@ class CspModelCfg:
|
|
|
|
|
aa_layer: Optional[str] = None # FIXME support string factory for this
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cs3darknet_cfg(
|
|
|
|
|
def _cs3_cfg(
|
|
|
|
|
width_multiplier=1.0,
|
|
|
|
|
depth_multiplier=1.0,
|
|
|
|
|
avg_down=False,
|
|
|
|
@ -170,6 +178,8 @@ def _cs3darknet_cfg(
|
|
|
|
|
focus=False,
|
|
|
|
|
attn_layer=None,
|
|
|
|
|
attn_kwargs=None,
|
|
|
|
|
bottle_ratio=1.0,
|
|
|
|
|
block_type='dark',
|
|
|
|
|
):
|
|
|
|
|
if focus:
|
|
|
|
|
stem_cfg = CspStemCfg(
|
|
|
|
@ -185,13 +195,13 @@ def _cs3darknet_cfg(
|
|
|
|
|
out_chs=tuple([make_divisible(c * width_multiplier) for c in (128, 256, 512, 1024)]),
|
|
|
|
|
depth=tuple([int(d * depth_multiplier) for d in (3, 6, 9, 3)]),
|
|
|
|
|
stride=2,
|
|
|
|
|
bottle_ratio=1.,
|
|
|
|
|
bottle_ratio=bottle_ratio,
|
|
|
|
|
block_ratio=0.5,
|
|
|
|
|
avg_down=avg_down,
|
|
|
|
|
attn_layer=attn_layer,
|
|
|
|
|
attn_kwargs=attn_kwargs,
|
|
|
|
|
stage_type='cs3',
|
|
|
|
|
block_type='dark',
|
|
|
|
|
block_type=block_type,
|
|
|
|
|
),
|
|
|
|
|
act_layer=act_layer,
|
|
|
|
|
)
|
|
|
|
@ -324,17 +334,18 @@ model_cfgs = dict(
|
|
|
|
|
),
|
|
|
|
|
),
|
|
|
|
|
|
|
|
|
|
cs3darknet_s=_cs3darknet_cfg(width_multiplier=0.5, depth_multiplier=0.5),
|
|
|
|
|
cs3darknet_m=_cs3darknet_cfg(width_multiplier=0.75, depth_multiplier=0.67),
|
|
|
|
|
cs3darknet_l=_cs3darknet_cfg(),
|
|
|
|
|
cs3darknet_x=_cs3darknet_cfg(width_multiplier=1.25, depth_multiplier=1.33),
|
|
|
|
|
cs3darknet_s=_cs3_cfg(width_multiplier=0.5, depth_multiplier=0.5),
|
|
|
|
|
cs3darknet_m=_cs3_cfg(width_multiplier=0.75, depth_multiplier=0.67),
|
|
|
|
|
cs3darknet_l=_cs3_cfg(),
|
|
|
|
|
cs3darknet_x=_cs3_cfg(width_multiplier=1.25, depth_multiplier=1.33),
|
|
|
|
|
|
|
|
|
|
cs3darknet_focus_s=_cs3darknet_cfg(width_multiplier=0.5, depth_multiplier=0.5, focus=True),
|
|
|
|
|
cs3darknet_focus_m=_cs3darknet_cfg(width_multiplier=0.75, depth_multiplier=0.67, focus=True),
|
|
|
|
|
cs3darknet_focus_l=_cs3darknet_cfg(focus=True),
|
|
|
|
|
cs3darknet_focus_x=_cs3darknet_cfg(width_multiplier=1.25, depth_multiplier=1.33, focus=True),
|
|
|
|
|
cs3darknet_focus_s=_cs3_cfg(width_multiplier=0.5, depth_multiplier=0.5, focus=True),
|
|
|
|
|
cs3darknet_focus_m=_cs3_cfg(width_multiplier=0.75, depth_multiplier=0.67, focus=True),
|
|
|
|
|
cs3darknet_focus_l=_cs3_cfg(focus=True),
|
|
|
|
|
cs3darknet_focus_x=_cs3_cfg(width_multiplier=1.25, depth_multiplier=1.33, focus=True),
|
|
|
|
|
|
|
|
|
|
cs3sedarknet_l=_cs3darknet_cfg(attn_layer='se', attn_kwargs=dict(rd_ratio=.25)),
|
|
|
|
|
cs3sedarknet_l=_cs3_cfg(attn_layer='se', attn_kwargs=dict(rd_ratio=.25)),
|
|
|
|
|
cs3sedarknet_x=_cs3_cfg(attn_layer='se', width_multiplier=1.25, depth_multiplier=1.33),
|
|
|
|
|
|
|
|
|
|
cs3sedarknet_xdw=CspModelCfg(
|
|
|
|
|
stem=CspStemCfg(out_chs=(32, 64), kernel_size=3, stride=2, pool=''),
|
|
|
|
@ -349,6 +360,11 @@ model_cfgs = dict(
|
|
|
|
|
),
|
|
|
|
|
act_layer='silu',
|
|
|
|
|
),
|
|
|
|
|
|
|
|
|
|
cs3edgenet_x=_cs3_cfg(width_multiplier=1.25, depth_multiplier=1.33, bottle_ratio=1.5, block_type='edge'),
|
|
|
|
|
cs3se_edgenet_x=_cs3_cfg(
|
|
|
|
|
width_multiplier=1.25, depth_multiplier=1.33, bottle_ratio=1.5, block_type='edge',
|
|
|
|
|
attn_layer='se', attn_kwargs=dict(rd_ratio=.25)),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -367,7 +383,6 @@ class BottleneckBlock(nn.Module):
|
|
|
|
|
norm_layer=nn.BatchNorm2d,
|
|
|
|
|
attn_last=False,
|
|
|
|
|
attn_layer=None,
|
|
|
|
|
aa_layer=None,
|
|
|
|
|
drop_block=None,
|
|
|
|
|
drop_path=0.
|
|
|
|
|
):
|
|
|
|
@ -378,9 +393,9 @@ class BottleneckBlock(nn.Module):
|
|
|
|
|
attn_first = attn_layer is not None and not attn_last
|
|
|
|
|
|
|
|
|
|
self.conv1 = ConvNormAct(in_chs, mid_chs, kernel_size=1, **ckwargs)
|
|
|
|
|
self.conv2 = ConvNormActAa(
|
|
|
|
|
self.conv2 = ConvNormAct(
|
|
|
|
|
mid_chs, mid_chs, kernel_size=3, dilation=dilation, groups=groups,
|
|
|
|
|
aa_layer=aa_layer, drop_layer=drop_block, **ckwargs)
|
|
|
|
|
drop_layer=drop_block, **ckwargs)
|
|
|
|
|
self.attn2 = attn_layer(mid_chs, act_layer=act_layer) if attn_first else nn.Identity()
|
|
|
|
|
self.conv3 = ConvNormAct(mid_chs, out_chs, kernel_size=1, apply_act=False, **ckwargs)
|
|
|
|
|
self.attn3 = attn_layer(out_chs, act_layer=act_layer) if attn_last else nn.Identity()
|
|
|
|
@ -418,7 +433,6 @@ class DarkBlock(nn.Module):
|
|
|
|
|
act_layer=nn.ReLU,
|
|
|
|
|
norm_layer=nn.BatchNorm2d,
|
|
|
|
|
attn_layer=None,
|
|
|
|
|
aa_layer=None,
|
|
|
|
|
drop_block=None,
|
|
|
|
|
drop_path=0.
|
|
|
|
|
):
|
|
|
|
@ -428,9 +442,49 @@ class DarkBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
self.conv1 = ConvNormAct(in_chs, mid_chs, kernel_size=1, **ckwargs)
|
|
|
|
|
self.attn = attn_layer(mid_chs, act_layer=act_layer) if attn_layer is not None else nn.Identity()
|
|
|
|
|
self.conv2 = ConvNormActAa(
|
|
|
|
|
self.conv2 = ConvNormAct(
|
|
|
|
|
mid_chs, out_chs, kernel_size=3, dilation=dilation, groups=groups,
|
|
|
|
|
aa_layer=aa_layer, drop_layer=drop_block, **ckwargs)
|
|
|
|
|
drop_layer=drop_block, **ckwargs)
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
|
|
|
|
|
|
|
|
|
|
def zero_init_last(self):
|
|
|
|
|
nn.init.zeros_(self.conv2.bn.weight)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
shortcut = x
|
|
|
|
|
x = self.conv1(x)
|
|
|
|
|
x = self.attn(x)
|
|
|
|
|
x = self.conv2(x)
|
|
|
|
|
x = self.drop_path(x) + shortcut
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EdgeBlock(nn.Module):
|
|
|
|
|
""" EdgeResidual / Fused-MBConv / MobileNetV1-like 3x3 + 1x1 block (w/ activated output)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
in_chs,
|
|
|
|
|
out_chs,
|
|
|
|
|
dilation=1,
|
|
|
|
|
bottle_ratio=0.5,
|
|
|
|
|
groups=1,
|
|
|
|
|
act_layer=nn.ReLU,
|
|
|
|
|
norm_layer=nn.BatchNorm2d,
|
|
|
|
|
attn_layer=None,
|
|
|
|
|
drop_block=None,
|
|
|
|
|
drop_path=0.
|
|
|
|
|
):
|
|
|
|
|
super(EdgeBlock, self).__init__()
|
|
|
|
|
mid_chs = int(round(out_chs * bottle_ratio))
|
|
|
|
|
ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
|
|
|
|
|
|
|
|
|
|
self.conv1 = ConvNormAct(
|
|
|
|
|
in_chs, mid_chs, kernel_size=3, dilation=dilation, groups=groups,
|
|
|
|
|
drop_layer=drop_block, **ckwargs)
|
|
|
|
|
self.attn = attn_layer(mid_chs, act_layer=act_layer) if attn_layer is not None else nn.Identity()
|
|
|
|
|
self.conv2 = ConvNormAct(mid_chs, out_chs, kernel_size=1, **ckwargs)
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
|
|
|
|
|
|
|
|
|
|
def zero_init_last(self):
|
|
|
|
@ -472,6 +526,7 @@ class CrossStage(nn.Module):
|
|
|
|
|
self.expand_chs = exp_chs = int(round(out_chs * expand_ratio))
|
|
|
|
|
block_out_chs = int(round(out_chs * block_ratio))
|
|
|
|
|
conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'))
|
|
|
|
|
aa_layer = block_kwargs.pop('aa_layer', None)
|
|
|
|
|
|
|
|
|
|
if stride != 1 or first_dilation != dilation:
|
|
|
|
|
if avg_down:
|
|
|
|
@ -482,7 +537,7 @@ class CrossStage(nn.Module):
|
|
|
|
|
else:
|
|
|
|
|
self.conv_down = ConvNormActAa(
|
|
|
|
|
in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
|
|
|
|
|
aa_layer=block_kwargs.get('aa_layer', None), **conv_kwargs)
|
|
|
|
|
aa_layer=aa_layer, **conv_kwargs)
|
|
|
|
|
prev_chs = down_chs
|
|
|
|
|
else:
|
|
|
|
|
self.conv_down = nn.Identity()
|
|
|
|
@ -550,6 +605,7 @@ class CrossStage3(nn.Module):
|
|
|
|
|
self.expand_chs = exp_chs = int(round(out_chs * expand_ratio))
|
|
|
|
|
block_out_chs = int(round(out_chs * block_ratio))
|
|
|
|
|
conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'))
|
|
|
|
|
aa_layer = block_kwargs.pop('aa_layer', None)
|
|
|
|
|
|
|
|
|
|
if stride != 1 or first_dilation != dilation:
|
|
|
|
|
if avg_down:
|
|
|
|
@ -560,7 +616,7 @@ class CrossStage3(nn.Module):
|
|
|
|
|
else:
|
|
|
|
|
self.conv_down = ConvNormActAa(
|
|
|
|
|
in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
|
|
|
|
|
aa_layer=block_kwargs.get('aa_layer', None), **conv_kwargs)
|
|
|
|
|
aa_layer=aa_layer, **conv_kwargs)
|
|
|
|
|
prev_chs = down_chs
|
|
|
|
|
else:
|
|
|
|
|
self.conv_down = None
|
|
|
|
@ -617,6 +673,7 @@ class DarkStage(nn.Module):
|
|
|
|
|
super(DarkStage, self).__init__()
|
|
|
|
|
first_dilation = first_dilation or dilation
|
|
|
|
|
conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'))
|
|
|
|
|
aa_layer = block_kwargs.pop('aa_layer', None)
|
|
|
|
|
|
|
|
|
|
if avg_down:
|
|
|
|
|
self.conv_down = nn.Sequential(
|
|
|
|
@ -626,7 +683,7 @@ class DarkStage(nn.Module):
|
|
|
|
|
else:
|
|
|
|
|
self.conv_down = ConvNormActAa(
|
|
|
|
|
in_chs, out_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
|
|
|
|
|
aa_layer=block_kwargs.get('aa_layer', None), **conv_kwargs)
|
|
|
|
|
aa_layer=aa_layer, **conv_kwargs)
|
|
|
|
|
|
|
|
|
|
prev_chs = out_chs
|
|
|
|
|
block_out_chs = int(round(out_chs * block_ratio))
|
|
|
|
@ -720,9 +777,11 @@ def _get_stage_fn(stage_args):
|
|
|
|
|
|
|
|
|
|
def _get_block_fn(stage_args):
|
|
|
|
|
block_type = stage_args.pop('block_type')
|
|
|
|
|
assert block_type in ('dark', 'bottle')
|
|
|
|
|
assert block_type in ('dark', 'edge', 'bottle')
|
|
|
|
|
if block_type == 'dark':
|
|
|
|
|
return DarkBlock, stage_args
|
|
|
|
|
elif block_type == 'edge':
|
|
|
|
|
return EdgeBlock, stage_args
|
|
|
|
|
else:
|
|
|
|
|
return BottleneckBlock, stage_args
|
|
|
|
|
|
|
|
|
@ -751,7 +810,6 @@ def create_csp_stages(
|
|
|
|
|
block_kwargs = dict(
|
|
|
|
|
act_layer=cfg.act_layer,
|
|
|
|
|
norm_layer=cfg.norm_layer,
|
|
|
|
|
aa_layer=cfg.aa_layer
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
dilation = 1
|
|
|
|
@ -780,6 +838,7 @@ def create_csp_stages(
|
|
|
|
|
first_dilation=first_dilation,
|
|
|
|
|
dilation=dilation,
|
|
|
|
|
block_fn=block_fn,
|
|
|
|
|
aa_layer=cfg.aa_layer,
|
|
|
|
|
attn_layer=attn_fn, # will be passed through stage as block_kwargs
|
|
|
|
|
**block_kwargs,
|
|
|
|
|
)]
|
|
|
|
@ -1002,6 +1061,21 @@ def cs3sedarknet_l(pretrained=False, **kwargs):
|
|
|
|
|
return _create_cspnet('cs3sedarknet_l', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def cs3sedarknet_x(pretrained=False, **kwargs):
|
|
|
|
|
return _create_cspnet('cs3sedarknet_x', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def cs3sedarknet_xdw(pretrained=False, **kwargs):
|
|
|
|
|
return _create_cspnet('cs3sedarknet_xdw', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def cs3edgenet_x(pretrained=False, **kwargs):
|
|
|
|
|
return _create_cspnet('cs3edgenet_x', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def cs3se_edgenet_x(pretrained=False, **kwargs):
|
|
|
|
|
return _create_cspnet('cs3se_edgenet_x', pretrained=pretrained, **kwargs)
|