Compare commits

...

3 Commits

@ -28,6 +28,11 @@ For a few months now, `timm` has been part of the Hugging Face ecosystem. Yearly
If you have a couple of minutes and want to participate in shaping the future of the ecosystem, please share your thoughts:
[**hf.co/oss-survey**](https://hf.co/oss-survey) 🙏
### Jan 5, 2023
* ConvNeXt-V2 models and weights added to existing `convnext.py`
* Paper: [ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders](http://arxiv.org/abs/2301.00808)
* Reference impl: https://github.com/facebookresearch/ConvNeXt-V2 (NOTE: weights currently CC-BY-NC)
### Dec 23, 2022 🎄☃
* Add FlexiViT models and weights from https://github.com/google-research/big_vision (check out paper at https://arxiv.org/abs/2212.08013)
* NOTE currently resizing is static on model creation, on-the-fly dynamic / train patch size sampling is a WIP
@ -396,6 +401,7 @@ A full version of the list below with source links can be found in the [document
* CoaT (Co-Scale Conv-Attentional Image Transformers) - https://arxiv.org/abs/2104.06399
* CoAtNet (Convolution and Attention) - https://arxiv.org/abs/2106.04803
* ConvNeXt - https://arxiv.org/abs/2201.03545
* ConvNeXt-V2 - http://arxiv.org/abs/2301.00808
* ConViT (Soft Convolutional Inductive Biases Vision Transformers)- https://arxiv.org/abs/2103.10697
* CspNet (Cross-Stage Partial Networks) - https://arxiv.org/abs/1911.11929
* DeiT - https://arxiv.org/abs/2012.12877
@ -418,6 +424,7 @@ A full version of the list below with source links can be found in the [document
* Single-Path NAS - https://arxiv.org/abs/1904.02877
* TinyNet - https://arxiv.org/abs/2010.14819
* EVA - https://arxiv.org/abs/2211.07636
* FlexiViT - https://arxiv.org/abs/2212.08013
* GCViT (Global Context Vision Transformer) - https://arxiv.org/abs/2206.09959
* GhostNet - https://arxiv.org/abs/1911.11907
* gMLP - https://arxiv.org/abs/2105.08050

@ -38,7 +38,7 @@ if 'GITHUB_ACTIONS' in os.environ:
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*',
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*',
'swin*giant*']
'swin*giant*', 'convnextv2_huge*']
NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*', 'eva_giant*']
else:
EXCLUDE_FILTERS = []
@ -129,7 +129,7 @@ def test_model_backward(model_name, batch_size):
@pytest.mark.timeout(300)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=NON_STD_FILTERS))
@pytest.mark.parametrize('model_name', list_models(exclude_filters=NON_STD_FILTERS, include_tags=True))
@pytest.mark.parametrize('batch_size', [1])
def test_model_default_cfgs(model_name, batch_size):
"""Run a single forward pass with each model"""
@ -191,7 +191,7 @@ def test_model_default_cfgs(model_name, batch_size):
@pytest.mark.timeout(300)
@pytest.mark.parametrize('model_name', list_models(filter=NON_STD_FILTERS, exclude_filters=NON_STD_EXCLUDE_FILTERS))
@pytest.mark.parametrize('model_name', list_models(filter=NON_STD_FILTERS, exclude_filters=NON_STD_EXCLUDE_FILTERS, include_tags=True))
@pytest.mark.parametrize('batch_size', [1])
def test_model_default_cfgs_non_std(model_name, batch_size):
"""Run a single forward pass with each model"""
@ -304,7 +304,7 @@ if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FEAT_FILTERS))
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FEAT_FILTERS, include_tags=True))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_features(model_name, batch_size):
"""Run a single forward pass with each model in feature extraction mode"""

@ -26,7 +26,7 @@ from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible,
from .inplace_abn import InplaceAbn
from .linear import Linear
from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp, GlobalResponseNormMlp
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm

@ -0,0 +1,39 @@
""" Global Response Normalization Module
Based on the GRN layer presented in
`ConvNeXt-V2 - Co-designing and Scaling ConvNets with Masked Autoencoders` - https://arxiv.org/abs/2301.00808
This implementation
* works for both NCHW and NHWC tensor layouts
* uses affine param names matching existing torch norm layers
* slightly improves eager mode performance via fused addcmul
Hacked together by / Copyright 2023 Ross Wightman
"""
import torch
from torch import nn as nn
class GlobalResponseNorm(nn.Module):
""" Global Response Normalization layer
"""
def __init__(self, dim, eps=1e-6, channels_last=True):
super().__init__()
self.eps = eps
if channels_last:
self.spatial_dim = (1, 2)
self.channel_dim = -1
self.wb_shape = (1, 1, 1, -1)
else:
self.spatial_dim = (2, 3)
self.channel_dim = 1
self.wb_shape = (1, -1, 1, 1)
self.weight = nn.Parameter(torch.zeros(dim))
self.bias = nn.Parameter(torch.zeros(dim))
def forward(self, x):
x_g = x.norm(p=2, dim=self.spatial_dim, keepdim=True)
x_n = x_g / (x_g.mean(dim=self.channel_dim, keepdim=True) + self.eps)
return x + torch.addcmul(self.bias.view(self.wb_shape), self.weight.view(self.wb_shape), x * x_n)

@ -2,25 +2,38 @@
Hacked together by / Copyright 2020 Ross Wightman
"""
from functools import partial
from torch import nn as nn
from .grn import GlobalResponseNorm
from .helpers import to_2tuple
class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
bias=True,
drop=0.,
use_conv=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
@ -36,18 +49,29 @@ class GluMlp(nn.Module):
""" MLP w/ GLU style gating
See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, bias=True, drop=0.):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.Sigmoid,
bias=True,
drop=0.,
use_conv=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
assert hidden_features % 2 == 0
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.chunk_dim = 1 if use_conv else -1
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.fc2 = nn.Linear(hidden_features // 2, out_features, bias=bias[1])
self.fc2 = linear_layer(hidden_features // 2, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def init_weights(self):
@ -58,7 +82,7 @@ class GluMlp(nn.Module):
def forward(self, x):
x = self.fc1(x)
x, gates = x.chunk(2, dim=-1)
x, gates = x.chunk(2, dim=self.chunk_dim)
x = x * self.act(gates)
x = self.drop1(x)
x = self.fc2(x)
@ -70,8 +94,15 @@ class GatedMlp(nn.Module):
""" MLP as used in gMLP
"""
def __init__(
self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
gate_layer=None, bias=True, drop=0.):
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
gate_layer=None,
bias=True,
drop=0.,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
@ -104,8 +135,15 @@ class ConvMlp(nn.Module):
""" MLP using 1x1 convs that keeps spatial dims
"""
def __init__(
self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU,
norm_layer=None, bias=True, drop=0.):
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.ReLU,
norm_layer=None,
bias=True,
drop=0.,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
@ -124,3 +162,40 @@ class ConvMlp(nn.Module):
x = self.drop(x)
x = self.fc2(x)
return x
class GlobalResponseNormMlp(nn.Module):
""" MLP w/ Global Response Norm (see grn.py), nn.Linear or 1x1 Conv2d
"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
bias=True,
drop=0.,
use_conv=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.grn = GlobalResponseNorm(hidden_features, channels_last=not use_conv)
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.grn(x)
x = self.fc2(x)
x = self.drop2(x)
return x

@ -1,16 +1,42 @@
""" ConvNeXt
Paper: `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
Papers:
* `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
@Article{liu2022convnet,
author = {Zhuang Liu and Hanzi Mao and Chao-Yuan Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
title = {A ConvNet for the 2020s},
journal = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year = {2022},
}
* `ConvNeXt-V2 - Co-designing and Scaling ConvNets with Masked Autoencoders` - https://arxiv.org/abs/2301.00808
@article{Woo2023ConvNeXtV2,
title={ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders},
author={Sanghyun Woo, Shoubhik Debnath, Ronghang Hu, Xinlei Chen, Zhuang Liu, In So Kweon and Saining Xie},
year={2023},
journal={arXiv preprint arXiv:2301.00808},
}
Original code and weights from https://github.com/facebookresearch/ConvNeXt, original copyright below
Original code and weights from:
* https://github.com/facebookresearch/ConvNeXt, original copyright below
* https://github.com/facebookresearch/ConvNeXt-V2, original copyright below
Model defs atto, femto, pico, nano and _ols / _hnf variants are timm specific.
Model defs atto, femto, pico, nano and _ols / _hnf variants are timm originals.
Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman
"""
# ConvNeXt
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the MIT license
# ConvNeXt-V2
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree (Attribution-NonCommercial 4.0 International (CC BY-NC 4.0))
# No code was used directly from ConvNeXt-V2, however the weights are CC BY-NC 4.0 so beware if using commercially.
from collections import OrderedDict
from functools import partial
@ -18,8 +44,8 @@ import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, LayerNorm, \
create_conv2d, get_act_layer, make_divisible, to_ntuple
from timm.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, Mlp, GlobalResponseNormMlp, \
LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
from ._builder import build_model_with_cfg
from ._manipulate import named_apply, checkpoint_seq
from ._pretrained import generate_default_cfgs
@ -54,6 +80,7 @@ class ConvNeXtBlock(nn.Module):
mlp_ratio=4,
conv_mlp=False,
conv_bias=True,
use_grn=False,
ls_init_value=1e-6,
act_layer='gelu',
norm_layer=None,
@ -64,14 +91,13 @@ class ConvNeXtBlock(nn.Module):
act_layer = get_act_layer(act_layer)
if not norm_layer:
norm_layer = LayerNorm2d if conv_mlp else LayerNorm
mlp_layer = ConvMlp if conv_mlp else Mlp
mlp_layer = partial(GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp)
self.use_conv_mlp = conv_mlp
self.conv_dw = create_conv2d(
in_chs, out_chs, kernel_size=kernel_size, stride=stride, dilation=dilation, depthwise=True, bias=conv_bias)
self.norm = norm_layer(out_chs)
self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer)
self.gamma = nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value > 0 else None
self.gamma = nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value is not None else None
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
@ -106,6 +132,7 @@ class ConvNeXtStage(nn.Module):
ls_init_value=1.0,
conv_mlp=False,
conv_bias=True,
use_grn=False,
act_layer='gelu',
norm_layer=None,
norm_layer_cl=None
@ -138,6 +165,7 @@ class ConvNeXtStage(nn.Module):
ls_init_value=ls_init_value,
conv_mlp=conv_mlp,
conv_bias=conv_bias,
use_grn=use_grn,
act_layer=act_layer,
norm_layer=norm_layer if conv_mlp else norm_layer_cl
))
@ -184,6 +212,7 @@ class ConvNeXt(nn.Module):
head_norm_first=False,
conv_mlp=False,
conv_bias=True,
use_grn=False,
act_layer='gelu',
norm_layer=None,
drop_rate=0.,
@ -247,6 +276,7 @@ class ConvNeXt(nn.Module):
ls_init_value=ls_init_value,
conv_mlp=conv_mlp,
conv_bias=conv_bias,
use_grn=use_grn,
act_layer=act_layer,
norm_layer=norm_layer,
norm_layer_cl=norm_layer_cl
@ -259,10 +289,11 @@ class ConvNeXt(nn.Module):
# if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
# otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
self.head_norm_first = head_norm_first
self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity()
self.head = nn.Sequential(OrderedDict([
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)),
('norm', nn.Identity() if head_norm_first or num_classes == 0 else norm_layer(self.num_features)),
('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
('drop', nn.Dropout(self.drop_rate)),
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())]))
@ -293,7 +324,14 @@ class ConvNeXt(nn.Module):
if global_pool is not None:
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
if num_classes == 0:
self.head.norm = nn.Identity()
self.head.fc = nn.Identity()
else:
if not self.head_norm_first:
norm_layer = type(self.stem[-1]) # obtain type from stem norm
self.head.norm = norm_layer(self.num_features)
self.head.fc = nn.Linear(self.num_features, num_classes)
def forward_features(self, x):
x = self.stem(x)
@ -342,6 +380,10 @@ def checkpoint_filter_fn(state_dict, model):
k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k)
k = k.replace('dwconv', 'conv_dw')
k = k.replace('pwconv', 'mlp.fc')
if 'grn' in k:
k = k.replace('grn.beta', 'mlp.grn.bias')
k = k.replace('grn.gamma', 'mlp.grn.weight')
v = v.reshape(v.shape[-1])
k = k.replace('head.', 'head.fc.')
if k.startswith('norm.'):
k = k.replace('norm', 'head.norm')
@ -372,6 +414,20 @@ def _cfg(url='', **kwargs):
}
def _cfgv2(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.0', 'classifier': 'head.fc',
'license': 'cc-by-nc-4.0', 'paper_ids': 'arXiv:2301.00808',
'paper_name': 'ConvNeXt-V2: Co-designing and Scaling ConvNets with Masked Autoencoders',
'origin_url': 'https://github.com/facebookresearch/ConvNeXt-V2',
**kwargs
}
default_cfgs = generate_default_cfgs({
# timm specific variants
'convnext_atto.d2_in1k': _cfg(
@ -499,6 +555,115 @@ default_cfgs = generate_default_cfgs({
url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
hf_hub_id='timm/',
num_classes=21841),
'convnextv2_nano.fcmae_ft_in22k_in1k': _cfgv2(
url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_224_ema.pt',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnextv2_nano.fcmae_ft_in22k_in1k_384': _cfgv2(
url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_384_ema.pt',
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnextv2_tiny.fcmae_ft_in22k_in1k': _cfgv2(
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_224_ema.pt",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnextv2_tiny.fcmae_ft_in22k_in1k_384': _cfgv2(
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_384_ema.pt",
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnextv2_base.fcmae_ft_in22k_in1k': _cfgv2(
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_224_ema.pt",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnextv2_base.fcmae_ft_in22k_in1k_384': _cfgv2(
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_384_ema.pt",
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnextv2_large.fcmae_ft_in22k_in1k': _cfgv2(
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_224_ema.pt",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnextv2_large.fcmae_ft_in22k_in1k_384': _cfgv2(
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_384_ema.pt",
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnextv2_huge.fcmae_ft_in22k_in1k_384': _cfgv2(
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_384_ema.pt",
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnextv2_huge.fcmae_ft_in22k_in1k_512': _cfgv2(
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_512_ema.pt",
hf_hub_id='timm/',
input_size=(3, 512, 512), pool_size=(15, 15), crop_pct=1.0, crop_mode='squash'),
'convnextv2_atto.fcmae_ft_in1k': _cfgv2(
url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_atto_1k_224_ema.pt',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=0.95),
'convnextv2_femto.fcmae_ft_in1k': _cfgv2(
url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_femto_1k_224_ema.pt',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=0.95),
'convnextv2_pico.fcmae_ft_in1k': _cfgv2(
url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_pico_1k_224_ema.pt',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=0.95),
'convnextv2_nano.fcmae_ft_in1k': _cfgv2(
url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_nano_1k_224_ema.pt',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnextv2_tiny.fcmae_ft_in1k': _cfgv2(
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_tiny_1k_224_ema.pt",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnextv2_base.fcmae_ft_in1k': _cfgv2(
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_base_1k_224_ema.pt",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnextv2_large.fcmae_ft_in1k': _cfgv2(
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_large_1k_224_ema.pt",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnextv2_huge.fcmae_ft_in1k': _cfgv2(
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_huge_1k_224_ema.pt",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnextv2_atto.fcmae': _cfgv2(
url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_atto_1k_224_fcmae.pt',
hf_hub_id='timm/',
num_classes=0),
'convnextv2_femto.fcmae': _cfgv2(
url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_femto_1k_224_fcmae.pt',
hf_hub_id='timm/',
num_classes=0),
'convnextv2_pico.fcmae': _cfgv2(
url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_pico_1k_224_fcmae.pt',
hf_hub_id='timm/',
num_classes=0),
'convnextv2_nano.fcmae': _cfgv2(
url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_nano_1k_224_fcmae.pt',
hf_hub_id='timm/',
num_classes=0),
'convnextv2_tiny.fcmae': _cfgv2(
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_tiny_1k_224_fcmae.pt",
hf_hub_id='timm/',
num_classes=0),
'convnextv2_base.fcmae': _cfgv2(
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_base_1k_224_fcmae.pt",
hf_hub_id='timm/',
num_classes=0),
'convnextv2_large.fcmae': _cfgv2(
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_large_1k_224_fcmae.pt",
hf_hub_id='timm/',
num_classes=0),
'convnextv2_huge.fcmae': _cfgv2(
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_huge_1k_224_fcmae.pt",
hf_hub_id='timm/',
num_classes=0),
'convnextv2_small.untrained': _cfg(),
})
@ -623,3 +788,75 @@ def convnext_xxlarge(pretrained=False, **kwargs):
model_args = dict(depths=[3, 4, 30, 3], dims=[384, 768, 1536, 3072], **kwargs)
model = _create_convnext('convnext_xxlarge', pretrained=pretrained, **model_args)
return model
@register_model
def convnextv2_atto(pretrained=False, **kwargs):
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
model_args = dict(
depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), use_grn=True, ls_init_value=None, conv_mlp=True, **kwargs)
model = _create_convnext('convnextv2_atto', pretrained=pretrained, **model_args)
return model
@register_model
def convnextv2_femto(pretrained=False, **kwargs):
# timm femto variant
model_args = dict(
depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), use_grn=True, ls_init_value=None, conv_mlp=True, **kwargs)
model = _create_convnext('convnextv2_femto', pretrained=pretrained, **model_args)
return model
@register_model
def convnextv2_pico(pretrained=False, **kwargs):
# timm pico variant
model_args = dict(
depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), use_grn=True, ls_init_value=None, conv_mlp=True, **kwargs)
model = _create_convnext('convnextv2_pico', pretrained=pretrained, **model_args)
return model
@register_model
def convnextv2_nano(pretrained=False, **kwargs):
# timm nano variant with standard stem and head
model_args = dict(
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), use_grn=True, ls_init_value=None, conv_mlp=True, **kwargs)
model = _create_convnext('convnextv2_nano', pretrained=pretrained, **model_args)
return model
@register_model
def convnextv2_tiny(pretrained=False, **kwargs):
model_args = dict(
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), use_grn=True, ls_init_value=None, **kwargs)
model = _create_convnext('convnextv2_tiny', pretrained=pretrained, **model_args)
return model
@register_model
def convnextv2_small(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], use_grn=True, ls_init_value=None, **kwargs)
model = _create_convnext('convnextv2_small', pretrained=pretrained, **model_args)
return model
@register_model
def convnextv2_base(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], use_grn=True, ls_init_value=None, **kwargs)
model = _create_convnext('convnextv2_base', pretrained=pretrained, **model_args)
return model
@register_model
def convnextv2_large(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], use_grn=True, ls_init_value=None, **kwargs)
model = _create_convnext('convnextv2_large', pretrained=pretrained, **model_args)
return model
@register_model
def convnextv2_huge(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], use_grn=True, ls_init_value=None, **kwargs)
model = _create_convnext('convnextv2_huge', pretrained=pretrained, **model_args)
return model
Loading…
Cancel
Save