@ -34,7 +34,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE
OPENAI_CLIP_MEAN , OPENAI_CLIP_STD
from . helpers import build_model_with_cfg , named_apply , adapt_input_conv , checkpoint_seq
from . layers import PatchEmbed , Mlp , DropPath , trunc_normal_ , lecun_normal_
from . _pretrained import generate_default s
from . _pretrained import generate_default _cfg s
from . registry import register_model
_logger = logging . getLogger ( __name__ )
@ -492,7 +492,8 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
model . patch_embed . proj . weight . shape [ 1 ] , _n2p ( w [ f ' { prefix } embedding/kernel ' ] ) )
model . patch_embed . proj . weight . copy_ ( embed_conv_w )
model . patch_embed . proj . bias . copy_ ( _n2p ( w [ f ' { prefix } embedding/bias ' ] ) )
model . cls_token . copy_ ( _n2p ( w [ f ' { prefix } cls ' ] , t = False ) )
if model . cls_token is not None :
model . cls_token . copy_ ( _n2p ( w [ f ' { prefix } cls ' ] , t = False ) )
pos_embed_w = _n2p ( w [ f ' { prefix } Transformer/posembed_input/pos_embedding ' ] , t = False )
if pos_embed_w . shape != model . pos_embed . shape :
pos_embed_w = resize_pos_embed ( # resize pos embedding when different size from pretrained weights
@ -630,51 +631,74 @@ def _cfg(url='', **kwargs):
}
default_cfgs = generate_defaults ( {
# patch models (weights from official Google JAX impl)
' vit_tiny_patch16_224.augreg_in21k_ft_1k ' : _cfg (
default_cfgs = generate_default_cfgs ( {
# How to train your ViT (augreg) weights, pretrained on 21k FT on in1k
' vit_tiny_patch16_224.augreg_in21k_ft_in1k ' : _cfg (
url = ' https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz ' ,
custom_load = True ) ,
' vit_tiny_patch16_384.augreg_in21k_ft_ 1k' : _cfg (
' vit_tiny_patch16_384.augreg_in21k_ft_ in 1k' : _cfg (
url = ' https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz ' ,
custom_load = True , input_size = ( 3 , 384 , 384 ) , crop_pct = 1.0 ) ,
' vit_small_patch32_224.augreg_in21k_ft_ 1k' : _cfg (
' vit_small_patch32_224.augreg_in21k_ft_ in 1k' : _cfg (
url = ' https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz ' ,
custom_load = True ) ,
' vit_small_patch32_384.augreg_in21k_ft_ 1k' : _cfg (
' vit_small_patch32_384.augreg_in21k_ft_ in 1k' : _cfg (
url = ' https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz ' ,
custom_load = True , input_size = ( 3 , 384 , 384 ) , crop_pct = 1.0 ) ,
' vit_small_patch16_224.augreg_in21k_ft_ 1k' : _cfg (
' vit_small_patch16_224.augreg_in21k_ft_ in 1k' : _cfg (
url = ' https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz ' ,
custom_load = True ) ,
' vit_small_patch16_384.augreg_in21k_ft_ 1k' : _cfg (
' vit_small_patch16_384.augreg_in21k_ft_ in 1k' : _cfg (
url = ' https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz ' ,
custom_load = True , input_size = ( 3 , 384 , 384 ) , crop_pct = 1.0 ) ,
' vit_base_patch32_224.augreg_in21k_ft_ 1k' : _cfg (
' vit_base_patch32_224.augreg_in21k_ft_ in 1k' : _cfg (
url = ' https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz ' ,
custom_load = True ) ,
' vit_base_patch32_384.augreg_in21k_ft_ 1k' : _cfg (
' vit_base_patch32_384.augreg_in21k_ft_ in 1k' : _cfg (
url = ' https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz ' ,
custom_load = True , input_size = ( 3 , 384 , 384 ) , crop_pct = 1.0 ) ,
' vit_base_patch16_224.augreg_in21k_ft_ 1k' : _cfg (
' vit_base_patch16_224.augreg_in21k_ft_ in 1k' : _cfg (
url = ' https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz ' ,
custom_load = True ) ,
' vit_base_patch16_384.augreg_in21k_ft_ 1k' : _cfg (
' vit_base_patch16_384.augreg_in21k_ft_ in 1k' : _cfg (
url = ' https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz ' ,
custom_load = True , input_size = ( 3 , 384 , 384 ) , crop_pct = 1.0 ) ,
' vit_base_patch8_224.augreg_in21k_ft_ 1k' : _cfg (
' vit_base_patch8_224.augreg_in21k_ft_ in 1k' : _cfg (
url = ' https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz ' ,
custom_load = True ) ,
' vit_large_patch32_384.v1_in21k_ft_1k ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth ' ,
input_size = ( 3 , 384 , 384 ) , crop_pct = 1.0 ) ,
' vit_large_patch16_224.augreg_in21k_ft_1k ' : _cfg (
' vit_large_patch16_224.augreg_in21k_ft_in1k ' : _cfg (
url = ' https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz ' ,
custom_load = True ) ,
' vit_large_patch16_384.augreg_in21k_ft_ 1k' : _cfg (
' vit_large_patch16_384.augreg_in21k_ft_in1k ' : _cfg (
url = ' https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz ' ,
custom_load = True , input_size = ( 3 , 384 , 384 ) , crop_pct = 1.0 ) ,
# re-finetuned augreg 21k FT on in1k weights
' vit_base_patch16_224.augreg2_in21k_ft_in1k ' : _cfg (
file = ' b16_augreg-a-8.pth ' ) ,
' vit_base_patch16_384.augreg2_in21k_ft_in1k ' : _cfg (
url = ' ' ) ,
' vit_base_patch8_224.augreg2_in21k_ft_in1k ' : _cfg (
url = ' ' ) ,
# patch models (weights from official Google JAX impl) pretrained on in21k FT on in1k
' vit_base_patch16_224.orig_in21k_ft_in1k ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth ' ) ,
' vit_base_patch16_384.orig_in21k_ft_in1k ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth ' ) ,
' vit_large_patch32_384.orig_in21k_ft_in1k ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth ' ,
input_size = ( 3 , 384 , 384 ) , crop_pct = 1.0 ) ,
# How to train your ViT (augreg) weights trained on in1k
' vit_base_patch16_224.augreg_in1k ' : _cfg (
url = ' https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz ' ,
custom_load = True ) ,
' vit_base_patch16_384.augreg_in1k ' : _cfg (
url = ' https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz ' ,
custom_load = True , input_size = ( 3 , 384 , 384 ) , crop_pct = 1.0 ) ,
' vit_large_patch14_224.untrained ' : _cfg ( url = ' ' ) ,
' vit_huge_patch14_224.untrained ' : _cfg ( url = ' ' ) ,
' vit_giant_patch14_224.untrained ' : _cfg ( url = ' ' ) ,
@ -682,6 +706,15 @@ default_cfgs = generate_defaults({
# patch models, imagenet21k (weights from official Google JAX impl)
' vit_large_patch32_224.v1_in21k ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth ' ,
num_classes = 21843 ) ,
' vit_huge_patch14_224.v1_in21k ' : _cfg (
url = ' https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz ' ,
hf_hub_id = ' timm/vit_huge_patch14_224_in21k ' ,
custom_load = True , num_classes = 21843 ) ,
# How to train your ViT (augreg) weights, pretrained on in21k
' vit_tiny_patch16_224.augreg_in21k ' : _cfg (
url = ' https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz ' ,
custom_load = True , num_classes = 21843 ) ,
@ -700,16 +733,9 @@ default_cfgs = generate_defaults({
' vit_base_patch8_224.augreg_in21k ' : _cfg (
url = ' https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz ' ,
custom_load = True , num_classes = 21843 ) ,
' vit_large_patch32_224.v1_in21k ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth ' ,
num_classes = 21843 ) ,
' vit_large_patch16_224.augreg_in21k ' : _cfg (
url = ' https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz ' ,
custom_load = True , num_classes = 21843 ) ,
' vit_huge_patch14_224.v1_in21k ' : _cfg (
url = ' https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz ' ,
hf_hub_id = ' timm/vit_huge_patch14_224_in21k ' ,
custom_load = True , num_classes = 21843 ) ,
# SAM trained models (https://arxiv.org/abs/2106.01548)
' vit_base_patch32_224.sam ' : _cfg (
@ -736,7 +762,7 @@ default_cfgs = generate_defaults({
' vit_base_patch16_224_miil.in21k ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth ' ,
mean = ( 0. , 0. , 0. ) , std = ( 1. , 1. , 1. ) , crop_pct = 0.875 , interpolation = ' bilinear ' , num_classes = 11221 ) ,
' vit_base_patch16_224_miil.in21k_ft_ 1k' : _cfg (
' vit_base_patch16_224_miil.in21k_ft_ in 1k' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_1k_miil_84_4-2deb18e3.pth ' ,
mean = ( 0. , 0. , 0. ) , std = ( 1. , 1. , 1. ) , crop_pct = 0.875 , interpolation = ' bilinear ' ) ,
@ -744,14 +770,15 @@ default_cfgs = generate_defaults({
' vit_base_patch16_rpn_224.in1k ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth ' ) ,
' vit_medium_patch16_gap_240.in12k ' : _cfg (
url= ' ' ,
hf_hub_id= ' timm/vit_medium_patch16_gap_240.in12k ' ,
input_size = ( 3 , 240 , 240 ) , crop_pct = 0.95 , num_classes = 11821 ) ,
' vit_medium_patch16_gap_256.in12k_ft_ 1k' : _cfg (
url= ' ' ,
' vit_medium_patch16_gap_256.in12k_ft_ in 1k' : _cfg (
hf_hub_id= ' timm/vit_medium_patch16_gap_256.in12k_ft_in1k ' ,
input_size = ( 3 , 256 , 256 ) , crop_pct = 0.95 ) ,
' vit_medium_patch16_gap_384.in12k_ft_1k ' : _cfg (
url = ' ' ,
input_size = ( 3 , 384 , 384 ) , crop_pct = 0.95 ) ,
' vit_medium_patch16_gap_384.in12k_ft_in1k ' : _cfg (
hf_hub_id = ' timm/vit_medium_patch16_gap_384.in12k_ft_in1k ' ,
input_size = ( 3 , 384 , 384 ) , crop_pct = 0.95 , crop_mode = ' squash ' ) ,
' vit_base_patch16_gap_224 ' : _cfg ( ) ,
# CLIP pretrained image tower and related fine-tuned weights
' vit_base_patch32_clip_224.laion2b ' : _cfg (
@ -781,15 +808,16 @@ default_cfgs = generate_defaults({
' vit_base_patch32_clip_384.laion2b_ft_in1k ' : _cfg (
hf_hub_id = ' timm/vit_base_patch32_clip_384.laion2b_ft_in1k ' ,
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 , input_size = ( 3 , 384 , 384 ) ) ,
' vit_base_patch32_clip_448.laion2b_ft_in1k ' : _cfg (
hf_hub_id = ' timm/vit_base_patch32_clip_448.laion2b_ft_in1k ' ,
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 , input_size = ( 3 , 448 , 448 ) ) ,
' vit_base_patch16_clip_224.laion2b_ft_in1k ' : _cfg (
hf_hub_id = ' timm/vit_base_patch16_clip_224.laion2b_ft_in1k ' ,
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 ) ,
' vit_base_patch16_clip_384.laion2b_ft_in1k ' : _cfg (
hf_hub_id = ' timm/vit_base_patch16_clip_384.laion2b_ft_in1k ' ,
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 , input_size = ( 3 , 384 , 384 ) ) ,
' vit_base_patch32_clip_448.laion2b_ft_in1k ' : _cfg (
hf_hub_id = ' timm/vit_base_patch32_clip_448.laion2b_ft_in1k ' ,
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 , input_size = ( 3 , 448 , 448 ) ) ,
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD ,
crop_pct = 1.0 , input_size = ( 3 , 384 , 384 ) , crop_mode = ' squash ' ) ,
' vit_large_patch14_clip_224.laion2b_ft_in1k ' : _cfg (
hf_hub_id = ' timm/vit_large_patch14_clip_224.laion2b_ft_in1k ' ,
mean = IMAGENET_INCEPTION_MEAN , std = IMAGENET_INCEPTION_STD , crop_pct = 1.0 ) ,
@ -816,10 +844,11 @@ default_cfgs = generate_defaults({
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 , input_size = ( 3 , 448 , 448 ) ) ,
' vit_base_patch16_clip_224.laion2b_ft_in12k_in1k ' : _cfg (
hf_hub_id = ' timm/vit_base_patch16_clip_224.laion2b_ft_in12k_in1k ' ,
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 ) ,
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 0.95 ) ,
' vit_base_patch16_clip_384.laion2b_ft_in12k_in1k ' : _cfg (
hf_hub_id = ' timm/vit_base_patch16_clip_384.laion2b_ft_in12k_in1k ' ,
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 , input_size = ( 3 , 384 , 384 ) ) ,
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD ,
crop_pct = 1.0 , input_size = ( 3 , 384 , 384 ) , crop_mode = ' squash ' ) ,
' vit_large_patch14_clip_224.laion2b_ft_in12k_in1k ' : _cfg (
hf_hub_id = ' timm/vit_large_patch14_clip_224.laion2b_ft_in12k_in1k ' ,
mean = IMAGENET_INCEPTION_MEAN , std = IMAGENET_INCEPTION_STD , crop_pct = 1.0 ) ,
@ -866,7 +895,8 @@ default_cfgs = generate_defaults({
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD ) ,
' vit_base_patch16_clip_384.openai_ft_in1k ' : _cfg (
hf_hub_id = ' timm/vit_base_patch16_clip_384.openai_ft_in1k ' ,
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 , input_size = ( 3 , 384 , 384 ) ) ,
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD ,
crop_pct = 1.0 , input_size = ( 3 , 384 , 384 ) , crop_mode = ' squash ' ) ,
' vit_large_patch14_clip_224.openai_ft_in1k ' : _cfg (
hf_hub_id = ' timm/vit_large_patch14_clip_224.openai_ft_in1k ' ,
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 ) ,
@ -876,10 +906,15 @@ default_cfgs = generate_defaults({
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD ) ,
' vit_base_patch32_clip_384.openai_ft_in12k_in1k ' : _cfg (
hf_hub_id = ' timm/vit_base_patch32_clip_384.openai_ft_in12k_in1k ' ,
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 , input_size = ( 3 , 384 , 384 ) ) ,
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD ,
crop_pct = 0.95 , input_size = ( 3 , 384 , 384 ) , crop_mode = ' squash ' ) ,
' vit_base_patch16_clip_224.openai_ft_in12k_in1k ' : _cfg (
#hf_hub_id='timm/vit_base_patch16_clip_224.openai_ft_in12k_in1k',
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD ) ,
hf_hub_id = ' timm/vit_base_patch16_clip_224.openai_ft_in12k_in1k ' ,
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 0.95 ) ,
' vit_base_patch16_clip_384.openai_ft_in12k_in1k ' : _cfg (
hf_hub_id = ' timm/vit_base_patch16_clip_384.openai_ft_in12k_in1k ' ,
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD ,
crop_pct = 0.95 , input_size = ( 3 , 384 , 384 ) , crop_mode = ' squash ' ) ,
' vit_large_patch14_clip_224.openai_ft_in12k_in1k ' : _cfg (
hf_hub_id = ' timm/vit_large_patch14_clip_224.openai_ft_in12k_in1k ' ,
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 ) ,
@ -1118,37 +1153,48 @@ def vit_base_patch16_224_miil(pretrained=False, **kwargs):
@register_model
def vit_medium_patch16_gap_240 ( pretrained = False , * * kwargs ) :
""" ViT- Base (ViT-M/16) w/o class token, w/ avg-pool @ 240x240
""" ViT- Medium (ViT-M/16) w/o class token, w/ avg-pool @ 240x240
"""
model_kwargs = dict (
patch_size = 16 , embed_dim = 512 , depth = 12 , num_heads = 8 , class_token = False ,
global_pool = ' avg ' , qkv_bias = False , init_values = 1e-6 , fc_norm = False , * * kwargs )
global_pool = kwargs . get ( ' global_pool ' , ' avg ' ) , qkv_bias = False , init_values = 1e-6 , fc_norm = False , * * kwargs )
model = _create_vision_transformer ( ' vit_medium_patch16_gap_240 ' , pretrained = pretrained , * * model_kwargs )
return model
@register_model
def vit_medium_patch16_gap_256 ( pretrained = False , * * kwargs ) :
""" ViT- Base (ViT-M/16) w/o class token, w/ avg-pool @ 256x256
""" ViT- Medium (ViT-M/16) w/o class token, w/ avg-pool @ 256x256
"""
model_kwargs = dict (
patch_size = 16 , embed_dim = 512 , depth = 12 , num_heads = 8 , class_token = False ,
global_pool = ' avg ' , qkv_bias = False , init_values = 1e-6 , fc_norm = False , * * kwargs )
global_pool = kwargs . get ( ' global_pool ' , ' avg ' ) , qkv_bias = False , init_values = 1e-6 , fc_norm = False , * * kwargs )
model = _create_vision_transformer ( ' vit_medium_patch16_gap_256 ' , pretrained = pretrained , * * model_kwargs )
return model
@register_model
def vit_medium_patch16_gap_384 ( pretrained = False , * * kwargs ) :
""" ViT- Base (ViT-M/16) w/o class token, w/ avg-pool @ 384x384
""" ViT- Medium (ViT-M/16) w/o class token, w/ avg-pool @ 384x384
"""
model_kwargs = dict (
patch_size = 16 , embed_dim = 512 , depth = 12 , num_heads = 8 , class_token = False ,
global_pool = ' avg ' , qkv_bias = False , init_values = 1e-6 , fc_norm = False , * * kwargs )
global_pool = kwargs . get ( ' global_pool ' , ' avg ' ) , qkv_bias = False , init_values = 1e-6 , fc_norm = False , * * kwargs )
model = _create_vision_transformer ( ' vit_medium_patch16_gap_384 ' , pretrained = pretrained , * * model_kwargs )
return model
@register_model
def vit_base_patch16_gap_224 ( pretrained = False , * * kwargs ) :
""" ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 256x256
"""
model_kwargs = dict (
patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 16 , class_token = False ,
global_pool = kwargs . get ( ' global_pool ' , ' avg ' ) , fc_norm = False , * * kwargs )
model = _create_vision_transformer ( ' vit_base_patch16_gap_224 ' , pretrained = pretrained , * * model_kwargs )
return model
@register_model
def vit_base_patch32_clip_224 ( pretrained = False , * * kwargs ) :
""" ViT-B/32 CLIP image tower @ 224x224