From 2544d3b80fdbb7978cba8dbbaf5ae18fcc54efd5 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 5 Aug 2022 17:05:50 -0700 Subject: [PATCH] ConvNeXt pico, femto, and nano, pico, femto ols (overlapping stem) weights and model defs --- timm/models/convnext.py | 104 ++++++++++++++++++++++++++++++++-------- 1 file changed, 83 insertions(+), 21 deletions(-) diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 4b22c929..0c324719 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -4,6 +4,8 @@ Paper: `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf Original code and weights from https://github.com/facebookresearch/ConvNeXt, original copyright below +Model defs atto, femto, pico, nano and _ols / _hnf variants are timm specific. + Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman """ # Copyright (c) Meta Platforms, Inc. and affiliates. @@ -18,7 +20,8 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import named_apply, build_model_with_cfg, checkpoint_seq -from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, create_conv2d +from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d,\ + create_conv2d, make_divisible from .registry import register_model @@ -43,11 +46,26 @@ default_cfgs = dict( convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"), # timm specific variants + convnext_atto=_cfg(url=''), + convnext_atto_ols=_cfg(url=''), + convnext_femto=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth', + test_input_size=(3, 288, 288), test_crop_pct=0.95), + convnext_femto_ols=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_ols_d1-246bf2ed.pth', + test_input_size=(3, 288, 288), test_crop_pct=0.95), + convnext_pico=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth', + test_input_size=(3, 288, 288), test_crop_pct=0.95), + convnext_pico_ols=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_ols_d1-611f0ca7.pth', + crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), convnext_nano=_cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth', crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), - convnext_nano_hnf=_cfg(url=''), - convnext_nano_ols=_cfg(url=''), + convnext_nano_ols=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_ols_d1h-ae424a9a.pth', + crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), convnext_tiny_hnf=_cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth', crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), @@ -236,8 +254,7 @@ class ConvNeXt(nn.Module): dims=(96, 192, 384, 768), ls_init_value=1e-6, stem_type='patch', - stem_kernel_size=4, - stem_stride=4, + patch_size=4, head_init_scale=1., head_norm_first=False, conv_mlp=False, @@ -260,21 +277,22 @@ class ConvNeXt(nn.Module): self.drop_rate = drop_rate self.feature_info = [] - assert stem_type in ('patch', 'overlap') + assert stem_type in ('patch', 'overlap', 'overlap_tiered') if stem_type == 'patch': - assert stem_kernel_size == stem_stride # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4 self.stem = nn.Sequential( - nn.Conv2d(in_chans, dims[0], kernel_size=stem_kernel_size, stride=stem_stride, bias=conv_bias), + nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias), norm_layer(dims[0]) ) + stem_stride = patch_size else: + mid_chs = make_divisible(dims[0] // 2) if 'tiered' in stem_type else dims[0] self.stem = nn.Sequential( - nn.Conv2d( - in_chans, dims[0], kernel_size=stem_kernel_size, stride=stem_stride, - padding=stem_kernel_size // 2, bias=conv_bias), + nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias), + nn.Conv2d(mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias), norm_layer(dims[0]), ) + stem_stride = 4 self.stages = nn.Sequential() dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] @@ -415,20 +433,65 @@ def _create_convnext(variant, pretrained=False, **kwargs): @register_model -def convnext_nano(pretrained=False, **kwargs): - # timm nano variant with standard stem and head +def convnext_atto(pretrained=False, **kwargs): + # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M model_args = dict( - depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, **kwargs) - model = _create_convnext('convnext_nano', pretrained=pretrained, **model_args) + depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, **kwargs) + model = _create_convnext('convnext_atto', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_atto_ols(pretrained=False, **kwargs): + # timm femto variant with overlapping 3x3 conv stem, wider than non-ols femto above, current param count 3.7M + model_args = dict( + depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, stem_type='overlap_tiered', **kwargs) + model = _create_convnext('convnext_atto_ols', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_femto(pretrained=False, **kwargs): + # timm femto variant + model_args = dict( + depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True, **kwargs) + model = _create_convnext('convnext_femto', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_femto_ols(pretrained=False, **kwargs): + # timm femto variant + model_args = dict( + depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True, stem_type='overlap_tiered', **kwargs) + model = _create_convnext('convnext_femto_ols', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_pico(pretrained=False, **kwargs): + # timm pico variant + model_args = dict( + depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True, **kwargs) + model = _create_convnext('convnext_pico', pretrained=pretrained, **model_args) return model @register_model -def convnext_nano_hnf(pretrained=False, **kwargs): - # experimental nano variant with normalization before pooling in head (head norm first) +def convnext_pico_ols(pretrained=False, **kwargs): + # timm nano variant with overlapping 3x3 conv stem model_args = dict( - depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), head_norm_first=True, conv_mlp=True, **kwargs) - model = _create_convnext('convnext_nano_hnf', pretrained=pretrained, **model_args) + depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True, stem_type='overlap_tiered', **kwargs) + model = _create_convnext('convnext_pico_ols', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_nano(pretrained=False, **kwargs): + # timm nano variant with standard stem and head + model_args = dict( + depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, **kwargs) + model = _create_convnext('convnext_nano', pretrained=pretrained, **model_args) return model @@ -436,8 +499,7 @@ def convnext_nano_hnf(pretrained=False, **kwargs): def convnext_nano_ols(pretrained=False, **kwargs): # experimental nano variant with overlapping conv stem model_args = dict( - depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, - stem_type='overlap', stem_kernel_size=9, **kwargs) + depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, stem_type='overlap', **kwargs) model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **model_args) return model