diff --git a/timm/models/beit.py b/timm/models/beit.py index a2083a4a..60497d9a 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -1,6 +1,25 @@ """ BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) Model from official source: https://github.com/microsoft/unilm/tree/master/beit +and +https://github.com/microsoft/unilm/tree/master/beit2 + +@inproceedings{beit, +title={{BEiT}: {BERT} Pre-Training of Image Transformers}, +author={Hangbo Bao and Li Dong and Songhao Piao and Furu Wei}, +booktitle={International Conference on Learning Representations}, +year={2022}, +url={https://openreview.net/forum?id=p-BhZSz59o4} +} + +@article{beitv2, +title={{BEiT v2}: Masked Image Modeling with Vector-Quantized Visual Tokenizers}, +author={Zhiliang Peng and Li Dong and Hangbo Bao and Qixiang Ye and Furu Wei}, +year={2022}, +eprint={2208.06366}, +archivePrefix={arXiv}, +primaryClass={cs.CV} +} At this point only the 1k fine-tuned classification weights and model configs have been added, see original source above for pre-training models and procedure. @@ -27,6 +46,7 @@ import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_ from .registry import register_model @@ -69,6 +89,26 @@ default_cfgs = { url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth', num_classes=21841, ), + + 'beitv2_base_patch16_224': _cfg( + url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD + ), + 'beitv2_base_patch16_224_in22k': _cfg( + url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth', + num_classes=21841, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD + ), + 'beitv2_large_patch16_224': _cfg( + url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth', + crop_pct=0.95, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD + ), + 'beitv2_large_patch16_224_in22k': _cfg( + url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth', + num_classes=21841, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD + ), } @@ -417,3 +457,39 @@ def beit_large_patch16_224_in22k(pretrained=False, **kwargs): use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) model = _create_beit('beit_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs) return model + + +@register_model +def beitv2_base_patch16_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) + model = _create_beit('beitv2_base_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def beitv2_base_patch16_224_in22k(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) + model = _create_beit('beitv2_base_patch16_224_in22k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def beitv2_large_patch16_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) + model = _create_beit('beitv2_large_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def beitv2_large_patch16_224_in22k(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) + model = _create_beit('beitv2_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs) + return model