|
|
|
@ -16,12 +16,11 @@ from functools import partial
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
from .helpers import named_apply, build_model_with_cfg, checkpoint_seq
|
|
|
|
|
from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d,\
|
|
|
|
|
create_conv2d, make_divisible
|
|
|
|
|
create_conv2d, get_act_layer, make_divisible, to_ntuple
|
|
|
|
|
from .registry import register_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -40,14 +39,13 @@ def _cfg(url='', **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
default_cfgs = dict(
|
|
|
|
|
convnext_tiny=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth"),
|
|
|
|
|
convnext_small=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth"),
|
|
|
|
|
convnext_base=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth"),
|
|
|
|
|
convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"),
|
|
|
|
|
|
|
|
|
|
# timm specific variants
|
|
|
|
|
convnext_atto=_cfg(url=''),
|
|
|
|
|
convnext_atto_ols=_cfg(url=''),
|
|
|
|
|
convnext_atto=_cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth',
|
|
|
|
|
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
|
|
|
|
convnext_atto_ols=_cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth',
|
|
|
|
|
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
|
|
|
|
convnext_femto=_cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth',
|
|
|
|
|
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
|
|
|
@ -70,16 +68,34 @@ default_cfgs = dict(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
|
|
|
|
|
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
|
|
|
|
|
|
|
|
|
convnext_tiny=_cfg(
|
|
|
|
|
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
|
|
|
|
|
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
|
|
|
|
convnext_small=_cfg(
|
|
|
|
|
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
|
|
|
|
|
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
|
|
|
|
convnext_base=_cfg(
|
|
|
|
|
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
|
|
|
|
|
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
|
|
|
|
convnext_large=_cfg(
|
|
|
|
|
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
|
|
|
|
|
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
|
|
|
|
|
|
|
|
|
convnext_tiny_in22ft1k=_cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth'),
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth',
|
|
|
|
|
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
|
|
|
|
convnext_small_in22ft1k=_cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth'),
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth',
|
|
|
|
|
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
|
|
|
|
convnext_base_in22ft1k=_cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth'),
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth',
|
|
|
|
|
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
|
|
|
|
convnext_large_in22ft1k=_cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth'),
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth',
|
|
|
|
|
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
|
|
|
|
convnext_xlarge_in22ft1k=_cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth'),
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth',
|
|
|
|
|
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
|
|
|
|
|
|
|
|
|
convnext_tiny_384_in22ft1k=_cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth',
|
|
|
|
@ -121,37 +137,39 @@ class ConvNeXtBlock(nn.Module):
|
|
|
|
|
is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
dim (int): Number of input channels.
|
|
|
|
|
in_chs (int): Number of input channels.
|
|
|
|
|
drop_path (float): Stochastic depth rate. Default: 0.0
|
|
|
|
|
ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
dim,
|
|
|
|
|
dim_out=None,
|
|
|
|
|
in_chs,
|
|
|
|
|
out_chs=None,
|
|
|
|
|
kernel_size=7,
|
|
|
|
|
stride=1,
|
|
|
|
|
dilation=1,
|
|
|
|
|
mlp_ratio=4,
|
|
|
|
|
conv_mlp=False,
|
|
|
|
|
conv_bias=True,
|
|
|
|
|
ls_init_value=1e-6,
|
|
|
|
|
act_layer='gelu',
|
|
|
|
|
norm_layer=None,
|
|
|
|
|
act_layer=nn.GELU,
|
|
|
|
|
drop_path=0.,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
dim_out = dim_out or dim
|
|
|
|
|
out_chs = out_chs or in_chs
|
|
|
|
|
act_layer = get_act_layer(act_layer)
|
|
|
|
|
if not norm_layer:
|
|
|
|
|
norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
|
|
|
|
|
mlp_layer = ConvMlp if conv_mlp else Mlp
|
|
|
|
|
self.use_conv_mlp = conv_mlp
|
|
|
|
|
|
|
|
|
|
self.conv_dw = create_conv2d(
|
|
|
|
|
dim, dim_out, kernel_size=7, stride=stride, dilation=dilation, depthwise=True, bias=conv_bias)
|
|
|
|
|
self.norm = norm_layer(dim_out)
|
|
|
|
|
self.mlp = mlp_layer(dim_out, int(mlp_ratio * dim_out), act_layer=act_layer)
|
|
|
|
|
self.gamma = nn.Parameter(ls_init_value * torch.ones(dim_out)) if ls_init_value > 0 else None
|
|
|
|
|
in_chs, out_chs, kernel_size=kernel_size, stride=stride, dilation=dilation, depthwise=True, bias=conv_bias)
|
|
|
|
|
self.norm = norm_layer(out_chs)
|
|
|
|
|
self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer)
|
|
|
|
|
self.gamma = nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value > 0 else None
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
@ -178,6 +196,7 @@ class ConvNeXtStage(nn.Module):
|
|
|
|
|
self,
|
|
|
|
|
in_chs,
|
|
|
|
|
out_chs,
|
|
|
|
|
kernel_size=7,
|
|
|
|
|
stride=2,
|
|
|
|
|
depth=2,
|
|
|
|
|
dilation=(1, 1),
|
|
|
|
@ -185,6 +204,7 @@ class ConvNeXtStage(nn.Module):
|
|
|
|
|
ls_init_value=1.0,
|
|
|
|
|
conv_mlp=False,
|
|
|
|
|
conv_bias=True,
|
|
|
|
|
act_layer='gelu',
|
|
|
|
|
norm_layer=None,
|
|
|
|
|
norm_layer_cl=None
|
|
|
|
|
):
|
|
|
|
@ -208,13 +228,15 @@ class ConvNeXtStage(nn.Module):
|
|
|
|
|
stage_blocks = []
|
|
|
|
|
for i in range(depth):
|
|
|
|
|
stage_blocks.append(ConvNeXtBlock(
|
|
|
|
|
dim=in_chs,
|
|
|
|
|
dim_out=out_chs,
|
|
|
|
|
in_chs=in_chs,
|
|
|
|
|
out_chs=out_chs,
|
|
|
|
|
kernel_size=kernel_size,
|
|
|
|
|
dilation=dilation[1],
|
|
|
|
|
drop_path=drop_path_rates[i],
|
|
|
|
|
ls_init_value=ls_init_value,
|
|
|
|
|
conv_mlp=conv_mlp,
|
|
|
|
|
conv_bias=conv_bias,
|
|
|
|
|
act_layer=act_layer,
|
|
|
|
|
norm_layer=norm_layer if conv_mlp else norm_layer_cl
|
|
|
|
|
))
|
|
|
|
|
in_chs = out_chs
|
|
|
|
@ -252,6 +274,7 @@ class ConvNeXt(nn.Module):
|
|
|
|
|
output_stride=32,
|
|
|
|
|
depths=(3, 3, 9, 3),
|
|
|
|
|
dims=(96, 192, 384, 768),
|
|
|
|
|
kernel_sizes=7,
|
|
|
|
|
ls_init_value=1e-6,
|
|
|
|
|
stem_type='patch',
|
|
|
|
|
patch_size=4,
|
|
|
|
@ -259,12 +282,14 @@ class ConvNeXt(nn.Module):
|
|
|
|
|
head_norm_first=False,
|
|
|
|
|
conv_mlp=False,
|
|
|
|
|
conv_bias=True,
|
|
|
|
|
act_layer='gelu',
|
|
|
|
|
norm_layer=None,
|
|
|
|
|
drop_rate=0.,
|
|
|
|
|
drop_path_rate=0.,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
assert output_stride in (8, 16, 32)
|
|
|
|
|
kernel_sizes = to_ntuple(4)(kernel_sizes)
|
|
|
|
|
if norm_layer is None:
|
|
|
|
|
norm_layer = partial(LayerNorm2d, eps=1e-6)
|
|
|
|
|
norm_layer_cl = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
|
|
|
|
@ -312,6 +337,7 @@ class ConvNeXt(nn.Module):
|
|
|
|
|
stages.append(ConvNeXtStage(
|
|
|
|
|
prev_chs,
|
|
|
|
|
out_chs,
|
|
|
|
|
kernel_size=kernel_sizes[i],
|
|
|
|
|
stride=stride,
|
|
|
|
|
dilation=(first_dilation, dilation),
|
|
|
|
|
depth=depths[i],
|
|
|
|
@ -319,6 +345,7 @@ class ConvNeXt(nn.Module):
|
|
|
|
|
ls_init_value=ls_init_value,
|
|
|
|
|
conv_mlp=conv_mlp,
|
|
|
|
|
conv_bias=conv_bias,
|
|
|
|
|
act_layer=act_layer,
|
|
|
|
|
norm_layer=norm_layer,
|
|
|
|
|
norm_layer_cl=norm_layer_cl
|
|
|
|
|
))
|
|
|
|
|