|
|
@ -109,6 +109,7 @@ class ConvNeXtBlock(nn.Module):
|
|
|
|
dim,
|
|
|
|
dim,
|
|
|
|
dim_out=None,
|
|
|
|
dim_out=None,
|
|
|
|
stride=1,
|
|
|
|
stride=1,
|
|
|
|
|
|
|
|
dilation=1,
|
|
|
|
mlp_ratio=4,
|
|
|
|
mlp_ratio=4,
|
|
|
|
conv_mlp=False,
|
|
|
|
conv_mlp=False,
|
|
|
|
conv_bias=True,
|
|
|
|
conv_bias=True,
|
|
|
@ -124,7 +125,8 @@ class ConvNeXtBlock(nn.Module):
|
|
|
|
mlp_layer = ConvMlp if conv_mlp else Mlp
|
|
|
|
mlp_layer = ConvMlp if conv_mlp else Mlp
|
|
|
|
self.use_conv_mlp = conv_mlp
|
|
|
|
self.use_conv_mlp = conv_mlp
|
|
|
|
|
|
|
|
|
|
|
|
self.conv_dw = create_conv2d(dim, dim_out, kernel_size=7, stride=stride, depthwise=True, bias=conv_bias)
|
|
|
|
self.conv_dw = create_conv2d(
|
|
|
|
|
|
|
|
dim, dim_out, kernel_size=7, stride=stride, dilation=dilation, depthwise=True, bias=conv_bias)
|
|
|
|
self.norm = norm_layer(dim_out)
|
|
|
|
self.norm = norm_layer(dim_out)
|
|
|
|
self.mlp = mlp_layer(dim_out, int(mlp_ratio * dim_out), act_layer=act_layer)
|
|
|
|
self.mlp = mlp_layer(dim_out, int(mlp_ratio * dim_out), act_layer=act_layer)
|
|
|
|
self.gamma = nn.Parameter(ls_init_value * torch.ones(dim_out)) if ls_init_value > 0 else None
|
|
|
|
self.gamma = nn.Parameter(ls_init_value * torch.ones(dim_out)) if ls_init_value > 0 else None
|
|
|
@ -156,6 +158,7 @@ class ConvNeXtStage(nn.Module):
|
|
|
|
out_chs,
|
|
|
|
out_chs,
|
|
|
|
stride=2,
|
|
|
|
stride=2,
|
|
|
|
depth=2,
|
|
|
|
depth=2,
|
|
|
|
|
|
|
|
dilation=(1, 1),
|
|
|
|
drop_path_rates=None,
|
|
|
|
drop_path_rates=None,
|
|
|
|
ls_init_value=1.0,
|
|
|
|
ls_init_value=1.0,
|
|
|
|
conv_mlp=False,
|
|
|
|
conv_mlp=False,
|
|
|
@ -166,10 +169,14 @@ class ConvNeXtStage(nn.Module):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
self.grad_checkpointing = False
|
|
|
|
self.grad_checkpointing = False
|
|
|
|
|
|
|
|
|
|
|
|
if in_chs != out_chs or stride > 1:
|
|
|
|
if in_chs != out_chs or stride > 1 or dilation[0] != dilation[1]:
|
|
|
|
|
|
|
|
ds_ks = 2 if stride > 1 or dilation[0] != dilation[1] else 1
|
|
|
|
|
|
|
|
pad = 'same' if dilation[1] > 1 else 0 # same padding needed if dilation used
|
|
|
|
self.downsample = nn.Sequential(
|
|
|
|
self.downsample = nn.Sequential(
|
|
|
|
norm_layer(in_chs),
|
|
|
|
norm_layer(in_chs),
|
|
|
|
nn.Conv2d(in_chs, out_chs, kernel_size=stride, stride=stride, bias=conv_bias),
|
|
|
|
create_conv2d(
|
|
|
|
|
|
|
|
in_chs, out_chs, kernel_size=ds_ks, stride=stride,
|
|
|
|
|
|
|
|
dilation=dilation[0], padding=pad, bias=conv_bias),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
in_chs = out_chs
|
|
|
|
in_chs = out_chs
|
|
|
|
else:
|
|
|
|
else:
|
|
|
@ -181,6 +188,7 @@ class ConvNeXtStage(nn.Module):
|
|
|
|
stage_blocks.append(ConvNeXtBlock(
|
|
|
|
stage_blocks.append(ConvNeXtBlock(
|
|
|
|
dim=in_chs,
|
|
|
|
dim=in_chs,
|
|
|
|
dim_out=out_chs,
|
|
|
|
dim_out=out_chs,
|
|
|
|
|
|
|
|
dilation=dilation[1],
|
|
|
|
drop_path=drop_path_rates[i],
|
|
|
|
drop_path=drop_path_rates[i],
|
|
|
|
ls_init_value=ls_init_value,
|
|
|
|
ls_init_value=ls_init_value,
|
|
|
|
conv_mlp=conv_mlp,
|
|
|
|
conv_mlp=conv_mlp,
|
|
|
@ -235,7 +243,7 @@ class ConvNeXt(nn.Module):
|
|
|
|
drop_path_rate=0.,
|
|
|
|
drop_path_rate=0.,
|
|
|
|
):
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
assert output_stride == 32
|
|
|
|
assert output_stride in (8, 16, 32)
|
|
|
|
if norm_layer is None:
|
|
|
|
if norm_layer is None:
|
|
|
|
norm_layer = partial(LayerNorm2d, eps=1e-6)
|
|
|
|
norm_layer = partial(LayerNorm2d, eps=1e-6)
|
|
|
|
norm_layer_cl = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
|
|
|
|
norm_layer_cl = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
|
|
|
@ -263,22 +271,27 @@ class ConvNeXt(nn.Module):
|
|
|
|
padding=stem_kernel_size // 2, bias=conv_bias),
|
|
|
|
padding=stem_kernel_size // 2, bias=conv_bias),
|
|
|
|
norm_layer(dims[0]),
|
|
|
|
norm_layer(dims[0]),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
prev_chs = dims[0]
|
|
|
|
|
|
|
|
curr_stride = stem_stride
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.stages = nn.Sequential()
|
|
|
|
self.stages = nn.Sequential()
|
|
|
|
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
|
|
|
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
|
|
|
stages = []
|
|
|
|
stages = []
|
|
|
|
|
|
|
|
prev_chs = dims[0]
|
|
|
|
|
|
|
|
curr_stride = stem_stride
|
|
|
|
|
|
|
|
dilation = 1
|
|
|
|
# 4 feature resolution stages, each consisting of multiple residual blocks
|
|
|
|
# 4 feature resolution stages, each consisting of multiple residual blocks
|
|
|
|
for i in range(4):
|
|
|
|
for i in range(4):
|
|
|
|
stride = 2 if curr_stride == 2 or i > 0 else 1
|
|
|
|
stride = 2 if curr_stride == 2 or i > 0 else 1
|
|
|
|
# FIXME support dilation / output_stride
|
|
|
|
if curr_stride >= output_stride and stride > 1:
|
|
|
|
|
|
|
|
dilation *= stride
|
|
|
|
|
|
|
|
stride = 1
|
|
|
|
curr_stride *= stride
|
|
|
|
curr_stride *= stride
|
|
|
|
|
|
|
|
first_dilation = 1 if dilation in (1, 2) else 2
|
|
|
|
out_chs = dims[i]
|
|
|
|
out_chs = dims[i]
|
|
|
|
stages.append(ConvNeXtStage(
|
|
|
|
stages.append(ConvNeXtStage(
|
|
|
|
prev_chs,
|
|
|
|
prev_chs,
|
|
|
|
out_chs,
|
|
|
|
out_chs,
|
|
|
|
stride=stride,
|
|
|
|
stride=stride,
|
|
|
|
|
|
|
|
dilation=(first_dilation, dilation),
|
|
|
|
depth=depths[i],
|
|
|
|
depth=depths[i],
|
|
|
|
drop_path_rates=dp_rates[i],
|
|
|
|
drop_path_rates=dp_rates[i],
|
|
|
|
ls_init_value=ls_init_value,
|
|
|
|
ls_init_value=ls_init_value,
|
|
|
|