@ -1,8 +1,6 @@
""" BE I T: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
""" BE i T: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
Model from official source : https : / / github . com / microsoft / unilm / tree / master / beit
and
https : / / github . com / microsoft / unilm / tree / master / beit2
@inproceedings { beit ,
title = { { BEiT } : { BERT } Pre - Training of Image Transformers } ,
@ -12,6 +10,8 @@ year={2022},
url = { https : / / openreview . net / forum ? id = p - BhZSz59o4 }
}
BEiT - v2 from https : / / github . com / microsoft / unilm / tree / master / beit2
@article { beitv2 ,
title = { { BEiT v2 } : Masked Image Modeling with Vector - Quantized Visual Tokenizers } ,
author = { Zhiliang Peng and Li Dong and Hangbo Bao and Qixiang Ye and Furu Wei } ,
@ -21,6 +21,17 @@ archivePrefix={arXiv},
primaryClass = { cs . CV }
}
EVA from https : / / github . com / baaivision / EVA , paper : https : / / arxiv . org / abs / 2211.07636
@article { EVA ,
title = { EVA : Exploring the Limits of Masked Visual Representation Learning at Scale } ,
author = { Fang , Yuxin and Wang , Wen and Xie , Binhui and Sun , Quan and Wu , Ledell and Wang , Xinggang and Huang ,
Tiejun and Wang , Xinlong and Cao , Yue } ,
journal = { arXiv preprint arXiv : 2211.07636 } ,
year = { 2022 }
}
At this point only the 1 k fine - tuned classification weights and model configs have been added ,
see original source above for pre - training models and procedure .
@ -37,6 +48,9 @@ Modifications by / Copyright 2021 Ross Wightman, original copyrights below
# https://github.com/facebookresearch/deit/
# https://github.com/facebookresearch/dino
# --------------------------------------------------------'
# EVA models Copyright (c) 2022 BAAI-Vision
import math
from functools import partial
from typing import Optional , Tuple
@ -49,69 +63,12 @@ from torch.utils.checkpoint import checkpoint
from timm . data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD , OPENAI_CLIP_MEAN , OPENAI_CLIP_STD
from timm . layers import PatchEmbed , Mlp , DropPath , trunc_normal_
from . _builder import build_model_with_cfg
from . _pretrained import generate_default_cfgs
from . _registry import register_model
from . vision_transformer import checkpoint_filter_fn
__all__ = [ ' Beit ' ]
def _cfg ( url = ' ' , * * kwargs ) :
return {
' url ' : url ,
' num_classes ' : 1000 , ' input_size ' : ( 3 , 224 , 224 ) , ' pool_size ' : None ,
' crop_pct ' : .9 , ' interpolation ' : ' bicubic ' , ' fixed_input_size ' : True ,
' mean ' : ( 0.5 , 0.5 , 0.5 ) , ' std ' : ( 0.5 , 0.5 , 0.5 ) ,
' first_conv ' : ' patch_embed.proj ' , ' classifier ' : ' head ' ,
* * kwargs
}
default_cfgs = {
' beit_base_patch16_224 ' : _cfg (
url = ' https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth ' ) ,
' beit_base_patch16_384 ' : _cfg (
url = ' https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth ' ,
input_size = ( 3 , 384 , 384 ) , crop_pct = 1.0 ,
) ,
' beit_base_patch16_224_in22k ' : _cfg (
url = ' https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22k.pth ' ,
num_classes = 21841 ,
) ,
' beit_large_patch16_224 ' : _cfg (
url = ' https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth ' ) ,
' beit_large_patch16_384 ' : _cfg (
url = ' https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth ' ,
input_size = ( 3 , 384 , 384 ) , crop_pct = 1.0 ,
) ,
' beit_large_patch16_512 ' : _cfg (
url = ' https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth ' ,
input_size = ( 3 , 512 , 512 ) , crop_pct = 1.0 ,
) ,
' beit_large_patch16_224_in22k ' : _cfg (
url = ' https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth ' ,
num_classes = 21841 ,
) ,
' beitv2_base_patch16_224 ' : _cfg (
url = ' https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth ' ,
mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD
) ,
' beitv2_base_patch16_224_in22k ' : _cfg (
url = ' https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth ' ,
num_classes = 21841 ,
mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD
) ,
' beitv2_large_patch16_224 ' : _cfg (
url = ' https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth ' ,
crop_pct = 0.95 ,
mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD
) ,
' beitv2_large_patch16_224_in22k ' : _cfg (
url = ' https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth ' ,
num_classes = 21841 ,
mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD
) ,
}
def gen_relative_position_index ( window_size : Tuple [ int , int ] ) - > torch . Tensor :
num_relative_distance = ( 2 * window_size [ 0 ] - 1 ) * ( 2 * window_size [ 1 ] - 1 ) + 3
@ -385,6 +342,82 @@ class Beit(nn.Module):
return x
def _cfg ( url = ' ' , * * kwargs ) :
return {
' url ' : url ,
' num_classes ' : 1000 , ' input_size ' : ( 3 , 224 , 224 ) , ' pool_size ' : None ,
' crop_pct ' : .9 , ' interpolation ' : ' bicubic ' , ' fixed_input_size ' : True ,
' mean ' : ( 0.5 , 0.5 , 0.5 ) , ' std ' : ( 0.5 , 0.5 , 0.5 ) ,
' first_conv ' : ' patch_embed.proj ' , ' classifier ' : ' head ' ,
* * kwargs
}
default_cfgs = generate_default_cfgs ( {
' beit_base_patch16_224.in22k_ft_in22k_in1k ' : _cfg (
url = ' https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth ' ) ,
' beit_base_patch16_384.in22k_ft_in22k_in1k ' : _cfg (
url = ' https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth ' ,
input_size = ( 3 , 384 , 384 ) , crop_pct = 1.0 ,
) ,
' beit_base_patch16_224.in22k_ft_in22k ' : _cfg (
url = ' https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22k.pth ' ,
num_classes = 21841 ,
) ,
' beit_large_patch16_224.in22k_ft_in22k_in1k ' : _cfg (
url = ' https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth ' ) ,
' beit_large_patch16_384.in22k_ft_in22k_in1k ' : _cfg (
url = ' https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth ' ,
input_size = ( 3 , 384 , 384 ) , crop_pct = 1.0 ,
) ,
' beit_large_patch16_512.in22k_ft_in22k_in1k ' : _cfg (
url = ' https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth ' ,
input_size = ( 3 , 512 , 512 ) , crop_pct = 1.0 ,
) ,
' beit_large_patch16_224.in22k_ft_in22k ' : _cfg (
url = ' https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth ' ,
num_classes = 21841 ,
) ,
' beitv2_base_patch16_224.in1k_ft_in22k_in1k ' : _cfg (
url = ' https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth ' ,
mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD
) ,
' beitv2_base_patch16_224.in1k_ft_in22k ' : _cfg (
url = ' https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth ' ,
num_classes = 21841 ,
mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD
) ,
' beitv2_large_patch16_224.in1k_ft_in22k_in1k ' : _cfg (
url = ' https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth ' ,
crop_pct = 0.95 ,
mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD
) ,
' beitv2_large_patch16_224.in1k_ft_in22k ' : _cfg (
url = ' https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth ' ,
num_classes = 21841 ,
mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD
) ,
' eva_giant_patch14_224.clip_ft_in1k ' : _cfg (
hf_hub_id = ' BAAI/EVA ' , hf_hub_filename = ' eva_clip_vis_enc_sz224_ftcls_89p1.pt ' ,
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 ,
) ,
' eva_giant_patch14_336.clip_ft_in1k ' : _cfg (
hf_hub_id = ' BAAI/EVA ' , hf_hub_filename = ' eva_clip_vis_enc_sz336_ftcls_89p4.pt ' ,
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD ,
input_size = ( 3 , 336 , 336 ) , crop_pct = 1.0 , crop_mode = ' squash ' ) ,
' eva_giant_patch14_336.m30m_ft_in22k_in1k ' : _cfg (
hf_hub_id = ' BAAI/EVA ' , hf_hub_filename = ' eva_21k_1k_336px_psz14_ema_89p6.pt ' ,
mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ,
input_size = ( 3 , 336 , 336 ) , crop_pct = 1.0 , crop_mode = ' squash ' ) ,
' eva_giant_patch14_560.m30m_ft_in22k_in1k ' : _cfg (
hf_hub_id = ' BAAI/EVA ' , hf_hub_filename = ' eva_21k_1k_560px_psz14_ema_89p7.pt ' ,
mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ,
input_size = ( 3 , 560 , 560 ) , crop_pct = 1.0 , crop_mode = ' squash ' ) ,
} )
def _beit_checkpoint_filter_fn ( state_dict , model ) :
if ' module ' in state_dict :
# beit v2 didn't strip module
@ -394,7 +427,7 @@ def _beit_checkpoint_filter_fn(state_dict, model):
def _create_beit ( variant , pretrained = False , * * kwargs ) :
if kwargs . get ( ' features_only ' , None ) :
raise RuntimeError ( ' features_only not implemented for B eit models.' )
raise RuntimeError ( ' features_only not implemented for B EiT models.' )
model = build_model_with_cfg (
Beit , variant , pretrained ,
@ -416,25 +449,16 @@ def beit_base_patch16_224(pretrained=False, **kwargs):
@register_model
def beit_base_patch16_384 ( pretrained = False , * * kwargs ) :
model_kwargs = dict (
img_size = 384 , patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 ,
img_size = 384 , patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 ,
use_abs_pos_emb = False , use_rel_pos_bias = True , init_values = 0.1 , * * kwargs )
model = _create_beit ( ' beit_base_patch16_384 ' , pretrained = pretrained , * * model_kwargs )
return model
@register_model
def beit_base_patch16_224_in22k ( pretrained = False , * * kwargs ) :
model_kwargs = dict (
patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 ,
use_abs_pos_emb = False , use_rel_pos_bias = True , init_values = 0.1 , * * kwargs )
model = _create_beit ( ' beit_base_patch16_224_in22k ' , pretrained = pretrained , * * model_kwargs )
return model
@register_model
def beit_large_patch16_224 ( pretrained = False , * * kwargs ) :
model_kwargs = dict (
patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 , mlp_ratio = 4 , qkv_bias = True ,
patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 ,
use_abs_pos_emb = False , use_rel_pos_bias = True , init_values = 1e-5 , * * kwargs )
model = _create_beit ( ' beit_large_patch16_224 ' , pretrained = pretrained , * * model_kwargs )
return model
@ -443,7 +467,7 @@ def beit_large_patch16_224(pretrained=False, **kwargs):
@register_model
def beit_large_patch16_384 ( pretrained = False , * * kwargs ) :
model_kwargs = dict (
img_size = 384 , patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 , mlp_ratio = 4 , qkv_bias = True ,
img_size = 384 , patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 ,
use_abs_pos_emb = False , use_rel_pos_bias = True , init_values = 1e-5 , * * kwargs )
model = _create_beit ( ' beit_large_patch16_384 ' , pretrained = pretrained , * * model_kwargs )
return model
@ -452,52 +476,52 @@ def beit_large_patch16_384(pretrained=False, **kwargs):
@register_model
def beit_large_patch16_512 ( pretrained = False , * * kwargs ) :
model_kwargs = dict (
img_size = 512 , patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 , mlp_ratio = 4 , qkv_bias = True ,
img_size = 512 , patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 ,
use_abs_pos_emb = False , use_rel_pos_bias = True , init_values = 1e-5 , * * kwargs )
model = _create_beit ( ' beit_large_patch16_512 ' , pretrained = pretrained , * * model_kwargs )
return model
@register_model
def beit _large_patch16_224_in22k ( pretrained = False , * * kwargs ) :
def beit v2_base_patch16_224 ( pretrained = False , * * kwargs ) :
model_kwargs = dict (
patch_size = 16 , embed_dim = 1024, depth = 24 , num_heads = 16 , mlp_ratio = 4 , qkv_bias = True ,
use_abs_pos_emb = False , use_rel_pos_bias = True , init_values = 1e-5 , * * kwargs )
model = _create_beit ( ' beit _large_patch16_224_in22k ' , pretrained = pretrained , * * model_kwargs )
patch_size = 16 , embed_dim = 768, depth = 12 , num_heads = 12 , mlp_ratio = 4 ,
use_abs_pos_emb = False , use_rel_pos_bias = True , init_values = 1e-5 , * * kwargs )
model = _create_beit ( ' beit v2_base_patch16_224 ' , pretrained = pretrained , * * model_kwargs )
return model
@register_model
def beitv2_ bas e_patch16_224( pretrained = False , * * kwargs ) :
def beitv2_ larg e_patch16_224( pretrained = False , * * kwargs ) :
model_kwargs = dict (
patch_size = 16 , embed_dim = 768, depth = 12 , num_heads = 12 , mlp_ratio = 4 ,
use_abs_pos_emb = False , use_rel_pos_bias = True , init_values = 1e-5 , * * kwargs )
model = _create_beit ( ' beitv2_ bas e_patch16_224' , pretrained = pretrained , * * model_kwargs )
patch_size = 16 , embed_dim = 1024, depth = 24 , num_heads = 16 ,
use_abs_pos_emb = False , use_rel_pos_bias = True , init_values = 1e-5 , * * kwargs )
model = _create_beit ( ' beitv2_ larg e_patch16_224' , pretrained = pretrained , * * model_kwargs )
return model
@register_model
def beitv2_base_patch16_224_in22k ( pretrained = False , * * kwargs ) :
def eva_giant_patch14_224 ( pretrained = False , * * kwargs ) :
""" EVA-g model https://arxiv.org/abs/2211.07636 """
model_kwargs = dict (
patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 ,
use_abs_pos_emb = False , use_rel_pos_bias = True , init_values = 1e-5 , * * kwargs )
model = _create_beit ( ' beitv2_base_patch16_224_in22k ' , pretrained = pretrained , * * model_kwargs )
patch_size = 14 , embed_dim = 1408 , depth = 40 , num_heads = 16 , mlp_ratio = 6144 / 1408 , * * kwargs )
model = _create_beit ( ' eva_giant_patch14_224 ' , pretrained = pretrained , * * model_kwargs )
return model
@register_model
def beitv2_large_patch16_224 ( pretrained = False , * * kwargs ) :
def eva_giant_patch14_336 ( pretrained = False , * * kwargs ) :
""" EVA-g model https://arxiv.org/abs/2211.07636 """
model_kwargs = dict (
patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 , mlp_ratio = 4 , qkv_bias = True ,
use_abs_pos_emb = False , use_rel_pos_bias = True , init_values = 1e-5 , * * kwargs )
model = _create_beit ( ' beitv2_large_patch16_224 ' , pretrained = pretrained , * * model_kwargs )
patch_size = 14 , embed_dim = 1408 , depth = 40 , num_heads = 16 , mlp_ratio = 6144 / 1408 , * * kwargs )
model = _create_beit ( ' eva_giant_patch14_336 ' , pretrained = pretrained , * * model_kwargs )
return model
@register_model
def beitv2_large_patch16_224_in22k ( pretrained = False , * * kwargs ) :
def eva_giant_patch14_560 ( pretrained = False , * * kwargs ) :
""" EVA-g model https://arxiv.org/abs/2211.07636 """
model_kwargs = dict (
patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 , mlp_ratio = 4 , qkv_bias = True ,
use_abs_pos_emb = False , use_rel_pos_bias = True , init_values = 1e-5 , * * kwargs )
model = _create_beit ( ' beitv2_large_patch16_224_in22k ' , pretrained = pretrained , * * model_kwargs )
patch_size = 14 , embed_dim = 1408 , depth = 40 , num_heads = 16 , mlp_ratio = 6144 / 1408 , * * kwargs )
model = _create_beit ( ' eva_giant_patch14_560 ' , pretrained = pretrained , * * model_kwargs )
return model