|
|
@ -25,6 +25,7 @@ from functools import partial
|
|
|
|
from collections import OrderedDict
|
|
|
|
from collections import OrderedDict
|
|
|
|
from typing import Optional
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import huggingface_hub.file_download
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import torch.nn.functional as F
|
|
|
@ -106,7 +107,7 @@ default_cfgs = {
|
|
|
|
'vit_large_patch14_224': _cfg(url=''),
|
|
|
|
'vit_large_patch14_224': _cfg(url=''),
|
|
|
|
'vit_huge_patch14_224': _cfg(url=''),
|
|
|
|
'vit_huge_patch14_224': _cfg(url=''),
|
|
|
|
'vit_giant_patch14_224': _cfg(url=''),
|
|
|
|
'vit_giant_patch14_224': _cfg(url=''),
|
|
|
|
'vit_gigantic_patch14_224': _cfg(url=''),
|
|
|
|
'vit_gee_patch14_224': _cfg(url=''),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# patch models, imagenet21k (weights from official Google JAX impl)
|
|
|
|
# patch models, imagenet21k (weights from official Google JAX impl)
|
|
|
@ -177,6 +178,20 @@ default_cfgs = {
|
|
|
|
'vit_small_patch16_36x1_224': _cfg(url=''),
|
|
|
|
'vit_small_patch16_36x1_224': _cfg(url=''),
|
|
|
|
'vit_small_patch16_18x2_224': _cfg(url=''),
|
|
|
|
'vit_small_patch16_18x2_224': _cfg(url=''),
|
|
|
|
'vit_base_patch16_18x2_224': _cfg(url=''),
|
|
|
|
'vit_base_patch16_18x2_224': _cfg(url=''),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'vit_base_patch32_224_clip_laion2b': _cfg(
|
|
|
|
|
|
|
|
hf_hub_id='',
|
|
|
|
|
|
|
|
num_classes=512),
|
|
|
|
|
|
|
|
'vit_large_patch14_224_clip_laion2b': _cfg(
|
|
|
|
|
|
|
|
hf_hub_id='',
|
|
|
|
|
|
|
|
num_classes=768),
|
|
|
|
|
|
|
|
'vit_huge_patch14_224_clip_laion2b': _cfg(
|
|
|
|
|
|
|
|
hf_hub_id='',
|
|
|
|
|
|
|
|
num_classes=1024),
|
|
|
|
|
|
|
|
'vit_giant_patch14_224_clip_laion2b': _cfg(
|
|
|
|
|
|
|
|
hf_hub_id='',
|
|
|
|
|
|
|
|
num_classes=1024),
|
|
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -221,8 +236,18 @@ class LayerScale(nn.Module):
|
|
|
|
class Block(nn.Module):
|
|
|
|
class Block(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
|
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
|
|
|
|
self,
|
|
|
|
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
|
|
|
dim,
|
|
|
|
|
|
|
|
num_heads,
|
|
|
|
|
|
|
|
mlp_ratio=4.,
|
|
|
|
|
|
|
|
qkv_bias=False,
|
|
|
|
|
|
|
|
drop=0.,
|
|
|
|
|
|
|
|
attn_drop=0.,
|
|
|
|
|
|
|
|
init_values=None,
|
|
|
|
|
|
|
|
drop_path=0.,
|
|
|
|
|
|
|
|
act_layer=nn.GELU,
|
|
|
|
|
|
|
|
norm_layer=nn.LayerNorm
|
|
|
|
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
self.norm1 = norm_layer(dim)
|
|
|
|
self.norm1 = norm_layer(dim)
|
|
|
|
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
|
|
|
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
|
|
@ -244,8 +269,18 @@ class Block(nn.Module):
|
|
|
|
class ResPostBlock(nn.Module):
|
|
|
|
class ResPostBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
|
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
|
|
|
|
self,
|
|
|
|
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
|
|
|
dim,
|
|
|
|
|
|
|
|
num_heads,
|
|
|
|
|
|
|
|
mlp_ratio=4.,
|
|
|
|
|
|
|
|
qkv_bias=False,
|
|
|
|
|
|
|
|
drop=0.,
|
|
|
|
|
|
|
|
attn_drop=0.,
|
|
|
|
|
|
|
|
init_values=None,
|
|
|
|
|
|
|
|
drop_path=0.,
|
|
|
|
|
|
|
|
act_layer=nn.GELU,
|
|
|
|
|
|
|
|
norm_layer=nn.LayerNorm
|
|
|
|
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
self.init_values = init_values
|
|
|
|
self.init_values = init_values
|
|
|
|
|
|
|
|
|
|
|
@ -274,8 +309,19 @@ class ResPostBlock(nn.Module):
|
|
|
|
class ParallelBlock(nn.Module):
|
|
|
|
class ParallelBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
|
self, dim, num_heads, num_parallel=2, mlp_ratio=4., qkv_bias=False, init_values=None,
|
|
|
|
self,
|
|
|
|
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
|
|
|
dim,
|
|
|
|
|
|
|
|
num_heads,
|
|
|
|
|
|
|
|
num_parallel=2,
|
|
|
|
|
|
|
|
mlp_ratio=4.,
|
|
|
|
|
|
|
|
qkv_bias=False,
|
|
|
|
|
|
|
|
init_values=None,
|
|
|
|
|
|
|
|
drop=0.,
|
|
|
|
|
|
|
|
attn_drop=0.,
|
|
|
|
|
|
|
|
drop_path=0.,
|
|
|
|
|
|
|
|
act_layer=nn.GELU,
|
|
|
|
|
|
|
|
norm_layer=nn.LayerNorm
|
|
|
|
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
self.num_parallel = num_parallel
|
|
|
|
self.num_parallel = num_parallel
|
|
|
|
self.attns = nn.ModuleList()
|
|
|
|
self.attns = nn.ModuleList()
|
|
|
@ -320,10 +366,31 @@ class VisionTransformer(nn.Module):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
|
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
|
|
|
|
self,
|
|
|
|
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None,
|
|
|
|
img_size=224,
|
|
|
|
class_token=True, no_embed_class=False, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
|
|
|
|
patch_size=16,
|
|
|
|
weight_init='', embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block):
|
|
|
|
in_chans=3,
|
|
|
|
|
|
|
|
num_classes=1000,
|
|
|
|
|
|
|
|
global_pool='token',
|
|
|
|
|
|
|
|
embed_dim=768,
|
|
|
|
|
|
|
|
depth=12,
|
|
|
|
|
|
|
|
num_heads=12,
|
|
|
|
|
|
|
|
mlp_ratio=4.,
|
|
|
|
|
|
|
|
qkv_bias=True,
|
|
|
|
|
|
|
|
init_values=None,
|
|
|
|
|
|
|
|
class_token=True,
|
|
|
|
|
|
|
|
no_embed_class=False,
|
|
|
|
|
|
|
|
pre_norm=False,
|
|
|
|
|
|
|
|
fc_norm=None,
|
|
|
|
|
|
|
|
drop_rate=0.,
|
|
|
|
|
|
|
|
attn_drop_rate=0.,
|
|
|
|
|
|
|
|
drop_path_rate=0.,
|
|
|
|
|
|
|
|
weight_init='',
|
|
|
|
|
|
|
|
embed_layer=PatchEmbed,
|
|
|
|
|
|
|
|
norm_layer=None,
|
|
|
|
|
|
|
|
act_layer=None,
|
|
|
|
|
|
|
|
block_fn=Block,
|
|
|
|
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
img_size (int, tuple): input image size
|
|
|
|
img_size (int, tuple): input image size
|
|
|
@ -362,19 +429,34 @@ class VisionTransformer(nn.Module):
|
|
|
|
self.grad_checkpointing = False
|
|
|
|
self.grad_checkpointing = False
|
|
|
|
|
|
|
|
|
|
|
|
self.patch_embed = embed_layer(
|
|
|
|
self.patch_embed = embed_layer(
|
|
|
|
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
|
|
|
img_size=img_size,
|
|
|
|
|
|
|
|
patch_size=patch_size,
|
|
|
|
|
|
|
|
in_chans=in_chans,
|
|
|
|
|
|
|
|
embed_dim=embed_dim,
|
|
|
|
|
|
|
|
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
|
|
|
|
|
|
|
|
)
|
|
|
|
num_patches = self.patch_embed.num_patches
|
|
|
|
num_patches = self.patch_embed.num_patches
|
|
|
|
|
|
|
|
|
|
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
|
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
|
|
|
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
|
|
|
|
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
|
|
|
|
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
|
|
|
|
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
|
|
|
|
self.pos_drop = nn.Dropout(p=drop_rate)
|
|
|
|
self.pos_drop = nn.Dropout(p=drop_rate)
|
|
|
|
|
|
|
|
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
|
|
|
|
|
|
|
|
|
|
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
|
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
|
|
|
self.blocks = nn.Sequential(*[
|
|
|
|
self.blocks = nn.Sequential(*[
|
|
|
|
block_fn(
|
|
|
|
block_fn(
|
|
|
|
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values,
|
|
|
|
dim=embed_dim,
|
|
|
|
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
|
|
|
|
num_heads=num_heads,
|
|
|
|
|
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
|
|
|
|
|
qkv_bias=qkv_bias,
|
|
|
|
|
|
|
|
init_values=init_values,
|
|
|
|
|
|
|
|
drop=drop_rate,
|
|
|
|
|
|
|
|
attn_drop=attn_drop_rate,
|
|
|
|
|
|
|
|
drop_path=dpr[i],
|
|
|
|
|
|
|
|
norm_layer=norm_layer,
|
|
|
|
|
|
|
|
act_layer=act_layer
|
|
|
|
|
|
|
|
)
|
|
|
|
for i in range(depth)])
|
|
|
|
for i in range(depth)])
|
|
|
|
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
|
|
|
|
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
|
|
|
|
|
|
|
|
|
|
|
@ -445,6 +527,7 @@ class VisionTransformer(nn.Module):
|
|
|
|
def forward_features(self, x):
|
|
|
|
def forward_features(self, x):
|
|
|
|
x = self.patch_embed(x)
|
|
|
|
x = self.patch_embed(x)
|
|
|
|
x = self._pos_embed(x)
|
|
|
|
x = self._pos_embed(x)
|
|
|
|
|
|
|
|
x = self.norm_pre(x)
|
|
|
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
|
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
|
|
x = checkpoint_seq(self.blocks, x)
|
|
|
|
x = checkpoint_seq(self.blocks, x)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
@ -623,6 +706,40 @@ def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()):
|
|
|
|
return posemb
|
|
|
|
return posemb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _convert_openai_clip(state_dict, model):
|
|
|
|
|
|
|
|
out_dict = {}
|
|
|
|
|
|
|
|
swaps = [
|
|
|
|
|
|
|
|
('visual.', ''), ('conv1', 'patch_embed.proj'), ('positional_embedding', 'pos_embed'),
|
|
|
|
|
|
|
|
('transformer.resblocks.', 'blocks.'), ('ln_pre', 'norm_pre'), ('ln_post', 'norm'), ('ln_', 'norm'),
|
|
|
|
|
|
|
|
('in_proj_', 'qkv.'), ('out_proj', 'proj'), ('mlp.c_fc', 'mlp.fc1'), ('mlp.c_proj', 'mlp.fc2'),
|
|
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
|
|
|
|
if not k.startswith('visual.'):
|
|
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
for sp in swaps:
|
|
|
|
|
|
|
|
k = k.replace(sp[0], sp[1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if k == 'proj':
|
|
|
|
|
|
|
|
k = 'head.weight'
|
|
|
|
|
|
|
|
v = v.transpose(0, 1)
|
|
|
|
|
|
|
|
out_dict['head.bias'] = torch.zeros(v.shape[0])
|
|
|
|
|
|
|
|
elif k == 'class_embedding':
|
|
|
|
|
|
|
|
k = 'cls_token'
|
|
|
|
|
|
|
|
v = v.unsqueeze(0).unsqueeze(1)
|
|
|
|
|
|
|
|
elif k == 'pos_embed':
|
|
|
|
|
|
|
|
v = v.unsqueeze(0)
|
|
|
|
|
|
|
|
if v.shape[1] != model.pos_embed.shape[1]:
|
|
|
|
|
|
|
|
# To resize pos embedding when using model at different size from pretrained weights
|
|
|
|
|
|
|
|
v = resize_pos_embed(
|
|
|
|
|
|
|
|
v,
|
|
|
|
|
|
|
|
model.pos_embed,
|
|
|
|
|
|
|
|
0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1),
|
|
|
|
|
|
|
|
model.patch_embed.grid_size
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
out_dict[k] = v
|
|
|
|
|
|
|
|
return out_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False):
|
|
|
|
def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False):
|
|
|
|
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
|
|
|
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
|
|
|
import re
|
|
|
|
import re
|
|
|
@ -631,6 +748,9 @@ def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False):
|
|
|
|
# For deit models
|
|
|
|
# For deit models
|
|
|
|
state_dict = state_dict['model']
|
|
|
|
state_dict = state_dict['model']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if 'visual.class_embedding' in state_dict:
|
|
|
|
|
|
|
|
return _convert_openai_clip(state_dict, model)
|
|
|
|
|
|
|
|
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
|
|
|
|
if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
|
|
|
|
# For old models that I trained prior to conv based patchification
|
|
|
|
# For old models that I trained prior to conv based patchification
|
|
|
@ -833,7 +953,7 @@ def vit_huge_patch14_224(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def vit_giant_patch14_224(pretrained=False, **kwargs):
|
|
|
|
def vit_giant_patch14_224(pretrained=False, **kwargs):
|
|
|
|
""" ViT-Giant model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
|
|
|
|
""" ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_kwargs = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, **kwargs)
|
|
|
|
model_kwargs = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **model_kwargs)
|
|
|
@ -841,11 +961,12 @@ def vit_giant_patch14_224(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def vit_gigantic_patch14_224(pretrained=False, **kwargs):
|
|
|
|
def vit_gee_patch14_224(pretrained=False, **kwargs):
|
|
|
|
""" ViT-Gigantic model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
|
|
|
|
""" ViT-GEE (big-G) model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
|
|
|
|
|
|
|
|
As per https://twitter.com/wightmanr/status/1570549064667889666
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, **kwargs)
|
|
|
|
model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_gigantic_patch14_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer('vit_gee_patch14_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -1085,3 +1206,44 @@ def vit_base_patch16_18x2_224(pretrained=False, **kwargs):
|
|
|
|
patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock, **kwargs)
|
|
|
|
patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_patch16_18x2_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_patch16_18x2_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def vit_base_patch32_224_clip_laion2b(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
""" ViT-B/32
|
|
|
|
|
|
|
|
Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, **kwargs)
|
|
|
|
|
|
|
|
model = _create_vision_transformer('vit_base_patch32_224_clip_laion2b', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def vit_large_patch14_224_clip_laion2b(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
""" ViT-Large model (ViT-L/14)
|
|
|
|
|
|
|
|
Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, **kwargs)
|
|
|
|
|
|
|
|
model = _create_vision_transformer('vit_large_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def vit_huge_patch14_224_clip_laion2b(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
|
|
|
|
Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, **kwargs)
|
|
|
|
|
|
|
|
model = _create_vision_transformer('vit_huge_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def vit_giant_patch14_224_clip_laion2b(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
""" ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
|
|
|
|
|
|
|
|
Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
|
|
|
patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, pre_norm=True, **kwargs)
|
|
|
|
|
|
|
|
model = _create_vision_transformer('vit_giant_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
return model
|
|
|
|