@ -21,111 +21,13 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from . helpers import named_apply , build_model_with_cfg , checkpoint_seq
from . layers import trunc_normal_ , SelectAdaptivePool2d , DropPath , ConvMlp , Mlp , LayerNorm2d , LayerNorm , \
create_conv2d , get_act_layer , make_divisible , to_ntuple
from . _pretrained import generate_defaults
from . registry import register_model
__all__ = [ ' ConvNeXt ' ] # model_registry will add each entrypoint fn to this
def _cfg ( url = ' ' , * * kwargs ) :
return {
' url ' : url ,
' num_classes ' : 1000 , ' input_size ' : ( 3 , 224 , 224 ) , ' pool_size ' : ( 7 , 7 ) ,
' crop_pct ' : 0.875 , ' interpolation ' : ' bicubic ' ,
' mean ' : IMAGENET_DEFAULT_MEAN , ' std ' : IMAGENET_DEFAULT_STD ,
' first_conv ' : ' stem.0 ' , ' classifier ' : ' head.fc ' ,
* * kwargs
}
default_cfgs = dict (
# timm specific variants
convnext_atto = _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth ' ,
test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 0.95 ) ,
convnext_atto_ols = _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth ' ,
test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 0.95 ) ,
convnext_femto = _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth ' ,
test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 0.95 ) ,
convnext_femto_ols = _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_ols_d1-246bf2ed.pth ' ,
test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 0.95 ) ,
convnext_pico = _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth ' ,
test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 0.95 ) ,
convnext_pico_ols = _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_ols_d1-611f0ca7.pth ' ,
crop_pct = 0.95 , test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 1.0 ) ,
convnext_nano = _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth ' ,
crop_pct = 0.95 , test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 1.0 ) ,
convnext_nano_ols = _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_ols_d1h-ae424a9a.pth ' ,
crop_pct = 0.95 , test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 1.0 ) ,
convnext_tiny_hnf = _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth ' ,
crop_pct = 0.95 , test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 1.0 ) ,
convnext_tiny = _cfg (
url = " https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth " ,
test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 1.0 ) ,
convnext_small = _cfg (
url = " https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth " ,
test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 1.0 ) ,
convnext_base = _cfg (
url = " https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth " ,
test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 1.0 ) ,
convnext_large = _cfg (
url = " https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth " ,
test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 1.0 ) ,
convnext_tiny_in22ft1k = _cfg (
url = ' https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth ' ,
test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 1.0 ) ,
convnext_small_in22ft1k = _cfg (
url = ' https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth ' ,
test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 1.0 ) ,
convnext_base_in22ft1k = _cfg (
url = ' https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth ' ,
test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 1.0 ) ,
convnext_large_in22ft1k = _cfg (
url = ' https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth ' ,
test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 1.0 ) ,
convnext_xlarge_in22ft1k = _cfg (
url = ' https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth ' ,
test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 1.0 ) ,
convnext_tiny_384_in22ft1k = _cfg (
url = ' https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth ' ,
input_size = ( 3 , 384 , 384 ) , pool_size = ( 12 , 12 ) , crop_pct = 1.0 ) ,
convnext_small_384_in22ft1k = _cfg (
url = ' https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_384.pth ' ,
input_size = ( 3 , 384 , 384 ) , pool_size = ( 12 , 12 ) , crop_pct = 1.0 ) ,
convnext_base_384_in22ft1k = _cfg (
url = ' https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth ' ,
input_size = ( 3 , 384 , 384 ) , pool_size = ( 12 , 12 ) , crop_pct = 1.0 ) ,
convnext_large_384_in22ft1k = _cfg (
url = ' https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth ' ,
input_size = ( 3 , 384 , 384 ) , pool_size = ( 12 , 12 ) , crop_pct = 1.0 ) ,
convnext_xlarge_384_in22ft1k = _cfg (
url = ' https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth ' ,
input_size = ( 3 , 384 , 384 ) , pool_size = ( 12 , 12 ) , crop_pct = 1.0 ) ,
convnext_tiny_in22k = _cfg (
url = " https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth " , num_classes = 21841 ) ,
convnext_small_in22k = _cfg (
url = " https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth " , num_classes = 21841 ) ,
convnext_base_in22k = _cfg (
url = " https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth " , num_classes = 21841 ) ,
convnext_large_in22k = _cfg (
url = " https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth " , num_classes = 21841 ) ,
convnext_xlarge_in22k = _cfg (
url = " https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth " , num_classes = 21841 ) ,
)
class ConvNeXtBlock ( nn . Module ) :
""" ConvNeXt Block
There are two equivalent implementations :
@ -459,6 +361,107 @@ def _create_convnext(variant, pretrained=False, **kwargs):
return model
def _cfg ( url = ' ' , * * kwargs ) :
return {
' url ' : url ,
' num_classes ' : 1000 , ' input_size ' : ( 3 , 224 , 224 ) , ' pool_size ' : ( 7 , 7 ) ,
' crop_pct ' : 0.875 , ' interpolation ' : ' bicubic ' ,
' mean ' : IMAGENET_DEFAULT_MEAN , ' std ' : IMAGENET_DEFAULT_STD ,
' first_conv ' : ' stem.0 ' , ' classifier ' : ' head.fc ' ,
* * kwargs
}
default_cfgs = generate_defaults ( {
# timm specific variants
' convnext_atto.timm_in1k ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth ' ,
test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 0.95 ) ,
' convnext_atto_ols.timm_in1k ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth ' ,
test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 0.95 ) ,
' convnext_femto.timm_in1k ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth ' ,
test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 0.95 ) ,
' convnext_femto_ols.timm_in1k ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_ols_d1-246bf2ed.pth ' ,
test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 0.95 ) ,
' convnext_pico.timm_in1k ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth ' ,
test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 0.95 ) ,
' convnext_pico_ols.timm_in1k ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_ols_d1-611f0ca7.pth ' ,
crop_pct = 0.95 , test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 1.0 ) ,
' convnext_nano.timm_in1k ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth ' ,
crop_pct = 0.95 , test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 1.0 ) ,
' convnext_nano_ols.timm_in1k ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_ols_d1h-ae424a9a.pth ' ,
crop_pct = 0.95 , test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 1.0 ) ,
' convnext_tiny_hnf.timm_in1k ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth ' ,
crop_pct = 0.95 , test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 1.0 ) ,
' convnext_tiny.fb_in1k ' : _cfg (
url = " https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth " ,
test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 1.0 ) ,
' convnext_small.fb_in1k ' : _cfg (
url = " https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth " ,
test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 1.0 ) ,
' convnext_base.fb_in1k ' : _cfg (
url = " https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth " ,
test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 1.0 ) ,
' convnext_large.fb_in1k ' : _cfg (
url = " https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth " ,
test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 1.0 ) ,
' convnext_xlarge.untrained ' : _cfg ( ) ,
' convnext_tiny.fb_in22k_ft_in1k ' : _cfg (
url = ' https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth ' ,
test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 1.0 ) ,
' convnext_small.fb_in22k_ft_in1k ' : _cfg (
url = ' https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth ' ,
test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 1.0 ) ,
' convnext_base.fb_in22k_ft_in1k ' : _cfg (
url = ' https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth ' ,
test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 1.0 ) ,
' convnext_large.fb_in22k_ft_in1k ' : _cfg (
url = ' https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth ' ,
test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 1.0 ) ,
' convnext_xlarge.fb_in22k_ft_in1k ' : _cfg (
url = ' https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth ' ,
test_input_size = ( 3 , 288 , 288 ) , test_crop_pct = 1.0 ) ,
' convnext_tiny.fb_in22k_ft_in1k_384 ' : _cfg (
url = ' https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth ' ,
input_size = ( 3 , 384 , 384 ) , pool_size = ( 12 , 12 ) , crop_pct = 1.0 , crop_mode = ' squash ' ) ,
' convnext_small..fb_in22k_ft_in1k_384 ' : _cfg (
url = ' https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_384.pth ' ,
input_size = ( 3 , 384 , 384 ) , pool_size = ( 12 , 12 ) , crop_pct = 1.0 , crop_mode = ' squash ' ) ,
' convnext_base.fb_in22k_ft_in1k_384 ' : _cfg (
url = ' https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth ' ,
input_size = ( 3 , 384 , 384 ) , pool_size = ( 12 , 12 ) , crop_pct = 1.0 , crop_mode = ' squash ' ) ,
' convnext_large.fb_in22k_ft_in1k_384 ' : _cfg (
url = ' https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth ' ,
input_size = ( 3 , 384 , 384 ) , pool_size = ( 12 , 12 ) , crop_pct = 1.0 , crop_mode = ' squash ' ) ,
' convnext_xlarge.fb_in22k_ft_in1k_384 ' : _cfg (
url = ' https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth ' ,
input_size = ( 3 , 384 , 384 ) , pool_size = ( 12 , 12 ) , crop_pct = 1.0 , crop_mode = ' squash ' ) ,
' convnext_tiny_in22k.fb_in22k ' : _cfg (
url = " https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth " , num_classes = 21841 ) ,
' convnext_small_in22k.fb_in22k ' : _cfg (
url = " https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth " , num_classes = 21841 ) ,
' convnext_base_in22k.fb_in22k ' : _cfg (
url = " https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth " , num_classes = 21841 ) ,
' convnext_large_in22k.fb_in22k ' : _cfg (
url = " https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth " , num_classes = 21841 ) ,
' convnext_xlarge_in22k.fb_in22k ' : _cfg (
url = " https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth " , num_classes = 21841 ) ,
} )
@register_model
def convnext_atto ( pretrained = False , * * kwargs ) :
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
@ -569,105 +572,7 @@ def convnext_large(pretrained=False, **kwargs):
@register_model
def convnext_tiny_in22ft1k ( pretrained = False , * * kwargs ) :
model_args = dict ( depths = ( 3 , 3 , 9 , 3 ) , dims = ( 96 , 192 , 384 , 768 ) , * * kwargs )
model = _create_convnext ( ' convnext_tiny_in22ft1k ' , pretrained = pretrained , * * model_args )
return model
@register_model
def convnext_small_in22ft1k ( pretrained = False , * * kwargs ) :
model_args = dict ( depths = [ 3 , 3 , 27 , 3 ] , dims = [ 96 , 192 , 384 , 768 ] , * * kwargs )
model = _create_convnext ( ' convnext_small_in22ft1k ' , pretrained = pretrained , * * model_args )
return model
@register_model
def convnext_base_in22ft1k ( pretrained = False , * * kwargs ) :
model_args = dict ( depths = [ 3 , 3 , 27 , 3 ] , dims = [ 128 , 256 , 512 , 1024 ] , * * kwargs )
model = _create_convnext ( ' convnext_base_in22ft1k ' , pretrained = pretrained , * * model_args )
return model
@register_model
def convnext_large_in22ft1k ( pretrained = False , * * kwargs ) :
model_args = dict ( depths = [ 3 , 3 , 27 , 3 ] , dims = [ 192 , 384 , 768 , 1536 ] , * * kwargs )
model = _create_convnext ( ' convnext_large_in22ft1k ' , pretrained = pretrained , * * model_args )
return model
@register_model
def convnext_xlarge_in22ft1k ( pretrained = False , * * kwargs ) :
model_args = dict ( depths = [ 3 , 3 , 27 , 3 ] , dims = [ 256 , 512 , 1024 , 2048 ] , * * kwargs )
model = _create_convnext ( ' convnext_xlarge_in22ft1k ' , pretrained = pretrained , * * model_args )
return model
@register_model
def convnext_tiny_384_in22ft1k ( pretrained = False , * * kwargs ) :
model_args = dict ( depths = ( 3 , 3 , 9 , 3 ) , dims = ( 96 , 192 , 384 , 768 ) , * * kwargs )
model = _create_convnext ( ' convnext_tiny_384_in22ft1k ' , pretrained = pretrained , * * model_args )
return model
@register_model
def convnext_small_384_in22ft1k ( pretrained = False , * * kwargs ) :
model_args = dict ( depths = [ 3 , 3 , 27 , 3 ] , dims = [ 96 , 192 , 384 , 768 ] , * * kwargs )
model = _create_convnext ( ' convnext_small_384_in22ft1k ' , pretrained = pretrained , * * model_args )
return model
@register_model
def convnext_base_384_in22ft1k ( pretrained = False , * * kwargs ) :
model_args = dict ( depths = [ 3 , 3 , 27 , 3 ] , dims = [ 128 , 256 , 512 , 1024 ] , * * kwargs )
model = _create_convnext ( ' convnext_base_384_in22ft1k ' , pretrained = pretrained , * * model_args )
return model
@register_model
def convnext_large_384_in22ft1k ( pretrained = False , * * kwargs ) :
model_args = dict ( depths = [ 3 , 3 , 27 , 3 ] , dims = [ 192 , 384 , 768 , 1536 ] , * * kwargs )
model = _create_convnext ( ' convnext_large_384_in22ft1k ' , pretrained = pretrained , * * model_args )
return model
@register_model
def convnext_xlarge_384_in22ft1k ( pretrained = False , * * kwargs ) :
model_args = dict ( depths = [ 3 , 3 , 27 , 3 ] , dims = [ 256 , 512 , 1024 , 2048 ] , * * kwargs )
model = _create_convnext ( ' convnext_xlarge_384_in22ft1k ' , pretrained = pretrained , * * model_args )
return model
@register_model
def convnext_tiny_in22k ( pretrained = False , * * kwargs ) :
model_args = dict ( depths = ( 3 , 3 , 9 , 3 ) , dims = ( 96 , 192 , 384 , 768 ) , * * kwargs )
model = _create_convnext ( ' convnext_tiny_in22k ' , pretrained = pretrained , * * model_args )
return model
@register_model
def convnext_small_in22k ( pretrained = False , * * kwargs ) :
model_args = dict ( depths = [ 3 , 3 , 27 , 3 ] , dims = [ 96 , 192 , 384 , 768 ] , * * kwargs )
model = _create_convnext ( ' convnext_small_in22k ' , pretrained = pretrained , * * model_args )
return model
@register_model
def convnext_base_in22k ( pretrained = False , * * kwargs ) :
model_args = dict ( depths = [ 3 , 3 , 27 , 3 ] , dims = [ 128 , 256 , 512 , 1024 ] , * * kwargs )
model = _create_convnext ( ' convnext_base_in22k ' , pretrained = pretrained , * * model_args )
return model
@register_model
def convnext_large_in22k ( pretrained = False , * * kwargs ) :
model_args = dict ( depths = [ 3 , 3 , 27 , 3 ] , dims = [ 192 , 384 , 768 , 1536 ] , * * kwargs )
model = _create_convnext ( ' convnext_large_in22k ' , pretrained = pretrained , * * model_args )
return model
@register_model
def convnext_xlarge_in22k ( pretrained = False , * * kwargs ) :
def convnext_xlarge ( pretrained = False , * * kwargs ) :
model_args = dict ( depths = [ 3 , 3 , 27 , 3 ] , dims = [ 256 , 512 , 1024 , 2048 ] , * * kwargs )
model = _create_convnext ( ' convnext_xlarge _in22k ' , pretrained = pretrained , * * model_args )
model = _create_convnext ( ' convnext_xlarge ' , pretrained = pretrained , * * model_args )
return model