pull/804/merge
Ross Wightman 2 years ago
parent dc90816f26
commit 914544fc81

@ -1,6 +1,25 @@
""" BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
Model from official source: https://github.com/microsoft/unilm/tree/master/beit
and
https://github.com/microsoft/unilm/tree/master/beit2
@inproceedings{beit,
title={{BEiT}: {BERT} Pre-Training of Image Transformers},
author={Hangbo Bao and Li Dong and Songhao Piao and Furu Wei},
booktitle={International Conference on Learning Representations},
year={2022},
url={https://openreview.net/forum?id=p-BhZSz59o4}
}
@article{beitv2,
title={{BEiT v2}: Masked Image Modeling with Vector-Quantized Visual Tokenizers},
author={Zhiliang Peng and Li Dong and Hangbo Bao and Qixiang Ye and Furu Wei},
year={2022},
eprint={2208.06366},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
At this point only the 1k fine-tuned classification weights and model configs have been added,
see original source above for pre-training models and procedure.
@ -27,6 +46,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_
from .registry import register_model
@ -69,6 +89,26 @@ default_cfgs = {
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth',
num_classes=21841,
),
'beitv2_base_patch16_224': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
),
'beitv2_base_patch16_224_in22k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth',
num_classes=21841,
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
),
'beitv2_large_patch16_224': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth',
crop_pct=0.95,
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
),
'beitv2_large_patch16_224_in22k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth',
num_classes=21841,
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
),
}
@ -417,3 +457,39 @@ def beit_large_patch16_224_in22k(pretrained=False, **kwargs):
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
model = _create_beit('beit_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def beitv2_base_patch16_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
model = _create_beit('beitv2_base_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def beitv2_base_patch16_224_in22k(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
model = _create_beit('beitv2_base_patch16_224_in22k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def beitv2_large_patch16_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
model = _create_beit('beitv2_large_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def beitv2_large_patch16_224_in22k(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
model = _create_beit('beitv2_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs)
return model

Loading…
Cancel
Save