|
|
|
@ -1,8 +1,6 @@
|
|
|
|
|
""" BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
|
|
|
|
|
""" BEiT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
|
|
|
|
|
|
|
|
|
|
Model from official source: https://github.com/microsoft/unilm/tree/master/beit
|
|
|
|
|
and
|
|
|
|
|
https://github.com/microsoft/unilm/tree/master/beit2
|
|
|
|
|
|
|
|
|
|
@inproceedings{beit,
|
|
|
|
|
title={{BEiT}: {BERT} Pre-Training of Image Transformers},
|
|
|
|
@ -12,6 +10,8 @@ year={2022},
|
|
|
|
|
url={https://openreview.net/forum?id=p-BhZSz59o4}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
BEiT-v2 from https://github.com/microsoft/unilm/tree/master/beit2
|
|
|
|
|
|
|
|
|
|
@article{beitv2,
|
|
|
|
|
title={{BEiT v2}: Masked Image Modeling with Vector-Quantized Visual Tokenizers},
|
|
|
|
|
author={Zhiliang Peng and Li Dong and Hangbo Bao and Qixiang Ye and Furu Wei},
|
|
|
|
@ -21,6 +21,17 @@ archivePrefix={arXiv},
|
|
|
|
|
primaryClass={cs.CV}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
EVA from https://github.com/baaivision/EVA , paper: https://arxiv.org/abs/2211.07636
|
|
|
|
|
|
|
|
|
|
@article{EVA,
|
|
|
|
|
title={EVA: Exploring the Limits of Masked Visual Representation Learning at Scale},
|
|
|
|
|
author={Fang, Yuxin and Wang, Wen and Xie, Binhui and Sun, Quan and Wu, Ledell and Wang, Xinggang and Huang,
|
|
|
|
|
Tiejun and Wang, Xinlong and Cao, Yue},
|
|
|
|
|
journal={arXiv preprint arXiv:2211.07636},
|
|
|
|
|
year={2022}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
At this point only the 1k fine-tuned classification weights and model configs have been added,
|
|
|
|
|
see original source above for pre-training models and procedure.
|
|
|
|
|
|
|
|
|
@ -37,6 +48,9 @@ Modifications by / Copyright 2021 Ross Wightman, original copyrights below
|
|
|
|
|
# https://github.com/facebookresearch/deit/
|
|
|
|
|
# https://github.com/facebookresearch/dino
|
|
|
|
|
# --------------------------------------------------------'
|
|
|
|
|
|
|
|
|
|
# EVA models Copyright (c) 2022 BAAI-Vision
|
|
|
|
|
|
|
|
|
|
import math
|
|
|
|
|
from functools import partial
|
|
|
|
|
from typing import Optional, Tuple
|
|
|
|
@ -46,72 +60,14 @@ import torch.nn as nn
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
from torch.utils.checkpoint import checkpoint
|
|
|
|
|
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
|
|
|
|
from .helpers import build_model_with_cfg
|
|
|
|
|
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_
|
|
|
|
|
from .pretrained import generate_default_cfgs
|
|
|
|
|
from .registry import register_model
|
|
|
|
|
from .vision_transformer import checkpoint_filter_fn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cfg(url='', **kwargs):
|
|
|
|
|
return {
|
|
|
|
|
'url': url,
|
|
|
|
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
|
|
|
|
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
|
|
|
|
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
|
|
|
|
|
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
|
|
|
|
**kwargs
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
default_cfgs = {
|
|
|
|
|
'beit_base_patch16_224': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth'),
|
|
|
|
|
'beit_base_patch16_384': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth',
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0,
|
|
|
|
|
),
|
|
|
|
|
'beit_base_patch16_224_in22k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22k.pth',
|
|
|
|
|
num_classes=21841,
|
|
|
|
|
),
|
|
|
|
|
'beit_large_patch16_224': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth'),
|
|
|
|
|
'beit_large_patch16_384': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth',
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0,
|
|
|
|
|
),
|
|
|
|
|
'beit_large_patch16_512': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth',
|
|
|
|
|
input_size=(3, 512, 512), crop_pct=1.0,
|
|
|
|
|
),
|
|
|
|
|
'beit_large_patch16_224_in22k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth',
|
|
|
|
|
num_classes=21841,
|
|
|
|
|
),
|
|
|
|
|
|
|
|
|
|
'beitv2_base_patch16_224': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth',
|
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
|
|
|
|
|
),
|
|
|
|
|
'beitv2_base_patch16_224_in22k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth',
|
|
|
|
|
num_classes=21841,
|
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
|
|
|
|
|
),
|
|
|
|
|
'beitv2_large_patch16_224': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth',
|
|
|
|
|
crop_pct=0.95,
|
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
|
|
|
|
|
),
|
|
|
|
|
'beitv2_large_patch16_224_in22k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth',
|
|
|
|
|
num_classes=21841,
|
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
|
|
|
|
|
),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor:
|
|
|
|
|
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
|
|
|
|
# cls to token & token 2 cls & cls to cls
|
|
|
|
@ -384,6 +340,82 @@ class Beit(nn.Module):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cfg(url='', **kwargs):
|
|
|
|
|
return {
|
|
|
|
|
'url': url,
|
|
|
|
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
|
|
|
|
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
|
|
|
|
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
|
|
|
|
|
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
|
|
|
|
**kwargs
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
default_cfgs = generate_default_cfgs({
|
|
|
|
|
'beit_base_patch16_224.in22k_ft_in22k_in1k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth'),
|
|
|
|
|
'beit_base_patch16_384.in22k_ft_in22k_in1k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth',
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0,
|
|
|
|
|
),
|
|
|
|
|
'beit_base_patch16_224.in22k_ft_in22k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22k.pth',
|
|
|
|
|
num_classes=21841,
|
|
|
|
|
),
|
|
|
|
|
'beit_large_patch16_224.in22k_ft_in22k_in1k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth'),
|
|
|
|
|
'beit_large_patch16_384.in22k_ft_in22k_in1k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth',
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0,
|
|
|
|
|
),
|
|
|
|
|
'beit_large_patch16_512.in22k_ft_in22k_in1k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth',
|
|
|
|
|
input_size=(3, 512, 512), crop_pct=1.0,
|
|
|
|
|
),
|
|
|
|
|
'beit_large_patch16_224.in22k_ft_in22k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth',
|
|
|
|
|
num_classes=21841,
|
|
|
|
|
),
|
|
|
|
|
|
|
|
|
|
'beitv2_base_patch16_224.in1k_ft_in22k_in1k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth',
|
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
|
|
|
|
|
),
|
|
|
|
|
'beitv2_base_patch16_224.in1k_ft_in22k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth',
|
|
|
|
|
num_classes=21841,
|
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
|
|
|
|
|
),
|
|
|
|
|
'beitv2_large_patch16_224.in1k_ft_in22k_in1k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth',
|
|
|
|
|
crop_pct=0.95,
|
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
|
|
|
|
|
),
|
|
|
|
|
'beitv2_large_patch16_224.in1k_ft_in22k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth',
|
|
|
|
|
num_classes=21841,
|
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
|
|
|
|
|
),
|
|
|
|
|
|
|
|
|
|
'eva_giant_patch14_224.clip_ft_in1k': _cfg(
|
|
|
|
|
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt',
|
|
|
|
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0,
|
|
|
|
|
),
|
|
|
|
|
'eva_giant_patch14_336.clip_ft_in1k': _cfg(
|
|
|
|
|
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt',
|
|
|
|
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
|
|
|
|
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
|
|
|
|
|
'eva_giant_patch14_336.m30m_ft_in22k_in1k': _cfg(
|
|
|
|
|
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt',
|
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
|
|
|
|
|
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
|
|
|
|
|
'eva_giant_patch14_560.m30m_ft_in22k_in1k': _cfg(
|
|
|
|
|
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt',
|
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
|
|
|
|
|
input_size=(3, 560, 560), crop_pct=1.0, crop_mode='squash'),
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _beit_checkpoint_filter_fn(state_dict, model):
|
|
|
|
|
if 'module' in state_dict:
|
|
|
|
|
# beit v2 didn't strip module
|
|
|
|
@ -393,7 +425,7 @@ def _beit_checkpoint_filter_fn(state_dict, model):
|
|
|
|
|
|
|
|
|
|
def _create_beit(variant, pretrained=False, **kwargs):
|
|
|
|
|
if kwargs.get('features_only', None):
|
|
|
|
|
raise RuntimeError('features_only not implemented for Beit models.')
|
|
|
|
|
raise RuntimeError('features_only not implemented for BEiT models.')
|
|
|
|
|
|
|
|
|
|
model = build_model_with_cfg(
|
|
|
|
|
Beit, variant, pretrained,
|
|
|
|
@ -415,25 +447,16 @@ def beit_base_patch16_224(pretrained=False, **kwargs):
|
|
|
|
|
@register_model
|
|
|
|
|
def beit_base_patch16_384(pretrained=False, **kwargs):
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
|
|
|
|
|
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
|
|
|
|
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1, **kwargs)
|
|
|
|
|
model = _create_beit('beit_base_patch16_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def beit_base_patch16_224_in22k(pretrained=False, **kwargs):
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
|
|
|
|
|
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1, **kwargs)
|
|
|
|
|
model = _create_beit('beit_base_patch16_224_in22k', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def beit_large_patch16_224(pretrained=False, **kwargs):
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
|
|
|
|
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
|
|
|
|
|
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
|
|
|
|
|
model = _create_beit('beit_large_patch16_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
@ -442,7 +465,7 @@ def beit_large_patch16_224(pretrained=False, **kwargs):
|
|
|
|
|
@register_model
|
|
|
|
|
def beit_large_patch16_384(pretrained=False, **kwargs):
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
|
|
|
|
img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16,
|
|
|
|
|
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
|
|
|
|
|
model = _create_beit('beit_large_patch16_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
@ -451,52 +474,52 @@ def beit_large_patch16_384(pretrained=False, **kwargs):
|
|
|
|
|
@register_model
|
|
|
|
|
def beit_large_patch16_512(pretrained=False, **kwargs):
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
|
|
|
|
img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16,
|
|
|
|
|
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
|
|
|
|
|
model = _create_beit('beit_large_patch16_512', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def beit_large_patch16_224_in22k(pretrained=False, **kwargs):
|
|
|
|
|
def beitv2_base_patch16_224(pretrained=False, **kwargs):
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
|
|
|
|
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
|
|
|
|
|
model = _create_beit('beit_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
|
|
|
|
|
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
|
|
|
|
|
model = _create_beit('beitv2_base_patch16_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def beitv2_base_patch16_224(pretrained=False, **kwargs):
|
|
|
|
|
def beitv2_large_patch16_224(pretrained=False, **kwargs):
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
|
|
|
|
|
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
|
|
|
|
|
model = _create_beit('beitv2_base_patch16_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
|
|
|
|
|
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
|
|
|
|
|
model = _create_beit('beitv2_large_patch16_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def beitv2_base_patch16_224_in22k(pretrained=False, **kwargs):
|
|
|
|
|
def eva_giant_patch14_224(pretrained=False, **kwargs):
|
|
|
|
|
""" EVA-g model https://arxiv.org/abs/2211.07636 """
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
|
|
|
|
|
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
|
|
|
|
|
model = _create_beit('beitv2_base_patch16_224_in22k', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs)
|
|
|
|
|
model = _create_beit('eva_giant_patch14_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def beitv2_large_patch16_224(pretrained=False, **kwargs):
|
|
|
|
|
def eva_giant_patch14_336(pretrained=False, **kwargs):
|
|
|
|
|
""" EVA-g model https://arxiv.org/abs/2211.07636 """
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
|
|
|
|
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
|
|
|
|
|
model = _create_beit('beitv2_large_patch16_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs)
|
|
|
|
|
model = _create_beit('eva_giant_patch14_336', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def beitv2_large_patch16_224_in22k(pretrained=False, **kwargs):
|
|
|
|
|
def eva_giant_patch14_560(pretrained=False, **kwargs):
|
|
|
|
|
""" EVA-g model https://arxiv.org/abs/2211.07636 """
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
|
|
|
|
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
|
|
|
|
|
model = _create_beit('beitv2_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs)
|
|
|
|
|
model = _create_beit('eva_giant_patch14_560', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|