From eba07b0de7fcbd418ce652f7cb3162cda21c39a0 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 6 Dec 2022 16:45:11 -0800 Subject: [PATCH] Add eva models to beit.py --- README.md | 3 ++ timm/models/beit.py | 109 ++++++++++++++++++++++++++++++++++---------- 2 files changed, 88 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index 798f94f3..735cb5a4 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,9 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before ## What's New +# Dec 6, 2022 +* Add 'EVA g', BEiT style ViT-g/14 model weights w/ both MIM pretrain and CLIP pretrain from https://github.com/baaivision/EVA + # Dec 5, 2022 * Pre-release (`0.8.0dev0`) of multi-weight support (`model_arch.pretrained_tag`). Install with `pip install --pre timm` diff --git a/timm/models/beit.py b/timm/models/beit.py index 1f6bf82b..c36683ef 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -1,8 +1,6 @@ """ BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) Model from official source: https://github.com/microsoft/unilm/tree/master/beit -and -https://github.com/microsoft/unilm/tree/master/beit2 @inproceedings{beit, title={{BEiT}: {BERT} Pre-Training of Image Transformers}, @@ -12,6 +10,8 @@ year={2022}, url={https://openreview.net/forum?id=p-BhZSz59o4} } +BEiT-v2 from https://github.com/microsoft/unilm/tree/master/beit2 + @article{beitv2, title={{BEiT v2}: Masked Image Modeling with Vector-Quantized Visual Tokenizers}, author={Zhiliang Peng and Li Dong and Hangbo Bao and Qixiang Ye and Furu Wei}, @@ -21,6 +21,17 @@ archivePrefix={arXiv}, primaryClass={cs.CV} } +EVA from https://github.com/baaivision/EVA , paper: https://arxiv.org/abs/2211.07636 + +@article{EVA, + title={EVA: Exploring the Limits of Masked Visual Representation Learning at Scale}, + author={Fang, Yuxin and Wang, Wen and Xie, Binhui and Sun, Quan and Wu, Ledell and Wang, Xinggang and Huang, + Tiejun and Wang, Xinlong and Cao, Yue}, + journal={arXiv preprint arXiv:2211.07636}, + year={2022} +} + + At this point only the 1k fine-tuned classification weights and model configs have been added, see original source above for pre-training models and procedure. @@ -37,6 +48,9 @@ Modifications by / Copyright 2021 Ross Wightman, original copyrights below # https://github.com/facebookresearch/deit/ # https://github.com/facebookresearch/dino # --------------------------------------------------------' + +# EVA models Copyright (c) 2022 BAAI-Vision + import math from functools import partial from typing import Optional, Tuple @@ -46,9 +60,10 @@ import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from .helpers import build_model_with_cfg from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_ +from .pretrained import generate_default_cfgs from .registry import register_model from .vision_transformer import checkpoint_filter_fn @@ -64,52 +79,72 @@ def _cfg(url='', **kwargs): } -default_cfgs = { - 'beit_base_patch16_224': _cfg( +default_cfgs = generate_default_cfgs({ + 'beit_base_patch16_224.in22k_ft_in22k_in1k': _cfg( url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth'), - 'beit_base_patch16_384': _cfg( + 'beit_base_patch16_384.in22k_ft_in22k_in1k': _cfg( url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth', input_size=(3, 384, 384), crop_pct=1.0, ), - 'beit_base_patch16_224_in22k': _cfg( + 'beit_base_patch16_224.in22k_ft_in22k': _cfg( url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22k.pth', num_classes=21841, ), - 'beit_large_patch16_224': _cfg( + 'beit_large_patch16_224.in22k_ft_in22k_in1k': _cfg( url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth'), - 'beit_large_patch16_384': _cfg( + 'beit_large_patch16_384.in22k_ft_in22k_in1k': _cfg( url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth', input_size=(3, 384, 384), crop_pct=1.0, ), - 'beit_large_patch16_512': _cfg( + 'beit_large_patch16_512.in22k_ft_in22k_in1k': _cfg( url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth', input_size=(3, 512, 512), crop_pct=1.0, ), - 'beit_large_patch16_224_in22k': _cfg( + 'beit_large_patch16_224.in22k_ft_in22k': _cfg( url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth', num_classes=21841, ), - 'beitv2_base_patch16_224': _cfg( + 'beitv2_base_patch16_224.in1k_ft_in22k_in1k': _cfg( url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD ), - 'beitv2_base_patch16_224_in22k': _cfg( + 'beitv2_base_patch16_224.in1k_ft_in22k': _cfg( url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth', num_classes=21841, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD ), - 'beitv2_large_patch16_224': _cfg( + 'beitv2_large_patch16_224.in1k_ft_in22k_in1k': _cfg( url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth', crop_pct=0.95, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD ), - 'beitv2_large_patch16_224_in22k': _cfg( + 'beitv2_large_patch16_224.in1k_ft_in22k': _cfg( url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth', num_classes=21841, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD ), -} + + 'eva_giant_patch14_224.clip_ft_in1k': _cfg( + hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + ), + 'eva_giant_patch14_336.clip_ft_in1k': _cfg( + hf_hub_id='BAAI/EVA', + hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 336, 336)), + 'eva_giant_patch14_336.m30m_ft_in22k_in1k': _cfg( + hf_hub_id='BAAI/EVA', + hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, + input_size=(3, 336, 336)), + 'eva_giant_patch14_560.m30m_ft_in22k_in1k': _cfg( + hf_hub_id='BAAI/EVA', + hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, + input_size=(3, 560, 560)), +}) def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor: @@ -415,7 +450,7 @@ def beit_base_patch16_224(pretrained=False, **kwargs): @register_model def beit_base_patch16_384(pretrained=False, **kwargs): model_kwargs = dict( - img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1, **kwargs) model = _create_beit('beit_base_patch16_384', pretrained=pretrained, **model_kwargs) return model @@ -424,7 +459,7 @@ def beit_base_patch16_384(pretrained=False, **kwargs): @register_model def beit_base_patch16_224_in22k(pretrained=False, **kwargs): model_kwargs = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + patch_size=16, embed_dim=768, depth=12, num_heads=12, use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1, **kwargs) model = _create_beit('beit_base_patch16_224_in22k', pretrained=pretrained, **model_kwargs) return model @@ -433,7 +468,7 @@ def beit_base_patch16_224_in22k(pretrained=False, **kwargs): @register_model def beit_large_patch16_224(pretrained=False, **kwargs): model_kwargs = dict( - patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + patch_size=16, embed_dim=1024, depth=24, num_heads=16, use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) model = _create_beit('beit_large_patch16_224', pretrained=pretrained, **model_kwargs) return model @@ -442,7 +477,7 @@ def beit_large_patch16_224(pretrained=False, **kwargs): @register_model def beit_large_patch16_384(pretrained=False, **kwargs): model_kwargs = dict( - img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) model = _create_beit('beit_large_patch16_384', pretrained=pretrained, **model_kwargs) return model @@ -451,7 +486,7 @@ def beit_large_patch16_384(pretrained=False, **kwargs): @register_model def beit_large_patch16_512(pretrained=False, **kwargs): model_kwargs = dict( - img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) model = _create_beit('beit_large_patch16_512', pretrained=pretrained, **model_kwargs) return model @@ -460,7 +495,7 @@ def beit_large_patch16_512(pretrained=False, **kwargs): @register_model def beit_large_patch16_224_in22k(pretrained=False, **kwargs): model_kwargs = dict( - patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + patch_size=16, embed_dim=1024, depth=24, num_heads=16, use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) model = _create_beit('beit_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs) return model @@ -487,7 +522,7 @@ def beitv2_base_patch16_224_in22k(pretrained=False, **kwargs): @register_model def beitv2_large_patch16_224(pretrained=False, **kwargs): model_kwargs = dict( - patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + patch_size=16, embed_dim=1024, depth=24, num_heads=16, use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) model = _create_beit('beitv2_large_patch16_224', pretrained=pretrained, **model_kwargs) return model @@ -496,7 +531,33 @@ def beitv2_large_patch16_224(pretrained=False, **kwargs): @register_model def beitv2_large_patch16_224_in22k(pretrained=False, **kwargs): model_kwargs = dict( - patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + patch_size=16, embed_dim=1024, depth=24, num_heads=16, use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) model = _create_beit('beitv2_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs) return model + + +def eva_giant_patch14_224(pretrained=False, **kwargs): + """ EVA-g model https://arxiv.org/abs/2211.07636 """ + model_kwargs = dict( + patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs) + model = _create_beit('eva_giant_patch14_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def eva_giant_patch14_336(pretrained=False, **kwargs): + """ EVA-g model https://arxiv.org/abs/2211.07636 """ + model_kwargs = dict( + patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs) + model = _create_beit('eva_giant_patch14_336', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def eva_giant_patch14_560(pretrained=False, **kwargs): + """ EVA-g model https://arxiv.org/abs/2211.07636 """ + model_kwargs = dict( + patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs) + model = _create_beit('eva_giant_patch14_560', pretrained=pretrained, **model_kwargs) + return model