|
|
|
@ -1,4 +1,4 @@
|
|
|
|
|
""" BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
|
|
|
|
|
""" BEiT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
|
|
|
|
|
|
|
|
|
|
Model from official source: https://github.com/microsoft/unilm/tree/master/beit
|
|
|
|
|
|
|
|
|
@ -68,82 +68,6 @@ from .registry import register_model
|
|
|
|
|
from .vision_transformer import checkpoint_filter_fn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cfg(url='', **kwargs):
|
|
|
|
|
return {
|
|
|
|
|
'url': url,
|
|
|
|
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
|
|
|
|
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
|
|
|
|
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
|
|
|
|
|
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
|
|
|
|
**kwargs
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
default_cfgs = generate_default_cfgs({
|
|
|
|
|
'beit_base_patch16_224.in22k_ft_in22k_in1k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth'),
|
|
|
|
|
'beit_base_patch16_384.in22k_ft_in22k_in1k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth',
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0,
|
|
|
|
|
),
|
|
|
|
|
'beit_base_patch16_224.in22k_ft_in22k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22k.pth',
|
|
|
|
|
num_classes=21841,
|
|
|
|
|
),
|
|
|
|
|
'beit_large_patch16_224.in22k_ft_in22k_in1k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth'),
|
|
|
|
|
'beit_large_patch16_384.in22k_ft_in22k_in1k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth',
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0,
|
|
|
|
|
),
|
|
|
|
|
'beit_large_patch16_512.in22k_ft_in22k_in1k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth',
|
|
|
|
|
input_size=(3, 512, 512), crop_pct=1.0,
|
|
|
|
|
),
|
|
|
|
|
'beit_large_patch16_224.in22k_ft_in22k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth',
|
|
|
|
|
num_classes=21841,
|
|
|
|
|
),
|
|
|
|
|
|
|
|
|
|
'beitv2_base_patch16_224.in1k_ft_in22k_in1k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth',
|
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
|
|
|
|
|
),
|
|
|
|
|
'beitv2_base_patch16_224.in1k_ft_in22k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth',
|
|
|
|
|
num_classes=21841,
|
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
|
|
|
|
|
),
|
|
|
|
|
'beitv2_large_patch16_224.in1k_ft_in22k_in1k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth',
|
|
|
|
|
crop_pct=0.95,
|
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
|
|
|
|
|
),
|
|
|
|
|
'beitv2_large_patch16_224.in1k_ft_in22k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth',
|
|
|
|
|
num_classes=21841,
|
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
|
|
|
|
|
),
|
|
|
|
|
|
|
|
|
|
'eva_giant_patch14_224.clip_ft_in1k': _cfg(
|
|
|
|
|
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt',
|
|
|
|
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
|
|
|
|
),
|
|
|
|
|
'eva_giant_patch14_336.clip_ft_in1k': _cfg(
|
|
|
|
|
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt',
|
|
|
|
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
|
|
|
|
input_size=(3, 336, 336)),
|
|
|
|
|
'eva_giant_patch14_336.m30m_ft_in22k_in1k': _cfg(
|
|
|
|
|
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt',
|
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
|
|
|
|
|
input_size=(3, 336, 336)),
|
|
|
|
|
'eva_giant_patch14_560.m30m_ft_in22k_in1k': _cfg(
|
|
|
|
|
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt',
|
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
|
|
|
|
|
input_size=(3, 560, 560)),
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor:
|
|
|
|
|
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
|
|
|
|
# cls to token & token 2 cls & cls to cls
|
|
|
|
@ -416,6 +340,82 @@ class Beit(nn.Module):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cfg(url='', **kwargs):
|
|
|
|
|
return {
|
|
|
|
|
'url': url,
|
|
|
|
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
|
|
|
|
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
|
|
|
|
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
|
|
|
|
|
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
|
|
|
|
**kwargs
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
default_cfgs = generate_default_cfgs({
|
|
|
|
|
'beit_base_patch16_224.in22k_ft_in22k_in1k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth'),
|
|
|
|
|
'beit_base_patch16_384.in22k_ft_in22k_in1k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth',
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0,
|
|
|
|
|
),
|
|
|
|
|
'beit_base_patch16_224.in22k_ft_in22k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22k.pth',
|
|
|
|
|
num_classes=21841,
|
|
|
|
|
),
|
|
|
|
|
'beit_large_patch16_224.in22k_ft_in22k_in1k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth'),
|
|
|
|
|
'beit_large_patch16_384.in22k_ft_in22k_in1k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth',
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0,
|
|
|
|
|
),
|
|
|
|
|
'beit_large_patch16_512.in22k_ft_in22k_in1k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth',
|
|
|
|
|
input_size=(3, 512, 512), crop_pct=1.0,
|
|
|
|
|
),
|
|
|
|
|
'beit_large_patch16_224.in22k_ft_in22k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth',
|
|
|
|
|
num_classes=21841,
|
|
|
|
|
),
|
|
|
|
|
|
|
|
|
|
'beitv2_base_patch16_224.in1k_ft_in22k_in1k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth',
|
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
|
|
|
|
|
),
|
|
|
|
|
'beitv2_base_patch16_224.in1k_ft_in22k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth',
|
|
|
|
|
num_classes=21841,
|
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
|
|
|
|
|
),
|
|
|
|
|
'beitv2_large_patch16_224.in1k_ft_in22k_in1k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth',
|
|
|
|
|
crop_pct=0.95,
|
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
|
|
|
|
|
),
|
|
|
|
|
'beitv2_large_patch16_224.in1k_ft_in22k': _cfg(
|
|
|
|
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth',
|
|
|
|
|
num_classes=21841,
|
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
|
|
|
|
|
),
|
|
|
|
|
|
|
|
|
|
'eva_giant_patch14_224.clip_ft_in1k': _cfg(
|
|
|
|
|
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt',
|
|
|
|
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0,
|
|
|
|
|
),
|
|
|
|
|
'eva_giant_patch14_336.clip_ft_in1k': _cfg(
|
|
|
|
|
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt',
|
|
|
|
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
|
|
|
|
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
|
|
|
|
|
'eva_giant_patch14_336.m30m_ft_in22k_in1k': _cfg(
|
|
|
|
|
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt',
|
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
|
|
|
|
|
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
|
|
|
|
|
'eva_giant_patch14_560.m30m_ft_in22k_in1k': _cfg(
|
|
|
|
|
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt',
|
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
|
|
|
|
|
input_size=(3, 560, 560), crop_pct=1.0, crop_mode='squash'),
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _beit_checkpoint_filter_fn(state_dict, model):
|
|
|
|
|
if 'module' in state_dict:
|
|
|
|
|
# beit v2 didn't strip module
|
|
|
|
@ -425,7 +425,7 @@ def _beit_checkpoint_filter_fn(state_dict, model):
|
|
|
|
|
|
|
|
|
|
def _create_beit(variant, pretrained=False, **kwargs):
|
|
|
|
|
if kwargs.get('features_only', None):
|
|
|
|
|
raise RuntimeError('features_only not implemented for Beit models.')
|
|
|
|
|
raise RuntimeError('features_only not implemented for BEiT models.')
|
|
|
|
|
|
|
|
|
|
model = build_model_with_cfg(
|
|
|
|
|
Beit, variant, pretrained,
|
|
|
|
@ -453,15 +453,6 @@ def beit_base_patch16_384(pretrained=False, **kwargs):
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def beit_base_patch16_224_in22k(pretrained=False, **kwargs):
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
|
|
|
|
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1, **kwargs)
|
|
|
|
|
model = _create_beit('beit_base_patch16_224_in22k', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def beit_large_patch16_224(pretrained=False, **kwargs):
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
@ -489,15 +480,6 @@ def beit_large_patch16_512(pretrained=False, **kwargs):
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def beit_large_patch16_224_in22k(pretrained=False, **kwargs):
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
|
|
|
|
|
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
|
|
|
|
|
model = _create_beit('beit_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def beitv2_base_patch16_224(pretrained=False, **kwargs):
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
@ -507,15 +489,6 @@ def beitv2_base_patch16_224(pretrained=False, **kwargs):
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def beitv2_base_patch16_224_in22k(pretrained=False, **kwargs):
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
|
|
|
|
|
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
|
|
|
|
|
model = _create_beit('beitv2_base_patch16_224_in22k', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def beitv2_large_patch16_224(pretrained=False, **kwargs):
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
@ -525,15 +498,6 @@ def beitv2_large_patch16_224(pretrained=False, **kwargs):
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def beitv2_large_patch16_224_in22k(pretrained=False, **kwargs):
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
|
|
|
|
|
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
|
|
|
|
|
model = _create_beit('beitv2_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def eva_giant_patch14_224(pretrained=False, **kwargs):
|
|
|
|
|
""" EVA-g model https://arxiv.org/abs/2211.07636 """
|
|
|
|
|