ConvNeXt pico, femto, and nano, pico, femto ols (overlapping stem) weights and model defs

pull/1420/head
Ross Wightman 2 years ago
parent 13565aad50
commit 2544d3b80f

@ -4,6 +4,8 @@ Paper: `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
Original code and weights from https://github.com/facebookresearch/ConvNeXt, original copyright below
Model defs atto, femto, pico, nano and _ols / _hnf variants are timm specific.
Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman
"""
# Copyright (c) Meta Platforms, Inc. and affiliates.
@ -18,7 +20,8 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import named_apply, build_model_with_cfg, checkpoint_seq
from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, create_conv2d
from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d,\
create_conv2d, make_divisible
from .registry import register_model
@ -43,11 +46,26 @@ default_cfgs = dict(
convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"),
# timm specific variants
convnext_atto=_cfg(url=''),
convnext_atto_ols=_cfg(url=''),
convnext_femto=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth',
test_input_size=(3, 288, 288), test_crop_pct=0.95),
convnext_femto_ols=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_ols_d1-246bf2ed.pth',
test_input_size=(3, 288, 288), test_crop_pct=0.95),
convnext_pico=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth',
test_input_size=(3, 288, 288), test_crop_pct=0.95),
convnext_pico_ols=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_ols_d1-611f0ca7.pth',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
convnext_nano=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
convnext_nano_hnf=_cfg(url=''),
convnext_nano_ols=_cfg(url=''),
convnext_nano_ols=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_ols_d1h-ae424a9a.pth',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
convnext_tiny_hnf=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
@ -236,8 +254,7 @@ class ConvNeXt(nn.Module):
dims=(96, 192, 384, 768),
ls_init_value=1e-6,
stem_type='patch',
stem_kernel_size=4,
stem_stride=4,
patch_size=4,
head_init_scale=1.,
head_norm_first=False,
conv_mlp=False,
@ -260,21 +277,22 @@ class ConvNeXt(nn.Module):
self.drop_rate = drop_rate
self.feature_info = []
assert stem_type in ('patch', 'overlap')
assert stem_type in ('patch', 'overlap', 'overlap_tiered')
if stem_type == 'patch':
assert stem_kernel_size == stem_stride
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
self.stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=stem_kernel_size, stride=stem_stride, bias=conv_bias),
nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias),
norm_layer(dims[0])
)
stem_stride = patch_size
else:
mid_chs = make_divisible(dims[0] // 2) if 'tiered' in stem_type else dims[0]
self.stem = nn.Sequential(
nn.Conv2d(
in_chans, dims[0], kernel_size=stem_kernel_size, stride=stem_stride,
padding=stem_kernel_size // 2, bias=conv_bias),
nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias),
nn.Conv2d(mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias),
norm_layer(dims[0]),
)
stem_stride = 4
self.stages = nn.Sequential()
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
@ -415,20 +433,65 @@ def _create_convnext(variant, pretrained=False, **kwargs):
@register_model
def convnext_nano(pretrained=False, **kwargs):
# timm nano variant with standard stem and head
def convnext_atto(pretrained=False, **kwargs):
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
model_args = dict(
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, **kwargs)
model = _create_convnext('convnext_nano', pretrained=pretrained, **model_args)
depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, **kwargs)
model = _create_convnext('convnext_atto', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_atto_ols(pretrained=False, **kwargs):
# timm femto variant with overlapping 3x3 conv stem, wider than non-ols femto above, current param count 3.7M
model_args = dict(
depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, stem_type='overlap_tiered', **kwargs)
model = _create_convnext('convnext_atto_ols', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_femto(pretrained=False, **kwargs):
# timm femto variant
model_args = dict(
depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True, **kwargs)
model = _create_convnext('convnext_femto', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_femto_ols(pretrained=False, **kwargs):
# timm femto variant
model_args = dict(
depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True, stem_type='overlap_tiered', **kwargs)
model = _create_convnext('convnext_femto_ols', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_pico(pretrained=False, **kwargs):
# timm pico variant
model_args = dict(
depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True, **kwargs)
model = _create_convnext('convnext_pico', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_nano_hnf(pretrained=False, **kwargs):
# experimental nano variant with normalization before pooling in head (head norm first)
def convnext_pico_ols(pretrained=False, **kwargs):
# timm nano variant with overlapping 3x3 conv stem
model_args = dict(
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), head_norm_first=True, conv_mlp=True, **kwargs)
model = _create_convnext('convnext_nano_hnf', pretrained=pretrained, **model_args)
depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True, stem_type='overlap_tiered', **kwargs)
model = _create_convnext('convnext_pico_ols', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_nano(pretrained=False, **kwargs):
# timm nano variant with standard stem and head
model_args = dict(
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, **kwargs)
model = _create_convnext('convnext_nano', pretrained=pretrained, **model_args)
return model
@ -436,8 +499,7 @@ def convnext_nano_hnf(pretrained=False, **kwargs):
def convnext_nano_ols(pretrained=False, **kwargs):
# experimental nano variant with overlapping conv stem
model_args = dict(
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True,
stem_type='overlap', stem_kernel_size=9, **kwargs)
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, stem_type='overlap', **kwargs)
model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **model_args)
return model

Loading…
Cancel
Save