diff --git a/README.md b/README.md index c7f4bee7..019fdae2 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,12 @@ Thanks to the following for hardware support: And a big thanks to all GitHub sponsors who helped with some of my costs before I joined Hugging Face. ## What's New + +### Aug 15, 2022 +* ConvNeXt atto weights added + * `convnext_atto` - 75.7 @ 224, 77.0 @ 288 + * `convnext_atto_ols` - 75.9 @ 224, 77.2 @ 288 + ### Aug 5, 2022 * More custom ConvNeXt smaller model defs with weights * `convnext_femto` - 77.5 @ 224, 78.7 @ 288 diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 0c324719..ba63a453 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -16,12 +16,11 @@ from functools import partial import torch import torch.nn as nn -import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import named_apply, build_model_with_cfg, checkpoint_seq from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d,\ - create_conv2d, make_divisible + create_conv2d, get_act_layer, make_divisible, to_ntuple from .registry import register_model @@ -40,14 +39,13 @@ def _cfg(url='', **kwargs): default_cfgs = dict( - convnext_tiny=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth"), - convnext_small=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth"), - convnext_base=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth"), - convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"), - # timm specific variants - convnext_atto=_cfg(url=''), - convnext_atto_ols=_cfg(url=''), + convnext_atto=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth', + test_input_size=(3, 288, 288), test_crop_pct=0.95), + convnext_atto_ols=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth', + test_input_size=(3, 288, 288), test_crop_pct=0.95), convnext_femto=_cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth', test_input_size=(3, 288, 288), test_crop_pct=0.95), @@ -70,16 +68,34 @@ default_cfgs = dict( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth', crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), + convnext_tiny=_cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", + test_input_size=(3, 288, 288), test_crop_pct=1.0), + convnext_small=_cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", + test_input_size=(3, 288, 288), test_crop_pct=1.0), + convnext_base=_cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", + test_input_size=(3, 288, 288), test_crop_pct=1.0), + convnext_large=_cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth", + test_input_size=(3, 288, 288), test_crop_pct=1.0), + convnext_tiny_in22ft1k=_cfg( - url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth'), + url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth', + test_input_size=(3, 288, 288), test_crop_pct=1.0), convnext_small_in22ft1k=_cfg( - url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth'), + url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth', + test_input_size=(3, 288, 288), test_crop_pct=1.0), convnext_base_in22ft1k=_cfg( - url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth'), + url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth', + test_input_size=(3, 288, 288), test_crop_pct=1.0), convnext_large_in22ft1k=_cfg( - url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth'), + url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth', + test_input_size=(3, 288, 288), test_crop_pct=1.0), convnext_xlarge_in22ft1k=_cfg( - url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth'), + url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth', + test_input_size=(3, 288, 288), test_crop_pct=1.0), convnext_tiny_384_in22ft1k=_cfg( url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth', @@ -121,37 +137,39 @@ class ConvNeXtBlock(nn.Module): is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW. Args: - dim (int): Number of input channels. + in_chs (int): Number of input channels. drop_path (float): Stochastic depth rate. Default: 0.0 ls_init_value (float): Init value for Layer Scale. Default: 1e-6. """ def __init__( self, - dim, - dim_out=None, + in_chs, + out_chs=None, + kernel_size=7, stride=1, dilation=1, mlp_ratio=4, conv_mlp=False, conv_bias=True, ls_init_value=1e-6, + act_layer='gelu', norm_layer=None, - act_layer=nn.GELU, drop_path=0., ): super().__init__() - dim_out = dim_out or dim + out_chs = out_chs or in_chs + act_layer = get_act_layer(act_layer) if not norm_layer: norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6) mlp_layer = ConvMlp if conv_mlp else Mlp self.use_conv_mlp = conv_mlp self.conv_dw = create_conv2d( - dim, dim_out, kernel_size=7, stride=stride, dilation=dilation, depthwise=True, bias=conv_bias) - self.norm = norm_layer(dim_out) - self.mlp = mlp_layer(dim_out, int(mlp_ratio * dim_out), act_layer=act_layer) - self.gamma = nn.Parameter(ls_init_value * torch.ones(dim_out)) if ls_init_value > 0 else None + in_chs, out_chs, kernel_size=kernel_size, stride=stride, dilation=dilation, depthwise=True, bias=conv_bias) + self.norm = norm_layer(out_chs) + self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer) + self.gamma = nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value > 0 else None self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): @@ -178,6 +196,7 @@ class ConvNeXtStage(nn.Module): self, in_chs, out_chs, + kernel_size=7, stride=2, depth=2, dilation=(1, 1), @@ -185,6 +204,7 @@ class ConvNeXtStage(nn.Module): ls_init_value=1.0, conv_mlp=False, conv_bias=True, + act_layer='gelu', norm_layer=None, norm_layer_cl=None ): @@ -208,13 +228,15 @@ class ConvNeXtStage(nn.Module): stage_blocks = [] for i in range(depth): stage_blocks.append(ConvNeXtBlock( - dim=in_chs, - dim_out=out_chs, + in_chs=in_chs, + out_chs=out_chs, + kernel_size=kernel_size, dilation=dilation[1], drop_path=drop_path_rates[i], ls_init_value=ls_init_value, conv_mlp=conv_mlp, conv_bias=conv_bias, + act_layer=act_layer, norm_layer=norm_layer if conv_mlp else norm_layer_cl )) in_chs = out_chs @@ -252,6 +274,7 @@ class ConvNeXt(nn.Module): output_stride=32, depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), + kernel_sizes=7, ls_init_value=1e-6, stem_type='patch', patch_size=4, @@ -259,12 +282,14 @@ class ConvNeXt(nn.Module): head_norm_first=False, conv_mlp=False, conv_bias=True, + act_layer='gelu', norm_layer=None, drop_rate=0., drop_path_rate=0., ): super().__init__() assert output_stride in (8, 16, 32) + kernel_sizes = to_ntuple(4)(kernel_sizes) if norm_layer is None: norm_layer = partial(LayerNorm2d, eps=1e-6) norm_layer_cl = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6) @@ -312,6 +337,7 @@ class ConvNeXt(nn.Module): stages.append(ConvNeXtStage( prev_chs, out_chs, + kernel_size=kernel_sizes[i], stride=stride, dilation=(first_dilation, dilation), depth=depths[i], @@ -319,6 +345,7 @@ class ConvNeXt(nn.Module): ls_init_value=ls_init_value, conv_mlp=conv_mlp, conv_bias=conv_bias, + act_layer=act_layer, norm_layer=norm_layer, norm_layer_cl=norm_layer_cl )) diff --git a/timm/version.py b/timm/version.py index 085bc856..6300a709 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.6.8' +__version__ = '0.6.9'