@ -23,7 +23,7 @@ import torch.nn.functional as F
from timm . data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
from timm . data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
from . helpers import build_model_with_cfg , named_apply , MATCH_PREV_GROUP
from . helpers import build_model_with_cfg , named_apply , MATCH_PREV_GROUP
from . layers import ClassifierHead , ConvNormAct , ConvNormActAa , DropPath , create _attn, create_act_layer , make_divisible
from . layers import ClassifierHead , ConvNormAct , ConvNormActAa , DropPath , get _attn, create_act_layer , make_divisible
from . registry import register_model
from . registry import register_model
@ -57,9 +57,10 @@ default_cfgs = {
' sedarknet21 ' : _cfg ( url = ' ' ) ,
' sedarknet21 ' : _cfg ( url = ' ' ) ,
' darknet53 ' : _cfg (
' darknet53 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/darknet53_256_c2ns-3aeff817.pth ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/darknet53_256_c2ns-3aeff817.pth ' ,
interpolation = ' bicubic ' , test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 1.0 ,
interpolation = ' bicubic ' , test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 1.0 ) ,
) ,
' darknetaa53 ' : _cfg (
' darknetaa53 ' : _cfg ( url = ' ' ) ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/darknetaa53_c2ns-5c28ec8a.pth ' ,
test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 1.0 ) ,
' cs3darknet_s ' : _cfg (
' cs3darknet_s ' : _cfg (
url = ' ' , interpolation = ' bicubic ' ) ,
url = ' ' , interpolation = ' bicubic ' ) ,
@ -71,7 +72,8 @@ default_cfgs = {
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_l_c2ns-16220c5d.pth ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_l_c2ns-16220c5d.pth ' ,
interpolation = ' bicubic ' , test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 0.95 ) ,
interpolation = ' bicubic ' , test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 0.95 ) ,
' cs3darknet_x ' : _cfg (
' cs3darknet_x ' : _cfg (
url = ' ' ) ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_x_c2ns-4e4490aa.pth ' ,
interpolation = ' bicubic ' , crop_pct = 0.95 , test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 1.0 ) ,
' cs3darknet_focus_s ' : _cfg (
' cs3darknet_focus_s ' : _cfg (
url = ' ' , interpolation = ' bicubic ' ) ,
url = ' ' , interpolation = ' bicubic ' ) ,
@ -84,6 +86,10 @@ default_cfgs = {
' cs3darknet_focus_x ' : _cfg (
' cs3darknet_focus_x ' : _cfg (
url = ' ' , interpolation = ' bicubic ' ) ,
url = ' ' , interpolation = ' bicubic ' ) ,
' cs3sedarknet_l ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3sedarknet_l_c2ns-e8d1dc13.pth ' ,
interpolation = ' bicubic ' , test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 0.95 ) ,
' cs3sedarknet_xdw ' : _cfg (
' cs3sedarknet_xdw ' : _cfg (
url = ' ' , interpolation = ' bicubic ' ) ,
url = ' ' , interpolation = ' bicubic ' ) ,
}
}
@ -119,6 +125,7 @@ class CspStagesCfg:
bottle_ratio : Union [ float , Tuple [ float , . . . ] ] = 1. # bottleneck-ratio of blocks in stage
bottle_ratio : Union [ float , Tuple [ float , . . . ] ] = 1. # bottleneck-ratio of blocks in stage
avg_down : Union [ bool , Tuple [ bool , . . . ] ] = False
avg_down : Union [ bool , Tuple [ bool , . . . ] ] = False
attn_layer : Optional [ Union [ str , Tuple [ str , . . . ] ] ] = None
attn_layer : Optional [ Union [ str , Tuple [ str , . . . ] ] ] = None
attn_kwargs : Optional [ Union [ Dict , Tuple [ Dict ] ] ] = None
stage_type : Union [ str , Tuple [ str ] ] = ' csp ' # stage type ('csp', 'cs2', 'dark')
stage_type : Union [ str , Tuple [ str ] ] = ' csp ' # stage type ('csp', 'cs2', 'dark')
block_type : Union [ str , Tuple [ str ] ] = ' bottle ' # blocks type for stages ('bottle', 'dark')
block_type : Union [ str , Tuple [ str ] ] = ' bottle ' # blocks type for stages ('bottle', 'dark')
@ -136,6 +143,7 @@ class CspStagesCfg:
self . bottle_ratio = _pad_arg ( self . bottle_ratio , n )
self . bottle_ratio = _pad_arg ( self . bottle_ratio , n )
self . avg_down = _pad_arg ( self . avg_down , n )
self . avg_down = _pad_arg ( self . avg_down , n )
self . attn_layer = _pad_arg ( self . attn_layer , n )
self . attn_layer = _pad_arg ( self . attn_layer , n )
self . attn_kwargs = _pad_arg ( self . attn_kwargs , n )
self . stage_type = _pad_arg ( self . stage_type , n )
self . stage_type = _pad_arg ( self . stage_type , n )
self . block_type = _pad_arg ( self . block_type , n )
self . block_type = _pad_arg ( self . block_type , n )
@ -149,12 +157,20 @@ class CspModelCfg:
stem : CspStemCfg
stem : CspStemCfg
stages : CspStagesCfg
stages : CspStagesCfg
zero_init_last : bool = True # zero init last weight (usually bn) in residual path
zero_init_last : bool = True # zero init last weight (usually bn) in residual path
act_layer : str = ' relu'
act_layer : str = ' leaky_ relu'
norm_layer : str = ' batchnorm '
norm_layer : str = ' batchnorm '
aa_layer : Optional [ str ] = None # FIXME support string factory for this
aa_layer : Optional [ str ] = None # FIXME support string factory for this
def _cs3darknet_cfg ( width_multiplier = 1.0 , depth_multiplier = 1.0 , avg_down = False , act_layer = ' silu ' , focus = False ) :
def _cs3darknet_cfg (
width_multiplier = 1.0 ,
depth_multiplier = 1.0 ,
avg_down = False ,
act_layer = ' silu ' ,
focus = False ,
attn_layer = None ,
attn_kwargs = None ,
) :
if focus :
if focus :
stem_cfg = CspStemCfg (
stem_cfg = CspStemCfg (
out_chs = make_divisible ( 64 * width_multiplier ) ,
out_chs = make_divisible ( 64 * width_multiplier ) ,
@ -172,6 +188,8 @@ def _cs3darknet_cfg(width_multiplier=1.0, depth_multiplier=1.0, avg_down=False,
bottle_ratio = 1. ,
bottle_ratio = 1. ,
block_ratio = 0.5 ,
block_ratio = 0.5 ,
avg_down = avg_down ,
avg_down = avg_down ,
attn_layer = attn_layer ,
attn_kwargs = attn_kwargs ,
stage_type = ' cs3 ' ,
stage_type = ' cs3 ' ,
block_type = ' dark ' ,
block_type = ' dark ' ,
) ,
) ,
@ -201,7 +219,7 @@ model_cfgs = dict(
bottle_ratio = 0.5 ,
bottle_ratio = 0.5 ,
block_ratio = 1. ,
block_ratio = 1. ,
cross_linear = True ,
cross_linear = True ,
)
) ,
) ,
) ,
cspresnet50w = CspModelCfg (
cspresnet50w = CspModelCfg (
stem = CspStemCfg ( out_chs = ( 32 , 32 , 64 ) , kernel_size = 3 , stride = 4 , pool = ' max ' ) ,
stem = CspStemCfg ( out_chs = ( 32 , 32 , 64 ) , kernel_size = 3 , stride = 4 , pool = ' max ' ) ,
@ -213,7 +231,7 @@ model_cfgs = dict(
bottle_ratio = 0.25 ,
bottle_ratio = 0.25 ,
block_ratio = 0.5 ,
block_ratio = 0.5 ,
cross_linear = True ,
cross_linear = True ,
)
) ,
) ,
) ,
cspresnext50 = CspModelCfg (
cspresnext50 = CspModelCfg (
stem = CspStemCfg ( out_chs = 64 , kernel_size = 7 , stride = 4 , pool = ' max ' ) ,
stem = CspStemCfg ( out_chs = 64 , kernel_size = 7 , stride = 4 , pool = ' max ' ) ,
@ -226,7 +244,7 @@ model_cfgs = dict(
bottle_ratio = 1. ,
bottle_ratio = 1. ,
block_ratio = 0.5 ,
block_ratio = 0.5 ,
cross_linear = True ,
cross_linear = True ,
)
) ,
) ,
) ,
cspdarknet53 = CspModelCfg (
cspdarknet53 = CspModelCfg (
stem = CspStemCfg ( out_chs = 32 , kernel_size = 3 , stride = 1 , pool = ' ' ) ,
stem = CspStemCfg ( out_chs = 32 , kernel_size = 3 , stride = 1 , pool = ' ' ) ,
@ -240,7 +258,6 @@ model_cfgs = dict(
down_growth = True ,
down_growth = True ,
block_type = ' dark ' ,
block_type = ' dark ' ,
) ,
) ,
act_layer = ' leaky_relu ' ,
) ,
) ,
darknet17 = CspModelCfg (
darknet17 = CspModelCfg (
stem = CspStemCfg ( out_chs = 32 , kernel_size = 3 , stride = 1 , pool = ' ' ) ,
stem = CspStemCfg ( out_chs = 32 , kernel_size = 3 , stride = 1 , pool = ' ' ) ,
@ -253,7 +270,6 @@ model_cfgs = dict(
stage_type = ' dark ' ,
stage_type = ' dark ' ,
block_type = ' dark ' ,
block_type = ' dark ' ,
) ,
) ,
act_layer = ' leaky_relu ' ,
) ,
) ,
darknet21 = CspModelCfg (
darknet21 = CspModelCfg (
stem = CspStemCfg ( out_chs = 32 , kernel_size = 3 , stride = 1 , pool = ' ' ) ,
stem = CspStemCfg ( out_chs = 32 , kernel_size = 3 , stride = 1 , pool = ' ' ) ,
@ -267,7 +283,6 @@ model_cfgs = dict(
block_type = ' dark ' ,
block_type = ' dark ' ,
) ,
) ,
act_layer = ' leaky_relu ' ,
) ,
) ,
sedarknet21 = CspModelCfg (
sedarknet21 = CspModelCfg (
stem = CspStemCfg ( out_chs = 32 , kernel_size = 3 , stride = 1 , pool = ' ' ) ,
stem = CspStemCfg ( out_chs = 32 , kernel_size = 3 , stride = 1 , pool = ' ' ) ,
@ -282,7 +297,6 @@ model_cfgs = dict(
block_type = ' dark ' ,
block_type = ' dark ' ,
) ,
) ,
act_layer = ' leaky_relu ' ,
) ,
) ,
darknet53 = CspModelCfg (
darknet53 = CspModelCfg (
stem = CspStemCfg ( out_chs = 32 , kernel_size = 3 , stride = 1 , pool = ' ' ) ,
stem = CspStemCfg ( out_chs = 32 , kernel_size = 3 , stride = 1 , pool = ' ' ) ,
@ -295,7 +309,6 @@ model_cfgs = dict(
stage_type = ' dark ' ,
stage_type = ' dark ' ,
block_type = ' dark ' ,
block_type = ' dark ' ,
) ,
) ,
act_layer = ' leaky_relu ' ,
) ,
) ,
darknetaa53 = CspModelCfg (
darknetaa53 = CspModelCfg (
stem = CspStemCfg ( out_chs = 32 , kernel_size = 3 , stride = 1 , pool = ' ' ) ,
stem = CspStemCfg ( out_chs = 32 , kernel_size = 3 , stride = 1 , pool = ' ' ) ,
@ -309,7 +322,6 @@ model_cfgs = dict(
stage_type = ' dark ' ,
stage_type = ' dark ' ,
block_type = ' dark ' ,
block_type = ' dark ' ,
) ,
) ,
act_layer = ' leaky_relu ' ,
) ,
) ,
cs3darknet_s = _cs3darknet_cfg ( width_multiplier = 0.5 , depth_multiplier = 0.5 ) ,
cs3darknet_s = _cs3darknet_cfg ( width_multiplier = 0.5 , depth_multiplier = 0.5 ) ,
@ -322,6 +334,8 @@ model_cfgs = dict(
cs3darknet_focus_l = _cs3darknet_cfg ( focus = True ) ,
cs3darknet_focus_l = _cs3darknet_cfg ( focus = True ) ,
cs3darknet_focus_x = _cs3darknet_cfg ( width_multiplier = 1.25 , depth_multiplier = 1.33 , focus = True ) ,
cs3darknet_focus_x = _cs3darknet_cfg ( width_multiplier = 1.25 , depth_multiplier = 1.33 , focus = True ) ,
cs3sedarknet_l = _cs3darknet_cfg ( attn_layer = ' se ' , attn_kwargs = dict ( rd_ratio = .25 ) ) ,
cs3sedarknet_xdw = CspModelCfg (
cs3sedarknet_xdw = CspModelCfg (
stem = CspStemCfg ( out_chs = ( 32 , 64 ) , kernel_size = 3 , stride = 2 , pool = ' ' ) ,
stem = CspStemCfg ( out_chs = ( 32 , 64 ) , kernel_size = 3 , stride = 2 , pool = ' ' ) ,
stages = CspStagesCfg (
stages = CspStagesCfg (
@ -333,6 +347,7 @@ model_cfgs = dict(
block_ratio = 0.5 ,
block_ratio = 0.5 ,
attn_layer = ' se ' ,
attn_layer = ' se ' ,
) ,
) ,
act_layer = ' silu ' ,
) ,
) ,
)
)
@ -359,14 +374,16 @@ class BottleneckBlock(nn.Module):
super ( BottleneckBlock , self ) . __init__ ( )
super ( BottleneckBlock , self ) . __init__ ( )
mid_chs = int ( round ( out_chs * bottle_ratio ) )
mid_chs = int ( round ( out_chs * bottle_ratio ) )
ckwargs = dict ( act_layer = act_layer , norm_layer = norm_layer )
ckwargs = dict ( act_layer = act_layer , norm_layer = norm_layer )
attn_last = attn_layer is not None and attn_last
attn_first = attn_layer is not None and not attn_last
self . conv1 = ConvNormAct ( in_chs , mid_chs , kernel_size = 1 , * * ckwargs )
self . conv1 = ConvNormAct ( in_chs , mid_chs , kernel_size = 1 , * * ckwargs )
self . conv2 = ConvNormActAa (
self . conv2 = ConvNormActAa (
mid_chs , mid_chs , kernel_size = 3 , dilation = dilation , groups = groups ,
mid_chs , mid_chs , kernel_size = 3 , dilation = dilation , groups = groups ,
aa_layer = aa_layer , drop_layer = drop_block , * * ckwargs )
aa_layer = aa_layer , drop_layer = drop_block , * * ckwargs )
self . attn2 = create_attn( attn_layer , channels = mid_chs ) if not attn_last else None
self . attn2 = attn_layer( mid_chs , act_layer = act_layer ) if attn_first else nn . Identity ( )
self . conv3 = ConvNormAct ( mid_chs , out_chs , kernel_size = 1 , apply_act = False , * * ckwargs )
self . conv3 = ConvNormAct ( mid_chs , out_chs , kernel_size = 1 , apply_act = False , * * ckwargs )
self . attn3 = create_attn( attn_layer , channels = out_chs ) if attn_last else None
self . attn3 = attn_layer( out_chs , act_layer = act_layer ) if attn_last else nn . Identity ( )
self . drop_path = DropPath ( drop_path ) if drop_path else nn . Identity ( )
self . drop_path = DropPath ( drop_path ) if drop_path else nn . Identity ( )
self . act3 = create_act_layer ( act_layer )
self . act3 = create_act_layer ( act_layer )
@ -377,10 +394,8 @@ class BottleneckBlock(nn.Module):
shortcut = x
shortcut = x
x = self . conv1 ( x )
x = self . conv1 ( x )
x = self . conv2 ( x )
x = self . conv2 ( x )
if self . attn2 is not None :
x = self . attn2 ( x )
x = self . attn2 ( x )
x = self . conv3 ( x )
x = self . conv3 ( x )
if self . attn3 is not None :
x = self . attn3 ( x )
x = self . attn3 ( x )
x = self . drop_path ( x ) + shortcut
x = self . drop_path ( x ) + shortcut
# FIXME partial shortcut needed if first block handled as per original, not used for my current impl
# FIXME partial shortcut needed if first block handled as per original, not used for my current impl
@ -410,11 +425,12 @@ class DarkBlock(nn.Module):
super ( DarkBlock , self ) . __init__ ( )
super ( DarkBlock , self ) . __init__ ( )
mid_chs = int ( round ( out_chs * bottle_ratio ) )
mid_chs = int ( round ( out_chs * bottle_ratio ) )
ckwargs = dict ( act_layer = act_layer , norm_layer = norm_layer )
ckwargs = dict ( act_layer = act_layer , norm_layer = norm_layer )
self . conv1 = ConvNormAct ( in_chs , mid_chs , kernel_size = 1 , * * ckwargs )
self . conv1 = ConvNormAct ( in_chs , mid_chs , kernel_size = 1 , * * ckwargs )
self . attn = attn_layer ( mid_chs , act_layer = act_layer ) if attn_layer is not None else nn . Identity ( )
self . conv2 = ConvNormActAa (
self . conv2 = ConvNormActAa (
mid_chs , out_chs , kernel_size = 3 , dilation = dilation , groups = groups ,
mid_chs , out_chs , kernel_size = 3 , dilation = dilation , groups = groups ,
aa_layer = aa_layer , drop_layer = drop_block , * * ckwargs )
aa_layer = aa_layer , drop_layer = drop_block , * * ckwargs )
self . attn = create_attn ( attn_layer , channels = out_chs , act_layer = act_layer )
self . drop_path = DropPath ( drop_path ) if drop_path else nn . Identity ( )
self . drop_path = DropPath ( drop_path ) if drop_path else nn . Identity ( )
def zero_init_last ( self ) :
def zero_init_last ( self ) :
@ -423,9 +439,8 @@ class DarkBlock(nn.Module):
def forward ( self , x ) :
def forward ( self , x ) :
shortcut = x
shortcut = x
x = self . conv1 ( x )
x = self . conv1 ( x )
x = self . conv2 ( x )
if self . attn is not None :
x = self . attn ( x )
x = self . attn ( x )
x = self . conv2 ( x )
x = self . drop_path ( x ) + shortcut
x = self . drop_path ( x ) + shortcut
return x
return x
@ -688,7 +703,8 @@ def create_csp_stem(
return stem , feature_info
return stem , feature_info
def _get_stage_fn ( stage_type : str , stage_args ) :
def _get_stage_fn ( stage_args ) :
stage_type = stage_args . pop ( ' stage_type ' )
assert stage_type in ( ' dark ' , ' csp ' , ' cs3 ' )
assert stage_type in ( ' dark ' , ' csp ' , ' cs3 ' )
if stage_type == ' dark ' :
if stage_type == ' dark ' :
stage_args . pop ( ' expand_ratio ' , None )
stage_args . pop ( ' expand_ratio ' , None )
@ -702,14 +718,25 @@ def _get_stage_fn(stage_type: str, stage_args):
return stage_fn , stage_args
return stage_fn , stage_args
def _get_block_fn ( stage_type : str , stage_args ) :
def _get_block_fn ( stage_args ) :
assert stage_type in ( ' dark ' , ' bottle ' )
block_type = stage_args . pop ( ' block_type ' )
if stage_type == ' dark ' :
assert block_type in ( ' dark ' , ' bottle ' )
if block_type == ' dark ' :
return DarkBlock , stage_args
return DarkBlock , stage_args
else :
else :
return BottleneckBlock , stage_args
return BottleneckBlock , stage_args
def _get_attn_fn ( stage_args ) :
attn_layer = stage_args . pop ( ' attn_layer ' )
attn_kwargs = stage_args . pop ( ' attn_kwargs ' , None ) or { }
if attn_layer is not None :
attn_layer = get_attn ( attn_layer )
if attn_kwargs :
attn_layer = partial ( attn_layer , * * attn_kwargs )
return attn_layer , stage_args
def create_csp_stages (
def create_csp_stages (
cfg : CspModelCfg ,
cfg : CspModelCfg ,
drop_path_rate : float ,
drop_path_rate : float ,
@ -734,8 +761,9 @@ def create_csp_stages(
feature_info = [ ]
feature_info = [ ]
stages = [ ]
stages = [ ]
for stage_idx , stage_args in enumerate ( stage_args ) :
for stage_idx , stage_args in enumerate ( stage_args ) :
stage_fn , stage_args = _get_stage_fn ( stage_args . pop ( ' stage_type ' ) , stage_args )
stage_fn , stage_args = _get_stage_fn ( stage_args )
block_fn , stage_args = _get_block_fn ( stage_args . pop ( ' block_type ' ) , stage_args )
block_fn , stage_args = _get_block_fn ( stage_args )
attn_fn , stage_args = _get_attn_fn ( stage_args )
stride = stage_args . pop ( ' stride ' )
stride = stage_args . pop ( ' stride ' )
if stride != 1 and prev_feat :
if stride != 1 and prev_feat :
feature_info . append ( prev_feat )
feature_info . append ( prev_feat )
@ -752,6 +780,7 @@ def create_csp_stages(
first_dilation = first_dilation ,
first_dilation = first_dilation ,
dilation = dilation ,
dilation = dilation ,
block_fn = block_fn ,
block_fn = block_fn ,
attn_layer = attn_fn , # will be passed through stage as block_kwargs
* * block_kwargs ,
* * block_kwargs ,
) ]
) ]
prev_chs = stage_args [ ' out_chs ' ]
prev_chs = stage_args [ ' out_chs ' ]
@ -968,6 +997,11 @@ def cs3darknet_focus_x(pretrained=False, **kwargs):
return _create_cspnet ( ' cs3darknet_focus_x ' , pretrained = pretrained , * * kwargs )
return _create_cspnet ( ' cs3darknet_focus_x ' , pretrained = pretrained , * * kwargs )
@register_model
def cs3sedarknet_l ( pretrained = False , * * kwargs ) :
return _create_cspnet ( ' cs3sedarknet_l ' , pretrained = pretrained , * * kwargs )
@register_model
@register_model
def cs3sedarknet_xdw ( pretrained = False , * * kwargs ) :
def cs3sedarknet_xdw ( pretrained = False , * * kwargs ) :
return _create_cspnet ( ' cs3sedarknet_xdw ' , pretrained = pretrained , * * kwargs )
return _create_cspnet ( ' cs3sedarknet_xdw ' , pretrained = pretrained , * * kwargs )