|
|
|
@ -30,7 +30,8 @@ import torch.nn as nn
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
import torch.utils.checkpoint
|
|
|
|
|
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD,\
|
|
|
|
|
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
|
|
|
|
from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply, adapt_input_conv, checkpoint_seq
|
|
|
|
|
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
|
|
|
|
|
from .registry import register_model
|
|
|
|
@ -106,7 +107,7 @@ default_cfgs = {
|
|
|
|
|
'vit_large_patch14_224': _cfg(url=''),
|
|
|
|
|
'vit_huge_patch14_224': _cfg(url=''),
|
|
|
|
|
'vit_giant_patch14_224': _cfg(url=''),
|
|
|
|
|
'vit_gee_patch14_224': _cfg(url=''),
|
|
|
|
|
'vit_gigantic_patch14_224': _cfg(url=''),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# patch models, imagenet21k (weights from official Google JAX impl)
|
|
|
|
@ -179,17 +180,21 @@ default_cfgs = {
|
|
|
|
|
'vit_base_patch16_18x2_224': _cfg(url=''),
|
|
|
|
|
|
|
|
|
|
'vit_base_patch32_224_clip_laion2b': _cfg(
|
|
|
|
|
hf_hub_id='',
|
|
|
|
|
num_classes=512),
|
|
|
|
|
hf_hub_id='laion/CLIP-ViT-B-32-laion2B-s34B-b79K',
|
|
|
|
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
|
|
|
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
|
|
|
|
|
'vit_large_patch14_224_clip_laion2b': _cfg(
|
|
|
|
|
hf_hub_id='',
|
|
|
|
|
num_classes=768),
|
|
|
|
|
hf_hub_id='laion/CLIP-ViT-L-14-laion2B-s32B-b82K',
|
|
|
|
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
|
|
|
|
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=768),
|
|
|
|
|
'vit_huge_patch14_224_clip_laion2b': _cfg(
|
|
|
|
|
hf_hub_id='',
|
|
|
|
|
num_classes=1024),
|
|
|
|
|
hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K',
|
|
|
|
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
|
|
|
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=1024),
|
|
|
|
|
'vit_giant_patch14_224_clip_laion2b': _cfg(
|
|
|
|
|
hf_hub_id='',
|
|
|
|
|
num_classes=1024),
|
|
|
|
|
hf_hub_id='CLIP-ViT-g-14-laion2B-s12B-b42K',
|
|
|
|
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
|
|
|
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=1024),
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -960,12 +965,11 @@ def vit_giant_patch14_224(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_gee_patch14_224(pretrained=False, **kwargs):
|
|
|
|
|
""" ViT-GEE (big-G) model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
|
|
|
|
|
As per https://twitter.com/wightmanr/status/1570549064667889666
|
|
|
|
|
def vit_gigantic_patch14_224(pretrained=False, **kwargs):
|
|
|
|
|
""" ViT-Gigantic (big-G) model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_gee_patch14_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_gigantic_patch14_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|