|
|
|
@ -4,6 +4,8 @@ Paper: `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
|
|
|
|
|
|
|
|
|
|
Original code and weights from https://github.com/facebookresearch/ConvNeXt, original copyright below
|
|
|
|
|
|
|
|
|
|
Model defs atto, femto, pico, nano and _ols / _hnf variants are timm specific.
|
|
|
|
|
|
|
|
|
|
Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman
|
|
|
|
|
"""
|
|
|
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
|
|
@ -18,7 +20,8 @@ import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
from .helpers import named_apply, build_model_with_cfg, checkpoint_seq
|
|
|
|
|
from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, create_conv2d
|
|
|
|
|
from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d,\
|
|
|
|
|
create_conv2d, make_divisible
|
|
|
|
|
from .registry import register_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -43,11 +46,26 @@ default_cfgs = dict(
|
|
|
|
|
convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"),
|
|
|
|
|
|
|
|
|
|
# timm specific variants
|
|
|
|
|
convnext_atto=_cfg(url=''),
|
|
|
|
|
convnext_atto_ols=_cfg(url=''),
|
|
|
|
|
convnext_femto=_cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth',
|
|
|
|
|
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
|
|
|
|
convnext_femto_ols=_cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_ols_d1-246bf2ed.pth',
|
|
|
|
|
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
|
|
|
|
convnext_pico=_cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth',
|
|
|
|
|
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
|
|
|
|
convnext_pico_ols=_cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_ols_d1-611f0ca7.pth',
|
|
|
|
|
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
|
|
|
|
convnext_nano=_cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth',
|
|
|
|
|
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
|
|
|
|
convnext_nano_hnf=_cfg(url=''),
|
|
|
|
|
convnext_nano_ols=_cfg(url=''),
|
|
|
|
|
convnext_nano_ols=_cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_ols_d1h-ae424a9a.pth',
|
|
|
|
|
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
|
|
|
|
convnext_tiny_hnf=_cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
|
|
|
|
|
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
|
|
|
@ -236,8 +254,7 @@ class ConvNeXt(nn.Module):
|
|
|
|
|
dims=(96, 192, 384, 768),
|
|
|
|
|
ls_init_value=1e-6,
|
|
|
|
|
stem_type='patch',
|
|
|
|
|
stem_kernel_size=4,
|
|
|
|
|
stem_stride=4,
|
|
|
|
|
patch_size=4,
|
|
|
|
|
head_init_scale=1.,
|
|
|
|
|
head_norm_first=False,
|
|
|
|
|
conv_mlp=False,
|
|
|
|
@ -260,21 +277,22 @@ class ConvNeXt(nn.Module):
|
|
|
|
|
self.drop_rate = drop_rate
|
|
|
|
|
self.feature_info = []
|
|
|
|
|
|
|
|
|
|
assert stem_type in ('patch', 'overlap')
|
|
|
|
|
assert stem_type in ('patch', 'overlap', 'overlap_tiered')
|
|
|
|
|
if stem_type == 'patch':
|
|
|
|
|
assert stem_kernel_size == stem_stride
|
|
|
|
|
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
|
|
|
|
|
self.stem = nn.Sequential(
|
|
|
|
|
nn.Conv2d(in_chans, dims[0], kernel_size=stem_kernel_size, stride=stem_stride, bias=conv_bias),
|
|
|
|
|
nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias),
|
|
|
|
|
norm_layer(dims[0])
|
|
|
|
|
)
|
|
|
|
|
stem_stride = patch_size
|
|
|
|
|
else:
|
|
|
|
|
mid_chs = make_divisible(dims[0] // 2) if 'tiered' in stem_type else dims[0]
|
|
|
|
|
self.stem = nn.Sequential(
|
|
|
|
|
nn.Conv2d(
|
|
|
|
|
in_chans, dims[0], kernel_size=stem_kernel_size, stride=stem_stride,
|
|
|
|
|
padding=stem_kernel_size // 2, bias=conv_bias),
|
|
|
|
|
nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias),
|
|
|
|
|
nn.Conv2d(mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias),
|
|
|
|
|
norm_layer(dims[0]),
|
|
|
|
|
)
|
|
|
|
|
stem_stride = 4
|
|
|
|
|
|
|
|
|
|
self.stages = nn.Sequential()
|
|
|
|
|
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
|
|
|
@ -415,20 +433,65 @@ def _create_convnext(variant, pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def convnext_nano(pretrained=False, **kwargs):
|
|
|
|
|
# timm nano variant with standard stem and head
|
|
|
|
|
def convnext_atto(pretrained=False, **kwargs):
|
|
|
|
|
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
|
|
|
|
|
model_args = dict(
|
|
|
|
|
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, **kwargs)
|
|
|
|
|
model = _create_convnext('convnext_nano', pretrained=pretrained, **model_args)
|
|
|
|
|
depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, **kwargs)
|
|
|
|
|
model = _create_convnext('convnext_atto', pretrained=pretrained, **model_args)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def convnext_atto_ols(pretrained=False, **kwargs):
|
|
|
|
|
# timm femto variant with overlapping 3x3 conv stem, wider than non-ols femto above, current param count 3.7M
|
|
|
|
|
model_args = dict(
|
|
|
|
|
depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, stem_type='overlap_tiered', **kwargs)
|
|
|
|
|
model = _create_convnext('convnext_atto_ols', pretrained=pretrained, **model_args)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def convnext_femto(pretrained=False, **kwargs):
|
|
|
|
|
# timm femto variant
|
|
|
|
|
model_args = dict(
|
|
|
|
|
depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True, **kwargs)
|
|
|
|
|
model = _create_convnext('convnext_femto', pretrained=pretrained, **model_args)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def convnext_femto_ols(pretrained=False, **kwargs):
|
|
|
|
|
# timm femto variant
|
|
|
|
|
model_args = dict(
|
|
|
|
|
depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True, stem_type='overlap_tiered', **kwargs)
|
|
|
|
|
model = _create_convnext('convnext_femto_ols', pretrained=pretrained, **model_args)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def convnext_pico(pretrained=False, **kwargs):
|
|
|
|
|
# timm pico variant
|
|
|
|
|
model_args = dict(
|
|
|
|
|
depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True, **kwargs)
|
|
|
|
|
model = _create_convnext('convnext_pico', pretrained=pretrained, **model_args)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def convnext_nano_hnf(pretrained=False, **kwargs):
|
|
|
|
|
# experimental nano variant with normalization before pooling in head (head norm first)
|
|
|
|
|
def convnext_pico_ols(pretrained=False, **kwargs):
|
|
|
|
|
# timm nano variant with overlapping 3x3 conv stem
|
|
|
|
|
model_args = dict(
|
|
|
|
|
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), head_norm_first=True, conv_mlp=True, **kwargs)
|
|
|
|
|
model = _create_convnext('convnext_nano_hnf', pretrained=pretrained, **model_args)
|
|
|
|
|
depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True, stem_type='overlap_tiered', **kwargs)
|
|
|
|
|
model = _create_convnext('convnext_pico_ols', pretrained=pretrained, **model_args)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def convnext_nano(pretrained=False, **kwargs):
|
|
|
|
|
# timm nano variant with standard stem and head
|
|
|
|
|
model_args = dict(
|
|
|
|
|
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, **kwargs)
|
|
|
|
|
model = _create_convnext('convnext_nano', pretrained=pretrained, **model_args)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -436,8 +499,7 @@ def convnext_nano_hnf(pretrained=False, **kwargs):
|
|
|
|
|
def convnext_nano_ols(pretrained=False, **kwargs):
|
|
|
|
|
# experimental nano variant with overlapping conv stem
|
|
|
|
|
model_args = dict(
|
|
|
|
|
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True,
|
|
|
|
|
stem_type='overlap', stem_kernel_size=9, **kwargs)
|
|
|
|
|
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, stem_type='overlap', **kwargs)
|
|
|
|
|
model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **model_args)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|