""" Vision Transformer (ViT) in PyTorch A PyTorch implement of Vision Transformers as described in: 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929 `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` - https://arxiv.org/abs/2106.10270 The official jax code is released and available at https://github.com/google-research/vision_transformer Acknowledgments: * The paper authors for releasing code and weights, thanks! * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out for some einops/einsum fun * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT * Bert reference code checks against Huggingface Transformers and Tensorflow Bert Hacked together by / Copyright 2020, Ross Wightman """ import logging import math from collections import OrderedDict from functools import partial from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ from ._builder import build_model_with_cfg from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv from ._pretrained import generate_default_cfgs from ._registry import register_model __all__ = ['VisionTransformer'] # model_registry will add each entrypoint fn to this _logger = logging.getLogger(__name__) class Attention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class LayerScale(nn.Module): def __init__(self, dim, init_values=1e-5, inplace=False): super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x): return x.mul_(self.gamma) if self.inplace else x * self.gamma class Block(nn.Module): def __init__( self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm ): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return x class ResPostBlock(nn.Module): def __init__( self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm ): super().__init__() self.init_values = init_values self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) self.norm1 = norm_layer(dim) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) self.norm2 = norm_layer(dim) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.init_weights() def init_weights(self): # NOTE this init overrides that base model init with specific changes for the block type if self.init_values is not None: nn.init.constant_(self.norm1.weight, self.init_values) nn.init.constant_(self.norm2.weight, self.init_values) def forward(self, x): x = x + self.drop_path1(self.norm1(self.attn(x))) x = x + self.drop_path2(self.norm2(self.mlp(x))) return x class ParallelBlock(nn.Module): def __init__( self, dim, num_heads, num_parallel=2, mlp_ratio=4., qkv_bias=False, init_values=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm ): super().__init__() self.num_parallel = num_parallel self.attns = nn.ModuleList() self.ffns = nn.ModuleList() for _ in range(num_parallel): self.attns.append(nn.Sequential(OrderedDict([ ('norm', norm_layer(dim)), ('attn', Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)), ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) ]))) self.ffns.append(nn.Sequential(OrderedDict([ ('norm', norm_layer(dim)), ('mlp', Mlp(dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)), ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) ]))) def _forward_jit(self, x): x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0) x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0) return x @torch.jit.ignore def _forward(self, x): x = x + sum(attn(x) for attn in self.attns) x = x + sum(ffn(x) for ffn in self.ffns) return x def forward(self, x): if torch.jit.is_scripting() or torch.jit.is_tracing(): return self._forward_jit(x) else: return self._forward(x) class VisionTransformer(nn.Module): """ Vision Transformer A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - https://arxiv.org/abs/2010.11929 """ def __init__( self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None, class_token=True, no_embed_class=False, pre_norm=False, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block, ): """ Args: img_size (int, tuple): input image size patch_size (int, tuple): patch size in_chans (int): number of input channels num_classes (int): number of classes for classification head global_pool (str): type of global pooling for final sequence (default: 'token') embed_dim (int): embedding dimension depth (int): depth of transformer num_heads (int): number of attention heads mlp_ratio (int): ratio of mlp hidden dim to embedding dim qkv_bias (bool): enable bias for qkv if True init_values: (float): layer-scale init values class_token (bool): use class token fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None) drop_rate (float): dropout rate attn_drop_rate (float): attention dropout rate drop_path_rate (float): stochastic depth rate weight_init (str): weight init scheme embed_layer (nn.Module): patch embedding layer norm_layer: (nn.Module): normalization layer act_layer: (nn.Module): MLP activation layer """ super().__init__() assert global_pool in ('', 'avg', 'token') assert class_token or global_pool != 'token' use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) act_layer = act_layer or nn.GELU self.num_classes = num_classes self.global_pool = global_pool self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_prefix_tokens = 1 if class_token else 0 self.no_embed_class = no_embed_class self.grad_checkpointing = False self.patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) ) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) self.pos_drop = nn.Dropout(p=drop_rate) self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.blocks = nn.Sequential(*[ block_fn( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer ) for i in range(depth)]) self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() # Classifier Head self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() if weight_init != 'skip': self.init_weights(weight_init) def init_weights(self, mode=''): assert mode in ('jax', 'jax_nlhb', 'moco', '') head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. trunc_normal_(self.pos_embed, std=.02) if self.cls_token is not None: nn.init.normal_(self.cls_token, std=1e-6) named_apply(get_init_weights_vit(mode, head_bias), self) def _init_weights(self, m): # this fn left here for compat with downstream users init_weights_vit_timm(m) @torch.jit.ignore() def load_pretrained(self, checkpoint_path, prefix=''): _load_weights(self, checkpoint_path, prefix) @torch.jit.ignore def no_weight_decay(self): return {'pos_embed', 'cls_token', 'dist_token'} @torch.jit.ignore def group_matcher(self, coarse=False): return dict( stem=r'^cls_token|pos_embed|patch_embed', # stem and embed blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] ) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable @torch.jit.ignore def get_classifier(self): return self.head def reset_classifier(self, num_classes: int, global_pool=None): self.num_classes = num_classes if global_pool is not None: assert global_pool in ('', 'avg', 'token') self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() def _pos_embed(self, x): if self.no_embed_class: # deit-3, updated JAX (big vision) # position embedding does not overlap with class token, add then concat x = x + self.pos_embed if self.cls_token is not None: x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) else: # original timm, JAX, and deit vit impl # pos_embed has entry for class token, concat then add if self.cls_token is not None: x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) x = x + self.pos_embed return self.pos_drop(x) def forward_features(self, x): x = self.patch_embed(x) x = self._pos_embed(x) x = self.norm_pre(x) if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) else: x = self.blocks(x) x = self.norm(x) return x def forward_head(self, x, pre_logits: bool = False): if self.global_pool: x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] x = self.fc_norm(x) return x if pre_logits else self.head(x) def forward(self, x): x = self.forward_features(x) x = self.forward_head(x) return x def init_weights_vit_timm(module: nn.Module, name: str = ''): """ ViT weight initialization, original timm impl (for reproducibility) """ if isinstance(module, nn.Linear): trunc_normal_(module.weight, std=.02) if module.bias is not None: nn.init.zeros_(module.bias) elif hasattr(module, 'init_weights'): module.init_weights() def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.): """ ViT weight initialization, matching JAX (Flax) impl """ if isinstance(module, nn.Linear): if name.startswith('head'): nn.init.zeros_(module.weight) nn.init.constant_(module.bias, head_bias) else: nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias) elif isinstance(module, nn.Conv2d): lecun_normal_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif hasattr(module, 'init_weights'): module.init_weights() def init_weights_vit_moco(module: nn.Module, name: str = ''): """ ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed """ if isinstance(module, nn.Linear): if 'qkv' in name: # treat the weights of Q, K, V separately val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1])) nn.init.uniform_(module.weight, -val, val) else: nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif hasattr(module, 'init_weights'): module.init_weights() def get_init_weights_vit(mode='jax', head_bias: float = 0.): if 'jax' in mode: return partial(init_weights_vit_jax, head_bias=head_bias) elif 'moco' in mode: return init_weights_vit_moco else: return init_weights_vit_timm @torch.no_grad() def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): """ Load weights from .npz checkpoints for official Google Brain Flax implementation """ import numpy as np def _n2p(w, t=True): if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: w = w.flatten() if t: if w.ndim == 4: w = w.transpose([3, 2, 0, 1]) elif w.ndim == 3: w = w.transpose([2, 0, 1]) elif w.ndim == 2: w = w.transpose([1, 0]) return torch.from_numpy(w) w = np.load(checkpoint_path) if not prefix and 'opt/target/embedding/kernel' in w: prefix = 'opt/target/' if hasattr(model.patch_embed, 'backbone'): # hybrid backbone = model.patch_embed.backbone stem_only = not hasattr(backbone, 'stem') stem = backbone if stem_only else backbone.stem stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) if not stem_only: for i, stage in enumerate(backbone.stages): for j, block in enumerate(stage.blocks): bp = f'{prefix}block{i + 1}/unit{j + 1}/' for r in range(3): getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) if block.downsample is not None: block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) else: embed_conv_w = adapt_input_conv( model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) model.patch_embed.proj.weight.copy_(embed_conv_w) model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) if model.cls_token is not None: model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) if pos_embed_w.shape != model.pos_embed.shape: pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights pos_embed_w, model.pos_embed, getattr(model, 'num_prefix_tokens', 1), model.patch_embed.grid_size ) model.pos_embed.copy_(pos_embed_w) model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) # NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) for i, block in enumerate(model.blocks.children()): block_prefix = f'{prefix}Transformer/encoderblock_{i}/' mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) block.attn.qkv.weight.copy_(torch.cat([ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) block.attn.qkv.bias.copy_(torch.cat([ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) for r in range(2): getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()): # Rescale the grid of position embeddings when loading from state_dict. Adapted from # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) ntok_new = posemb_new.shape[1] if num_prefix_tokens: posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[0, num_prefix_tokens:] ntok_new -= num_prefix_tokens else: posemb_prefix, posemb_grid = posemb[:, :0], posemb[0] gs_old = int(math.sqrt(len(posemb_grid))) if not len(gs_new): # backwards compatibility gs_new = [int(math.sqrt(ntok_new))] * 2 assert len(gs_new) >= 2 _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new) posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False) posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) posemb = torch.cat([posemb_prefix, posemb_grid], dim=1) return posemb def _convert_openai_clip(state_dict, model): out_dict = {} swaps = [ ('visual.', ''), ('conv1', 'patch_embed.proj'), ('positional_embedding', 'pos_embed'), ('transformer.resblocks.', 'blocks.'), ('ln_pre', 'norm_pre'), ('ln_post', 'norm'), ('ln_', 'norm'), ('in_proj_', 'qkv.'), ('out_proj', 'proj'), ('mlp.c_fc', 'mlp.fc1'), ('mlp.c_proj', 'mlp.fc2'), ] for k, v in state_dict.items(): if not k.startswith('visual.'): continue for sp in swaps: k = k.replace(sp[0], sp[1]) if k == 'proj': k = 'head.weight' v = v.transpose(0, 1) out_dict['head.bias'] = torch.zeros(v.shape[0]) elif k == 'class_embedding': k = 'cls_token' v = v.unsqueeze(0).unsqueeze(1) elif k == 'pos_embed': v = v.unsqueeze(0) if v.shape[1] != model.pos_embed.shape[1]: # To resize pos embedding when using model at different size from pretrained weights v = resize_pos_embed( v, model.pos_embed, 0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1), model.patch_embed.grid_size ) out_dict[k] = v return out_dict def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False): """ convert patch embedding weight from manual patchify + linear proj to conv""" import re out_dict = {} if 'model' in state_dict: # For deit models state_dict = state_dict['model'] if 'visual.class_embedding' in state_dict: return _convert_openai_clip(state_dict, model) for k, v in state_dict.items(): if 'patch_embed.proj.weight' in k and len(v.shape) < 4: # For old models that I trained prior to conv based patchification O, I, H, W = model.patch_embed.proj.weight.shape v = v.reshape(O, -1, H, W) elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]: # To resize pos embedding when using model at different size from pretrained weights v = resize_pos_embed( v, model.pos_embed, 0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1), model.patch_embed.grid_size ) elif adapt_layer_scale and 'gamma_' in k: # remap layer-scale gamma into sub-module (deit3 models) k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k) elif 'pre_logits' in k: # NOTE representation layer removed as not used in latest 21k/1k pretrained weights continue out_dict[k] = v return out_dict def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, 'first_conv': 'patch_embed.proj', 'classifier': 'head', **kwargs } default_cfgs = generate_default_cfgs({ # How to train your ViT (augreg) weights, pretrained on 21k FT on in1k 'vit_tiny_patch16_224.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', custom_load=True), 'vit_tiny_patch16_384.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), 'vit_small_patch32_224.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', custom_load=True), 'vit_small_patch32_384.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), 'vit_small_patch16_224.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', custom_load=True), 'vit_small_patch16_384.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), 'vit_base_patch32_224.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', custom_load=True), 'vit_base_patch32_384.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), 'vit_base_patch16_224.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz', custom_load=True), 'vit_base_patch16_384.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), 'vit_base_patch8_224.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz', custom_load=True), 'vit_large_patch16_224.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz', custom_load=True), 'vit_large_patch16_384.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), # re-finetuned augreg 21k FT on in1k weights 'vit_base_patch16_224.augreg2_in21k_ft_in1k': _cfg( file='b16_augreg-a-8.pth'), 'vit_base_patch16_384.augreg2_in21k_ft_in1k': _cfg( url=''), 'vit_base_patch8_224.augreg2_in21k_ft_in1k': _cfg( url=''), # patch models (weights from official Google JAX impl) pretrained on in21k FT on in1k 'vit_base_patch16_224.orig_in21k_ft_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth'), 'vit_base_patch16_384.orig_in21k_ft_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth'), 'vit_large_patch32_384.orig_in21k_ft_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', input_size=(3, 384, 384), crop_pct=1.0), # How to train your ViT (augreg) weights trained on in1k 'vit_base_patch16_224.augreg_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz', custom_load=True), 'vit_base_patch16_384.augreg_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), 'vit_large_patch14_224.untrained': _cfg(url=''), 'vit_huge_patch14_224.untrained': _cfg(url=''), 'vit_giant_patch14_224.untrained': _cfg(url=''), 'vit_gigantic_patch14_224.untrained': _cfg(url=''), # patch models, imagenet21k (weights from official Google JAX impl) 'vit_large_patch32_224.v1_in21k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', num_classes=21843), 'vit_huge_patch14_224.v1_in21k': _cfg( url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz', hf_hub_id='timm/vit_huge_patch14_224_in21k', custom_load=True, num_classes=21843), # How to train your ViT (augreg) weights, pretrained on in21k 'vit_tiny_patch16_224.augreg_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', custom_load=True, num_classes=21843), 'vit_small_patch32_224.augreg_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', custom_load=True, num_classes=21843), 'vit_small_patch16_224.augreg_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', custom_load=True, num_classes=21843), 'vit_base_patch32_224.augreg_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz', custom_load=True, num_classes=21843), 'vit_base_patch16_224.augreg_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', custom_load=True, num_classes=21843), 'vit_base_patch8_224.augreg_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', custom_load=True, num_classes=21843), 'vit_large_patch16_224.augreg_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz', custom_load=True, num_classes=21843), # SAM trained models (https://arxiv.org/abs/2106.01548) 'vit_base_patch32_224.sam': _cfg( url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz', custom_load=True), 'vit_base_patch16_224.sam': _cfg( url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz', custom_load=True), # DINO pretrained - https://arxiv.org/abs/2104.14294 (no classifier head, for fine-tune only) 'vit_small_patch16_224.dino': _cfg( url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), 'vit_small_patch8_224.dino': _cfg( url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), 'vit_base_patch16_224.dino': _cfg( url='https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), 'vit_base_patch8_224.dino': _cfg( url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), # ViT ImageNet-21K-P pretraining by MILL 'vit_base_patch16_224_miil.in21k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth', mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear', num_classes=11221), 'vit_base_patch16_224_miil.in21k_ft_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_1k_miil_84_4-2deb18e3.pth', mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear'), # custom timm variants 'vit_base_patch16_rpn_224.in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth'), 'vit_medium_patch16_gap_240.in12k': _cfg( hf_hub_id='timm/vit_medium_patch16_gap_240.in12k', input_size=(3, 240, 240), crop_pct=0.95, num_classes=11821), 'vit_medium_patch16_gap_256.in12k_ft_in1k': _cfg( hf_hub_id='timm/vit_medium_patch16_gap_256.in12k_ft_in1k', input_size=(3, 256, 256), crop_pct=0.95), 'vit_medium_patch16_gap_384.in12k_ft_in1k': _cfg( hf_hub_id='timm/vit_medium_patch16_gap_384.in12k_ft_in1k', input_size=(3, 384, 384), crop_pct=0.95, crop_mode='squash'), 'vit_base_patch16_gap_224': _cfg(), # CLIP pretrained image tower and related fine-tuned weights 'vit_base_patch32_clip_224.laion2b': _cfg( hf_hub_id='laion/CLIP-ViT-B-32-laion2B-s34B-b79K', hf_hub_filename='open_clip_pytorch_model.bin', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), 'vit_base_patch16_clip_224.laion2b': _cfg( #hf_hub_id='laion/CLIP-ViT-B-16-laion2B-s34B-b88K', hf_hub_filename='open_clip_pytorch_model.bin', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), 'vit_large_patch14_clip_224.laion2b': _cfg( hf_hub_id='laion/CLIP-ViT-L-14-laion2B-s32B-b82K', hf_hub_filename='open_clip_pytorch_model.bin', mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=768), 'vit_huge_patch14_clip_224.laion2b': _cfg( hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K', hf_hub_filename='open_clip_pytorch_model.bin', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024), 'vit_giant_patch14_clip_224.laion2b': _cfg( hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K', hf_hub_filename='open_clip_pytorch_model.bin', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024), 'vit_base_patch32_clip_224.laion2b_ft_in1k': _cfg( hf_hub_id='timm/vit_base_patch32_clip_224.laion2b_ft_in1k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), 'vit_base_patch16_clip_224.laion2b_ft_in1k': _cfg( hf_hub_id='timm/vit_base_patch16_clip_224.laion2b_ft_in1k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), 'vit_base_patch16_clip_384.laion2b_ft_in1k': _cfg( hf_hub_id='timm/vit_base_patch16_clip_384.laion2b_ft_in1k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'), 'vit_large_patch14_clip_224.laion2b_ft_in1k': _cfg( hf_hub_id='timm/vit_large_patch14_clip_224.laion2b_ft_in1k', mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0), 'vit_large_patch14_clip_336.laion2b_ft_in1k': _cfg( hf_hub_id='timm/vit_large_patch14_clip_336.laion2b_ft_in1k', mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), 'vit_huge_patch14_clip_224.laion2b_ft_in1k': _cfg( hf_hub_id='timm/vit_huge_patch14_clip_224.laion2b_ft_in1k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), 'vit_huge_patch14_clip_336.laion2b_ft_in1k': _cfg( hf_hub_id='', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), 'vit_base_patch32_clip_224.laion2b_ft_in12k_in1k': _cfg( hf_hub_id='timm/vit_base_patch32_clip_224.laion2b_ft_in12k_in1k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), 'vit_base_patch32_clip_384.laion2b_ft_in12k_in1k': _cfg( hf_hub_id='timm/vit_base_patch32_clip_384.laion2b_ft_in12k_in1k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384)), 'vit_base_patch32_clip_448.laion2b_ft_in12k_in1k': _cfg( hf_hub_id='timm/vit_base_patch32_clip_448.laion2b_ft_in12k_in1k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 448, 448)), 'vit_base_patch16_clip_224.laion2b_ft_in12k_in1k': _cfg( hf_hub_id='timm/vit_base_patch16_clip_224.laion2b_ft_in12k_in1k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95), 'vit_base_patch16_clip_384.laion2b_ft_in12k_in1k': _cfg( hf_hub_id='timm/vit_base_patch16_clip_384.laion2b_ft_in12k_in1k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'), 'vit_large_patch14_clip_224.laion2b_ft_in12k_in1k': _cfg( hf_hub_id='timm/vit_large_patch14_clip_224.laion2b_ft_in12k_in1k', mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0), 'vit_large_patch14_clip_336.laion2b_ft_in12k_in1k': _cfg( hf_hub_id='timm/vit_large_patch14_clip_336.laion2b_ft_in12k_in1k', mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), 'vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k': _cfg( hf_hub_id='timm/vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), 'vit_huge_patch14_clip_336.laion2b_ft_in12k_in1k': _cfg( hf_hub_id='timm/vit_huge_patch14_clip_336.laion2b_ft_in12k_in1k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), 'vit_base_patch32_clip_224.laion2b_ft_in12k': _cfg( #hf_hub_id='timm/vit_base_patch32_clip_224.laion2b_ft_in12k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), 'vit_base_patch16_clip_224.laion2b_ft_in12k': _cfg( hf_hub_id='timm/vit_base_patch16_clip_224.laion2b_ft_in12k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), 'vit_large_patch14_clip_224.laion2b_ft_in12k': _cfg( hf_hub_id='timm/vit_large_patch14_clip_224.laion2b_ft_in12k', mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=11821), 'vit_huge_patch14_clip_224.laion2b_ft_in12k': _cfg( hf_hub_id='timm/vit_huge_patch14_clip_224.laion2b_ft_in12k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821), 'vit_base_patch32_clip_224.openai': _cfg( hf_hub_id='timm/clip_vit_base_patch32_224.openai', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), 'vit_base_patch16_clip_224.openai': _cfg( hf_hub_id='timm/clip_vit_base_patch16_224.openai', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), 'vit_large_patch14_clip_224.openai': _cfg( hf_hub_id='timm/clip_vit_large_patch14_224.openai', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768), 'vit_base_patch32_clip_224.openai_ft_in1k': _cfg( hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in1k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), 'vit_base_patch16_clip_224.openai_ft_in1k': _cfg( hf_hub_id='timm/vit_base_patch16_clip_224.openai_ft_in1k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), 'vit_base_patch16_clip_384.openai_ft_in1k': _cfg( hf_hub_id='timm/vit_base_patch16_clip_384.openai_ft_in1k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'), 'vit_large_patch14_clip_224.openai_ft_in1k': _cfg( hf_hub_id='timm/vit_large_patch14_clip_224.openai_ft_in1k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), 'vit_base_patch32_clip_224.openai_ft_in12k_in1k': _cfg( #hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k_in1k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), 'vit_base_patch32_clip_384.openai_ft_in12k_in1k': _cfg( hf_hub_id='timm/vit_base_patch32_clip_384.openai_ft_in12k_in1k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'), 'vit_base_patch16_clip_224.openai_ft_in12k_in1k': _cfg( hf_hub_id='timm/vit_base_patch16_clip_224.openai_ft_in12k_in1k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95), 'vit_base_patch16_clip_384.openai_ft_in12k_in1k': _cfg( hf_hub_id='timm/vit_base_patch16_clip_384.openai_ft_in12k_in1k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'), 'vit_large_patch14_clip_224.openai_ft_in12k_in1k': _cfg( hf_hub_id='timm/vit_large_patch14_clip_224.openai_ft_in12k_in1k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), 'vit_large_patch14_clip_336.openai_ft_in12k_in1k': _cfg( hf_hub_id='timm/vit_large_patch14_clip_336.openai_ft_in12k_in1k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), 'vit_base_patch32_clip_224.openai_ft_in12k': _cfg( #hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), 'vit_base_patch16_clip_224.openai_ft_in12k': _cfg( hf_hub_id='timm/vit_base_patch16_clip_224.openai_ft_in12k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), 'vit_large_patch14_clip_224.openai_ft_in12k': _cfg( hf_hub_id='timm/vit_large_patch14_clip_224.openai_ft_in12k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821), # experimental (may be removed) 'vit_base_patch32_plus_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95), 'vit_base_patch16_plus_240': _cfg(url='', input_size=(3, 240, 240), crop_pct=0.95), 'vit_small_patch16_36x1_224': _cfg(url=''), 'vit_small_patch16_18x2_224': _cfg(url=''), 'vit_base_patch16_18x2_224': _cfg(url=''), # EVA fine-tuned weights from MAE style MIM - EVA-CLIP target pretrain # https://github.com/baaivision/EVA/blob/7ecf2c0a370d97967e86d047d7af9188f78d2df3/eva/README.md#eva-l-learning-better-mim-representations-from-eva-clip 'eva_large_patch14_196.in22k_ft_in22k_in1k': _cfg( hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_21k_to_1k_ft_88p6.pt', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, input_size=(3, 196, 196), crop_pct=1.0), 'eva_large_patch14_336.in22k_ft_in22k_in1k': _cfg( hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_21k_to_1k_ft_89p2.pt', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'), 'eva_large_patch14_196.in22k_ft_in1k': _cfg( hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_1k_ft_88p0.pt', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, input_size=(3, 196, 196), crop_pct=1.0), 'eva_large_patch14_336.in22k_ft_in1k': _cfg( hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_1k_ft_88p65.pt', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'), }) def _create_vision_transformer(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') return build_model_with_cfg( VisionTransformer, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, **kwargs, ) @register_model def vit_tiny_patch16_224(pretrained=False, **kwargs): """ ViT-Tiny (Vit-Ti/16) """ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) return model @register_model def vit_tiny_patch16_384(pretrained=False, **kwargs): """ ViT-Tiny (Vit-Ti/16) @ 384x384. """ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs) return model @register_model def vit_small_patch32_224(pretrained=False, **kwargs): """ ViT-Small (ViT-S/32) """ model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **model_kwargs) return model @register_model def vit_small_patch32_384(pretrained=False, **kwargs): """ ViT-Small (ViT-S/32) at 384x384. """ model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **model_kwargs) return model @register_model def vit_small_patch16_224(pretrained=False, **kwargs): """ ViT-Small (ViT-S/16) """ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs) return model @register_model def vit_small_patch16_384(pretrained=False, **kwargs): """ ViT-Small (ViT-S/16) """ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs) return model @register_model def vit_small_patch8_224(pretrained=False, **kwargs): """ ViT-Small (ViT-S/8) """ model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6, **kwargs) model = _create_vision_transformer('vit_small_patch8_224', pretrained=pretrained, **model_kwargs) return model @register_model def vit_base_patch32_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k, source https://github.com/google-research/vision_transformer. """ model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs) return model @register_model def vit_base_patch32_384(pretrained=False, **kwargs): """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs) return model @register_model def vit_base_patch16_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs) return model @register_model def vit_base_patch16_384(pretrained=False, **kwargs): """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs) return model @register_model def vit_base_patch8_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer('vit_base_patch8_224', pretrained=pretrained, **model_kwargs) return model @register_model def vit_large_patch32_224(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. """ model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs) return model @register_model def vit_large_patch32_384(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs) return model @register_model def vit_large_patch16_224(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs) return model @register_model def vit_large_patch16_384(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs) return model @register_model def vit_large_patch14_224(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/14) """ model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, **kwargs) model = _create_vision_transformer('vit_large_patch14_224', pretrained=pretrained, **model_kwargs) return model @register_model def vit_huge_patch14_224(pretrained=False, **kwargs): """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). """ model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs) model = _create_vision_transformer('vit_huge_patch14_224', pretrained=pretrained, **model_kwargs) return model @register_model def vit_giant_patch14_224(pretrained=False, **kwargs): """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 """ model_kwargs = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, **kwargs) model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **model_kwargs) return model @register_model def vit_gigantic_patch14_224(pretrained=False, **kwargs): """ ViT-Gigantic (big-G) model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 """ model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, **kwargs) model = _create_vision_transformer('vit_gigantic_patch14_224', pretrained=pretrained, **model_kwargs) return model @register_model def vit_base_patch16_224_miil(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K """ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs) return model @register_model def vit_medium_patch16_gap_240(pretrained=False, **kwargs): """ ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 240x240 """ model_kwargs = dict( patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False, global_pool=kwargs.get('global_pool', 'avg'), qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs) model = _create_vision_transformer('vit_medium_patch16_gap_240', pretrained=pretrained, **model_kwargs) return model @register_model def vit_medium_patch16_gap_256(pretrained=False, **kwargs): """ ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 256x256 """ model_kwargs = dict( patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False, global_pool=kwargs.get('global_pool', 'avg'), qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs) model = _create_vision_transformer('vit_medium_patch16_gap_256', pretrained=pretrained, **model_kwargs) return model @register_model def vit_medium_patch16_gap_384(pretrained=False, **kwargs): """ ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 384x384 """ model_kwargs = dict( patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False, global_pool=kwargs.get('global_pool', 'avg'), qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs) model = _create_vision_transformer('vit_medium_patch16_gap_384', pretrained=pretrained, **model_kwargs) return model @register_model def vit_base_patch16_gap_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 256x256 """ model_kwargs = dict( patch_size=16, embed_dim=768, depth=12, num_heads=16, class_token=False, global_pool=kwargs.get('global_pool', 'avg'), fc_norm=False, **kwargs) model = _create_vision_transformer('vit_base_patch16_gap_224', pretrained=pretrained, **model_kwargs) return model @register_model def vit_base_patch32_clip_224(pretrained=False, **kwargs): """ ViT-B/32 CLIP image tower @ 224x224 """ model_kwargs = dict( patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) model = _create_vision_transformer('vit_base_patch32_clip_224', pretrained=pretrained, **model_kwargs) return model @register_model def vit_base_patch32_clip_384(pretrained=False, **kwargs): """ ViT-B/32 CLIP image tower @ 384x384 """ model_kwargs = dict( patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) model = _create_vision_transformer('vit_base_patch32_clip_384', pretrained=pretrained, **model_kwargs) return model @register_model def vit_base_patch32_clip_448(pretrained=False, **kwargs): """ ViT-B/32 CLIP image tower @ 448x448 """ model_kwargs = dict( patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) model = _create_vision_transformer('vit_base_patch32_clip_448', pretrained=pretrained, **model_kwargs) return model @register_model def vit_base_patch16_clip_224(pretrained=False, **kwargs): """ ViT-B/16 CLIP image tower """ model_kwargs = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) model = _create_vision_transformer('vit_base_patch16_clip_224', pretrained=pretrained, **model_kwargs) return model @register_model def vit_base_patch16_clip_384(pretrained=False, **kwargs): """ ViT-B/16 CLIP image tower @ 384x384 """ model_kwargs = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) model = _create_vision_transformer('vit_base_patch16_clip_384', pretrained=pretrained, **model_kwargs) return model @register_model def vit_large_patch14_clip_224(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/14) CLIP image tower """ model_kwargs = dict( patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) model = _create_vision_transformer('vit_large_patch14_clip_224', pretrained=pretrained, **model_kwargs) return model @register_model def vit_large_patch14_clip_336(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/14) CLIP image tower @ 336x336 """ model_kwargs = dict( patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) model = _create_vision_transformer('vit_large_patch14_clip_336', pretrained=pretrained, **model_kwargs) return model @register_model def vit_huge_patch14_clip_224(pretrained=False, **kwargs): """ ViT-Huge model (ViT-H/14) CLIP image tower. """ model_kwargs = dict( patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) model = _create_vision_transformer('vit_huge_patch14_clip_224', pretrained=pretrained, **model_kwargs) return model @register_model def vit_huge_patch14_clip_336(pretrained=False, **kwargs): """ ViT-Huge model (ViT-H/14) CLIP image tower @ 336x336 """ model_kwargs = dict( patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) model = _create_vision_transformer('vit_huge_patch14_clip_336', pretrained=pretrained, **model_kwargs) return model @register_model def vit_giant_patch14_clip_224(pretrained=False, **kwargs): """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 Pretrained weights from CLIP image tower. """ model_kwargs = dict( patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) model = _create_vision_transformer('vit_giant_patch14_clip_224', pretrained=pretrained, **model_kwargs) return model # Experimental models below @register_model def vit_base_patch32_plus_256(pretrained=False, **kwargs): """ ViT-Base (ViT-B/32+) """ model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs) model = _create_vision_transformer('vit_base_patch32_plus_256', pretrained=pretrained, **model_kwargs) return model @register_model def vit_base_patch16_plus_240(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16+) """ model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs) model = _create_vision_transformer('vit_base_patch16_plus_240', pretrained=pretrained, **model_kwargs) return model @register_model def vit_base_patch16_rpn_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) w/ residual post-norm """ model_kwargs = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5, class_token=False, block_fn=ResPostBlock, global_pool=kwargs.pop('global_pool', 'avg'), **kwargs) model = _create_vision_transformer('vit_base_patch16_rpn_224', pretrained=pretrained, **model_kwargs) return model @register_model def vit_small_patch16_36x1_224(pretrained=False, **kwargs): """ ViT-Base w/ LayerScale + 36 x 1 (36 block serial) config. Experimental, may remove. Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow. """ model_kwargs = dict(patch_size=16, embed_dim=384, depth=36, num_heads=6, init_values=1e-5, **kwargs) model = _create_vision_transformer('vit_small_patch16_36x1_224', pretrained=pretrained, **model_kwargs) return model @register_model def vit_small_patch16_18x2_224(pretrained=False, **kwargs): """ ViT-Small w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove. Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow. """ model_kwargs = dict( patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelBlock, **kwargs) model = _create_vision_transformer('vit_small_patch16_18x2_224', pretrained=pretrained, **model_kwargs) return model @register_model def vit_base_patch16_18x2_224(pretrained=False, **kwargs): """ ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove. Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 """ model_kwargs = dict( patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock, **kwargs) model = _create_vision_transformer('vit_base_patch16_18x2_224', pretrained=pretrained, **model_kwargs) return model @register_model def eva_large_patch14_196(pretrained=False, **kwargs): """ EVA-large model https://arxiv.org/abs/2211.07636 /via MAE MIM pretrain""" model_kwargs = dict( patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg', **kwargs) model = _create_vision_transformer('eva_large_patch14_196', pretrained=pretrained, **model_kwargs) return model @register_model def eva_large_patch14_336(pretrained=False, **kwargs): """ EVA-large model https://arxiv.org/abs/2211.07636 via MAE MIM pretrain""" model_kwargs = dict( patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg', **kwargs) model = _create_vision_transformer('eva_large_patch14_336', pretrained=pretrained, **model_kwargs) return model