From eba07b0de7fcbd418ce652f7cb3162cda21c39a0 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 6 Dec 2022 16:45:11 -0800 Subject: [PATCH 1/9] Add eva models to beit.py --- README.md | 3 ++ timm/models/beit.py | 109 ++++++++++++++++++++++++++++++++++---------- 2 files changed, 88 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index 798f94f3..735cb5a4 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,9 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before ## What's New +# Dec 6, 2022 +* Add 'EVA g', BEiT style ViT-g/14 model weights w/ both MIM pretrain and CLIP pretrain from https://github.com/baaivision/EVA + # Dec 5, 2022 * Pre-release (`0.8.0dev0`) of multi-weight support (`model_arch.pretrained_tag`). Install with `pip install --pre timm` diff --git a/timm/models/beit.py b/timm/models/beit.py index 1f6bf82b..c36683ef 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -1,8 +1,6 @@ """ BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) Model from official source: https://github.com/microsoft/unilm/tree/master/beit -and -https://github.com/microsoft/unilm/tree/master/beit2 @inproceedings{beit, title={{BEiT}: {BERT} Pre-Training of Image Transformers}, @@ -12,6 +10,8 @@ year={2022}, url={https://openreview.net/forum?id=p-BhZSz59o4} } +BEiT-v2 from https://github.com/microsoft/unilm/tree/master/beit2 + @article{beitv2, title={{BEiT v2}: Masked Image Modeling with Vector-Quantized Visual Tokenizers}, author={Zhiliang Peng and Li Dong and Hangbo Bao and Qixiang Ye and Furu Wei}, @@ -21,6 +21,17 @@ archivePrefix={arXiv}, primaryClass={cs.CV} } +EVA from https://github.com/baaivision/EVA , paper: https://arxiv.org/abs/2211.07636 + +@article{EVA, + title={EVA: Exploring the Limits of Masked Visual Representation Learning at Scale}, + author={Fang, Yuxin and Wang, Wen and Xie, Binhui and Sun, Quan and Wu, Ledell and Wang, Xinggang and Huang, + Tiejun and Wang, Xinlong and Cao, Yue}, + journal={arXiv preprint arXiv:2211.07636}, + year={2022} +} + + At this point only the 1k fine-tuned classification weights and model configs have been added, see original source above for pre-training models and procedure. @@ -37,6 +48,9 @@ Modifications by / Copyright 2021 Ross Wightman, original copyrights below # https://github.com/facebookresearch/deit/ # https://github.com/facebookresearch/dino # --------------------------------------------------------' + +# EVA models Copyright (c) 2022 BAAI-Vision + import math from functools import partial from typing import Optional, Tuple @@ -46,9 +60,10 @@ import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from .helpers import build_model_with_cfg from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_ +from .pretrained import generate_default_cfgs from .registry import register_model from .vision_transformer import checkpoint_filter_fn @@ -64,52 +79,72 @@ def _cfg(url='', **kwargs): } -default_cfgs = { - 'beit_base_patch16_224': _cfg( +default_cfgs = generate_default_cfgs({ + 'beit_base_patch16_224.in22k_ft_in22k_in1k': _cfg( url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth'), - 'beit_base_patch16_384': _cfg( + 'beit_base_patch16_384.in22k_ft_in22k_in1k': _cfg( url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth', input_size=(3, 384, 384), crop_pct=1.0, ), - 'beit_base_patch16_224_in22k': _cfg( + 'beit_base_patch16_224.in22k_ft_in22k': _cfg( url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22k.pth', num_classes=21841, ), - 'beit_large_patch16_224': _cfg( + 'beit_large_patch16_224.in22k_ft_in22k_in1k': _cfg( url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth'), - 'beit_large_patch16_384': _cfg( + 'beit_large_patch16_384.in22k_ft_in22k_in1k': _cfg( url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth', input_size=(3, 384, 384), crop_pct=1.0, ), - 'beit_large_patch16_512': _cfg( + 'beit_large_patch16_512.in22k_ft_in22k_in1k': _cfg( url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth', input_size=(3, 512, 512), crop_pct=1.0, ), - 'beit_large_patch16_224_in22k': _cfg( + 'beit_large_patch16_224.in22k_ft_in22k': _cfg( url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth', num_classes=21841, ), - 'beitv2_base_patch16_224': _cfg( + 'beitv2_base_patch16_224.in1k_ft_in22k_in1k': _cfg( url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD ), - 'beitv2_base_patch16_224_in22k': _cfg( + 'beitv2_base_patch16_224.in1k_ft_in22k': _cfg( url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth', num_classes=21841, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD ), - 'beitv2_large_patch16_224': _cfg( + 'beitv2_large_patch16_224.in1k_ft_in22k_in1k': _cfg( url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth', crop_pct=0.95, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD ), - 'beitv2_large_patch16_224_in22k': _cfg( + 'beitv2_large_patch16_224.in1k_ft_in22k': _cfg( url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth', num_classes=21841, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD ), -} + + 'eva_giant_patch14_224.clip_ft_in1k': _cfg( + hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + ), + 'eva_giant_patch14_336.clip_ft_in1k': _cfg( + hf_hub_id='BAAI/EVA', + hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 336, 336)), + 'eva_giant_patch14_336.m30m_ft_in22k_in1k': _cfg( + hf_hub_id='BAAI/EVA', + hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, + input_size=(3, 336, 336)), + 'eva_giant_patch14_560.m30m_ft_in22k_in1k': _cfg( + hf_hub_id='BAAI/EVA', + hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, + input_size=(3, 560, 560)), +}) def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor: @@ -415,7 +450,7 @@ def beit_base_patch16_224(pretrained=False, **kwargs): @register_model def beit_base_patch16_384(pretrained=False, **kwargs): model_kwargs = dict( - img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1, **kwargs) model = _create_beit('beit_base_patch16_384', pretrained=pretrained, **model_kwargs) return model @@ -424,7 +459,7 @@ def beit_base_patch16_384(pretrained=False, **kwargs): @register_model def beit_base_patch16_224_in22k(pretrained=False, **kwargs): model_kwargs = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + patch_size=16, embed_dim=768, depth=12, num_heads=12, use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1, **kwargs) model = _create_beit('beit_base_patch16_224_in22k', pretrained=pretrained, **model_kwargs) return model @@ -433,7 +468,7 @@ def beit_base_patch16_224_in22k(pretrained=False, **kwargs): @register_model def beit_large_patch16_224(pretrained=False, **kwargs): model_kwargs = dict( - patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + patch_size=16, embed_dim=1024, depth=24, num_heads=16, use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) model = _create_beit('beit_large_patch16_224', pretrained=pretrained, **model_kwargs) return model @@ -442,7 +477,7 @@ def beit_large_patch16_224(pretrained=False, **kwargs): @register_model def beit_large_patch16_384(pretrained=False, **kwargs): model_kwargs = dict( - img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) model = _create_beit('beit_large_patch16_384', pretrained=pretrained, **model_kwargs) return model @@ -451,7 +486,7 @@ def beit_large_patch16_384(pretrained=False, **kwargs): @register_model def beit_large_patch16_512(pretrained=False, **kwargs): model_kwargs = dict( - img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) model = _create_beit('beit_large_patch16_512', pretrained=pretrained, **model_kwargs) return model @@ -460,7 +495,7 @@ def beit_large_patch16_512(pretrained=False, **kwargs): @register_model def beit_large_patch16_224_in22k(pretrained=False, **kwargs): model_kwargs = dict( - patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + patch_size=16, embed_dim=1024, depth=24, num_heads=16, use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) model = _create_beit('beit_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs) return model @@ -487,7 +522,7 @@ def beitv2_base_patch16_224_in22k(pretrained=False, **kwargs): @register_model def beitv2_large_patch16_224(pretrained=False, **kwargs): model_kwargs = dict( - patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + patch_size=16, embed_dim=1024, depth=24, num_heads=16, use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) model = _create_beit('beitv2_large_patch16_224', pretrained=pretrained, **model_kwargs) return model @@ -496,7 +531,33 @@ def beitv2_large_patch16_224(pretrained=False, **kwargs): @register_model def beitv2_large_patch16_224_in22k(pretrained=False, **kwargs): model_kwargs = dict( - patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + patch_size=16, embed_dim=1024, depth=24, num_heads=16, use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) model = _create_beit('beitv2_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs) return model + + +def eva_giant_patch14_224(pretrained=False, **kwargs): + """ EVA-g model https://arxiv.org/abs/2211.07636 """ + model_kwargs = dict( + patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs) + model = _create_beit('eva_giant_patch14_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def eva_giant_patch14_336(pretrained=False, **kwargs): + """ EVA-g model https://arxiv.org/abs/2211.07636 """ + model_kwargs = dict( + patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs) + model = _create_beit('eva_giant_patch14_336', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def eva_giant_patch14_560(pretrained=False, **kwargs): + """ EVA-g model https://arxiv.org/abs/2211.07636 """ + model_kwargs = dict( + patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs) + model = _create_beit('eva_giant_patch14_560', pretrained=pretrained, **model_kwargs) + return model From 3cc4d7a894fbc2477d58551fffc9f28d8434c3d3 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 6 Dec 2022 16:57:07 -0800 Subject: [PATCH 2/9] Fix missing register for 224 eva model --- timm/models/beit.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/timm/models/beit.py b/timm/models/beit.py index c36683ef..162ba81b 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -130,18 +130,15 @@ default_cfgs = generate_default_cfgs({ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, ), 'eva_giant_patch14_336.clip_ft_in1k': _cfg( - hf_hub_id='BAAI/EVA', - hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt', + hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, input_size=(3, 336, 336)), 'eva_giant_patch14_336.m30m_ft_in22k_in1k': _cfg( - hf_hub_id='BAAI/EVA', - hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt', + hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 336, 336)), 'eva_giant_patch14_560.m30m_ft_in22k_in1k': _cfg( - hf_hub_id='BAAI/EVA', - hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt', + hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 560, 560)), }) @@ -537,6 +534,7 @@ def beitv2_large_patch16_224_in22k(pretrained=False, **kwargs): return model +@register_model def eva_giant_patch14_224(pretrained=False, **kwargs): """ EVA-g model https://arxiv.org/abs/2211.07636 """ model_kwargs = dict( From 98047ef5e35c18a0dccf16da6a29788e45d7225c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 6 Dec 2022 23:14:59 -0800 Subject: [PATCH 3/9] Add EVA FT results, hopefully fix BEiT test failures --- README.md | 11 ++- benchmark.py | 6 +- tests/test_models.py | 4 +- timm/models/beit.py | 192 ++++++++++++++++---------------------- timm/models/pretrained.py | 3 +- 5 files changed, 96 insertions(+), 120 deletions(-) diff --git a/README.md b/README.md index 735cb5a4..abcd01a4 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,16 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before ## What's New # Dec 6, 2022 -* Add 'EVA g', BEiT style ViT-g/14 model weights w/ both MIM pretrain and CLIP pretrain from https://github.com/baaivision/EVA +* Add 'EVA g', BEiT style ViT-g/14 model weights w/ both MIM pretrain and CLIP pretrain to `beit.py`. + * original source: https://github.com/baaivision/EVA + * paper: https://arxiv.org/abs/2211.07636 + +| model | top1 | param_count | gmac | macts | hub | +|:-----------------------------------------|-------:|--------------:|-------:|--------:|:----------------------------------------| +| eva_giant_patch14_560.m30m_ft_in22k_in1k | 89.8 | 1014.4 | 1906.8 | 2577.2 | [link](https://huggingface.co/BAAI/EVA) | +| eva_giant_patch14_336.m30m_ft_in22k_in1k | 89.6 | 1013 | 620.6 | 550.7 | [link](https://huggingface.co/BAAI/EVA) | +| eva_giant_patch14_336.clip_ft_in1k | 89.4 | 1013 | 620.6 | 550.7 | [link](https://huggingface.co/BAAI/EVA) | +| eva_giant_patch14_224.clip_ft_in1k | 89.1 | 1012.6 | 267.2 | 192.6 | [link](https://huggingface.co/BAAI/EVA) | # Dec 5, 2022 diff --git a/benchmark.py b/benchmark.py index 04557a7d..9adeb465 100755 --- a/benchmark.py +++ b/benchmark.py @@ -80,9 +80,11 @@ parser.add_argument('--results-file', default='', type=str, parser.add_argument('--results-format', default='csv', type=str, help='Format for results file one of (csv, json) (default: csv).') parser.add_argument('--num-warm-iter', default=10, type=int, - metavar='N', help='Number of warmup iterations (default: 10)') + help='Number of warmup iterations (default: 10)') parser.add_argument('--num-bench-iter', default=40, type=int, - metavar='N', help='Number of benchmark iterations (default: 40)') + help='Number of benchmark iterations (default: 40)') +parser.add_argument('--device', default='cuda', type=str, + help="device to run benchmark on") # common inference / train args parser.add_argument('--model', '-m', metavar='NAME', default='resnet50', diff --git a/tests/test_models.py b/tests/test_models.py index dd1330eb..87d75cbd 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -27,7 +27,7 @@ NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*', 'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*', - 'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', + 'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', 'eva_*' ] NUM_NON_STD = len(NON_STD_FILTERS) @@ -39,7 +39,7 @@ if 'GITHUB_ACTIONS' in os.environ: '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*', '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*', 'swin*giant*'] - NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*'] + NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*', 'eva_giant*'] else: EXCLUDE_FILTERS = [] NON_STD_EXCLUDE_FILTERS = ['vit_gi*'] diff --git a/timm/models/beit.py b/timm/models/beit.py index 162ba81b..c44256a3 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -1,4 +1,4 @@ -""" BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) +""" BEiT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) Model from official source: https://github.com/microsoft/unilm/tree/master/beit @@ -68,82 +68,6 @@ from .registry import register_model from .vision_transformer import checkpoint_filter_fn -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, - 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), - 'first_conv': 'patch_embed.proj', 'classifier': 'head', - **kwargs - } - - -default_cfgs = generate_default_cfgs({ - 'beit_base_patch16_224.in22k_ft_in22k_in1k': _cfg( - url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth'), - 'beit_base_patch16_384.in22k_ft_in22k_in1k': _cfg( - url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth', - input_size=(3, 384, 384), crop_pct=1.0, - ), - 'beit_base_patch16_224.in22k_ft_in22k': _cfg( - url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22k.pth', - num_classes=21841, - ), - 'beit_large_patch16_224.in22k_ft_in22k_in1k': _cfg( - url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth'), - 'beit_large_patch16_384.in22k_ft_in22k_in1k': _cfg( - url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth', - input_size=(3, 384, 384), crop_pct=1.0, - ), - 'beit_large_patch16_512.in22k_ft_in22k_in1k': _cfg( - url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth', - input_size=(3, 512, 512), crop_pct=1.0, - ), - 'beit_large_patch16_224.in22k_ft_in22k': _cfg( - url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth', - num_classes=21841, - ), - - 'beitv2_base_patch16_224.in1k_ft_in22k_in1k': _cfg( - url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD - ), - 'beitv2_base_patch16_224.in1k_ft_in22k': _cfg( - url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth', - num_classes=21841, - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD - ), - 'beitv2_large_patch16_224.in1k_ft_in22k_in1k': _cfg( - url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth', - crop_pct=0.95, - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD - ), - 'beitv2_large_patch16_224.in1k_ft_in22k': _cfg( - url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth', - num_classes=21841, - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD - ), - - 'eva_giant_patch14_224.clip_ft_in1k': _cfg( - hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, - ), - 'eva_giant_patch14_336.clip_ft_in1k': _cfg( - hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, - input_size=(3, 336, 336)), - 'eva_giant_patch14_336.m30m_ft_in22k_in1k': _cfg( - hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, - input_size=(3, 336, 336)), - 'eva_giant_patch14_560.m30m_ft_in22k_in1k': _cfg( - hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, - input_size=(3, 560, 560)), -}) - - def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor: num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 # cls to token & token 2 cls & cls to cls @@ -416,6 +340,82 @@ class Beit(nn.Module): return x +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'beit_base_patch16_224.in22k_ft_in22k_in1k': _cfg( + url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth'), + 'beit_base_patch16_384.in22k_ft_in22k_in1k': _cfg( + url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth', + input_size=(3, 384, 384), crop_pct=1.0, + ), + 'beit_base_patch16_224.in22k_ft_in22k': _cfg( + url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22k.pth', + num_classes=21841, + ), + 'beit_large_patch16_224.in22k_ft_in22k_in1k': _cfg( + url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth'), + 'beit_large_patch16_384.in22k_ft_in22k_in1k': _cfg( + url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth', + input_size=(3, 384, 384), crop_pct=1.0, + ), + 'beit_large_patch16_512.in22k_ft_in22k_in1k': _cfg( + url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth', + input_size=(3, 512, 512), crop_pct=1.0, + ), + 'beit_large_patch16_224.in22k_ft_in22k': _cfg( + url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth', + num_classes=21841, + ), + + 'beitv2_base_patch16_224.in1k_ft_in22k_in1k': _cfg( + url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD + ), + 'beitv2_base_patch16_224.in1k_ft_in22k': _cfg( + url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth', + num_classes=21841, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD + ), + 'beitv2_large_patch16_224.in1k_ft_in22k_in1k': _cfg( + url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth', + crop_pct=0.95, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD + ), + 'beitv2_large_patch16_224.in1k_ft_in22k': _cfg( + url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth', + num_classes=21841, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD + ), + + 'eva_giant_patch14_224.clip_ft_in1k': _cfg( + hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, + ), + 'eva_giant_patch14_336.clip_ft_in1k': _cfg( + hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'), + 'eva_giant_patch14_336.m30m_ft_in22k_in1k': _cfg( + hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, + input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'), + 'eva_giant_patch14_560.m30m_ft_in22k_in1k': _cfg( + hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, + input_size=(3, 560, 560), crop_pct=1.0, crop_mode='squash'), +}) + + def _beit_checkpoint_filter_fn(state_dict, model): if 'module' in state_dict: # beit v2 didn't strip module @@ -425,7 +425,7 @@ def _beit_checkpoint_filter_fn(state_dict, model): def _create_beit(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Beit models.') + raise RuntimeError('features_only not implemented for BEiT models.') model = build_model_with_cfg( Beit, variant, pretrained, @@ -453,15 +453,6 @@ def beit_base_patch16_384(pretrained=False, **kwargs): return model -@register_model -def beit_base_patch16_224_in22k(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=12, - use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1, **kwargs) - model = _create_beit('beit_base_patch16_224_in22k', pretrained=pretrained, **model_kwargs) - return model - - @register_model def beit_large_patch16_224(pretrained=False, **kwargs): model_kwargs = dict( @@ -489,15 +480,6 @@ def beit_large_patch16_512(pretrained=False, **kwargs): return model -@register_model -def beit_large_patch16_224_in22k(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=16, embed_dim=1024, depth=24, num_heads=16, - use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) - model = _create_beit('beit_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs) - return model - - @register_model def beitv2_base_patch16_224(pretrained=False, **kwargs): model_kwargs = dict( @@ -507,15 +489,6 @@ def beitv2_base_patch16_224(pretrained=False, **kwargs): return model -@register_model -def beitv2_base_patch16_224_in22k(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, - use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) - model = _create_beit('beitv2_base_patch16_224_in22k', pretrained=pretrained, **model_kwargs) - return model - - @register_model def beitv2_large_patch16_224(pretrained=False, **kwargs): model_kwargs = dict( @@ -525,15 +498,6 @@ def beitv2_large_patch16_224(pretrained=False, **kwargs): return model -@register_model -def beitv2_large_patch16_224_in22k(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=16, embed_dim=1024, depth=24, num_heads=16, - use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) - model = _create_beit('beitv2_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs) - return model - - @register_model def eva_giant_patch14_224(pretrained=False, **kwargs): """ EVA-g model https://arxiv.org/abs/2211.07636 """ diff --git a/timm/models/pretrained.py b/timm/models/pretrained.py index 60f38fd4..2ca7ac5a 100644 --- a/timm/models/pretrained.py +++ b/timm/models/pretrained.py @@ -59,10 +59,11 @@ class PretrainedCfg: def filter_pretrained_cfg(cfg, remove_source=False, remove_null=True): filtered_cfg = {} + keep_none = {'pool_size', 'first_conv', 'classifier'} # always keep these keys, even if none for k, v in cfg.items(): if remove_source and k in {'url', 'file', 'hf_hub_id', 'hf_hub_id', 'hf_hub_filename', 'source'}: continue - if remove_null and v is None: + if remove_null and v is None and k not in keep_none: continue filtered_cfg[k] = v return filtered_cfg From 6a92587e0d95d662b3fee0796cb82b34a706c5b0 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 7 Dec 2022 09:34:21 -0800 Subject: [PATCH 4/9] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index abcd01a4..994775f1 100644 --- a/README.md +++ b/README.md @@ -388,6 +388,7 @@ A full version of the list below with source links can be found in the [document * MobileNet-V2 - https://arxiv.org/abs/1801.04381 * Single-Path NAS - https://arxiv.org/abs/1904.02877 * TinyNet - https://arxiv.org/abs/2010.14819 +* EVA - https://arxiv.org/abs/2211.07636 * GCViT (Global Context Vision Transformer) - https://arxiv.org/abs/2206.09959 * GhostNet - https://arxiv.org/abs/1911.11907 * gMLP - https://arxiv.org/abs/2105.08050 From 7c4ed4d5a43f46084cc9b6f20a5edb8839bbeb14 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 8 Dec 2022 16:20:49 -0800 Subject: [PATCH 5/9] Add EVA-large models --- README.md | 11 +++++++++ timm/models/vision_transformer.py | 37 +++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/README.md b/README.md index 994775f1..331ea7a8 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,17 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before ## What's New +# Dec 8, 2022 +* Add 'EVA l' to `vision_transformer.py`, MAE style ViT-L/14 MIM pretrain w/ EVA-CLIP targets, FT on ImageNet-1k (w/ ImageNet-22k intermediate for some) + * original source: https://github.com/baaivision/EVA + +| model | top1 | param_count | gmac | macts | hub | +|:------------------------------------------|-----:|------------:|------:|------:|:----------------------------------------| +| eva_large_patch14_336.in22k_ft_in22k_in1k | 89.2 | 304.5 | 191.1 | 270.2 | [link](https://huggingface.co/BAAI/EVA) | +| eva_large_patch14_336.in22k_ft_in1k | 88.7 | 304.5 | 191.1 | 270.2 | [link](https://huggingface.co/BAAI/EVA) | +| eva_large_patch14_196.in22k_ft_in22k_in1k | 88.6 | 304.1 | 61.6 | 63.5 | [link](https://huggingface.co/BAAI/EVA) | +| eva_large_patch14_196.in22k_ft_in1k | 87.9 | 304.1 | 61.6 | 63.5 | [link](https://huggingface.co/BAAI/EVA) | + # Dec 6, 2022 * Add 'EVA g', BEiT style ViT-g/14 model weights w/ both MIM pretrain and CLIP pretrain to `beit.py`. * original source: https://github.com/baaivision/EVA diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 4effbed6..820dc656 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -933,6 +933,25 @@ default_cfgs = generate_default_cfgs({ 'vit_small_patch16_36x1_224': _cfg(url=''), 'vit_small_patch16_18x2_224': _cfg(url=''), 'vit_base_patch16_18x2_224': _cfg(url=''), + + # EVA fine-tuned weights from MAE style MIM - EVA-CLIP target pretrain + # https://github.com/baaivision/EVA/blob/7ecf2c0a370d97967e86d047d7af9188f78d2df3/eva/README.md#eva-l-learning-better-mim-representations-from-eva-clip + 'eva_large_patch14_196.in22k_ft_in22k_in1k': _cfg( + hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_21k_to_1k_ft_88p6.pt', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 196, 196), crop_pct=1.0), + 'eva_large_patch14_336.in22k_ft_in22k_in1k': _cfg( + hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_21k_to_1k_ft_89p2.pt', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'), + 'eva_large_patch14_196.in22k_ft_in1k': _cfg( + hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_1k_ft_88p0.pt', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 196, 196), crop_pct=1.0), + 'eva_large_patch14_336.in22k_ft_in1k': _cfg( + hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_1k_ft_88p65.pt', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'), }) @@ -1354,3 +1373,21 @@ def vit_base_patch16_18x2_224(pretrained=False, **kwargs): patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock, **kwargs) model = _create_vision_transformer('vit_base_patch16_18x2_224', pretrained=pretrained, **model_kwargs) return model + + +@register_model +def eva_large_patch14_196(pretrained=False, **kwargs): + """ EVA-large model https://arxiv.org/abs/2211.07636 /via MAE MIM pretrain""" + model_kwargs = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg', **kwargs) + model = _create_vision_transformer('eva_large_patch14_196', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def eva_large_patch14_336(pretrained=False, **kwargs): + """ EVA-large model https://arxiv.org/abs/2211.07636 via MAE MIM pretrain""" + model_kwargs = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg', **kwargs) + model = _create_vision_transformer('eva_large_patch14_336', pretrained=pretrained, **model_kwargs) + return model From 3d6bc42aa15b67883e3bf0f92df92fc7b74030b1 Mon Sep 17 00:00:00 2001 From: Lorenzo Baraldi Date: Fri, 9 Dec 2022 12:03:23 +0100 Subject: [PATCH 6/9] Put validation loss under amp_autocast Secured the loss evaluation under the amp, avoiding function to operate on float16 --- train.py | 16 ++++++++-------- validate.py | 6 +++--- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/train.py b/train.py index d40ff04b..b85eb6b0 100755 --- a/train.py +++ b/train.py @@ -970,16 +970,16 @@ def validate( with amp_autocast(): output = model(input) - if isinstance(output, (tuple, list)): - output = output[0] + if isinstance(output, (tuple, list)): + output = output[0] - # augmentation reduction - reduce_factor = args.tta - if reduce_factor > 1: - output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) - target = target[0:target.size(0):reduce_factor] + # augmentation reduction + reduce_factor = args.tta + if reduce_factor > 1: + output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) + target = target[0:target.size(0):reduce_factor] - loss = loss_fn(output, target) + loss = loss_fn(output, target) acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) if args.distributed: diff --git a/validate.py b/validate.py index 6b8222b9..872f27b0 100755 --- a/validate.py +++ b/validate.py @@ -294,9 +294,9 @@ def validate(args): with amp_autocast(): output = model(input) - if valid_labels is not None: - output = output[:, valid_labels] - loss = criterion(output, target) + if valid_labels is not None: + output = output[:, valid_labels] + loss = criterion(output, target) if real_labels is not None: real_labels.add_result(output) From 9e47d8ad5942a3e60ab12978c1aca5068c201929 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 9 Dec 2022 11:13:37 -0800 Subject: [PATCH 7/9] Update README.md --- README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.md b/README.md index 331ea7a8..a63c2b14 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,13 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before ## What's New +# Survey: Feedback Appreciated + +For a few months now, `timm` has been part of the Hugging Face ecosystem. Yearly, we survey users of our tools to see what we could do better, what we need to continue doing, or what we need to stop doing. + +If you have a couple of minutes and want to participate in shaping the future of the ecosystem, please share your thoughts: +[**hf.co/oss-survey**](https://hf.co/oss-survey) + # Dec 8, 2022 * Add 'EVA l' to `vision_transformer.py`, MAE style ViT-L/14 MIM pretrain w/ EVA-CLIP targets, FT on ImageNet-1k (w/ ImageNet-22k intermediate for some) * original source: https://github.com/baaivision/EVA From 1733177c75a23d3f8b34ffe4c8c9316440bef323 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 9 Dec 2022 11:14:35 -0800 Subject: [PATCH 8/9] Update README.md --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index a63c2b14..130b604c 100644 --- a/README.md +++ b/README.md @@ -21,14 +21,14 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before ## What's New -# Survey: Feedback Appreciated +### Survey: Feedback Appreciated For a few months now, `timm` has been part of the Hugging Face ecosystem. Yearly, we survey users of our tools to see what we could do better, what we need to continue doing, or what we need to stop doing. If you have a couple of minutes and want to participate in shaping the future of the ecosystem, please share your thoughts: [**hf.co/oss-survey**](https://hf.co/oss-survey) -# Dec 8, 2022 +### Dec 8, 2022 * Add 'EVA l' to `vision_transformer.py`, MAE style ViT-L/14 MIM pretrain w/ EVA-CLIP targets, FT on ImageNet-1k (w/ ImageNet-22k intermediate for some) * original source: https://github.com/baaivision/EVA @@ -39,7 +39,7 @@ If you have a couple of minutes and want to participate in shaping the future of | eva_large_patch14_196.in22k_ft_in22k_in1k | 88.6 | 304.1 | 61.6 | 63.5 | [link](https://huggingface.co/BAAI/EVA) | | eva_large_patch14_196.in22k_ft_in1k | 87.9 | 304.1 | 61.6 | 63.5 | [link](https://huggingface.co/BAAI/EVA) | -# Dec 6, 2022 +### Dec 6, 2022 * Add 'EVA g', BEiT style ViT-g/14 model weights w/ both MIM pretrain and CLIP pretrain to `beit.py`. * original source: https://github.com/baaivision/EVA * paper: https://arxiv.org/abs/2211.07636 @@ -51,7 +51,7 @@ If you have a couple of minutes and want to participate in shaping the future of | eva_giant_patch14_336.clip_ft_in1k | 89.4 | 1013 | 620.6 | 550.7 | [link](https://huggingface.co/BAAI/EVA) | | eva_giant_patch14_224.clip_ft_in1k | 89.1 | 1012.6 | 267.2 | 192.6 | [link](https://huggingface.co/BAAI/EVA) | -# Dec 5, 2022 +### Dec 5, 2022 * Pre-release (`0.8.0dev0`) of multi-weight support (`model_arch.pretrained_tag`). Install with `pip install --pre timm` * vision_transformer, maxvit, convnext are the first three model impl w/ support From 0fe90449e5e07e91a78ea847b76eda6010b55283 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 9 Dec 2022 11:21:46 -0800 Subject: [PATCH 9/9] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 130b604c..bb6485c0 100644 --- a/README.md +++ b/README.md @@ -21,12 +21,12 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before ## What's New -### Survey: Feedback Appreciated +### 🤗 Survey: Feedback Appreciated 🤗 For a few months now, `timm` has been part of the Hugging Face ecosystem. Yearly, we survey users of our tools to see what we could do better, what we need to continue doing, or what we need to stop doing. If you have a couple of minutes and want to participate in shaping the future of the ecosystem, please share your thoughts: -[**hf.co/oss-survey**](https://hf.co/oss-survey) +[**hf.co/oss-survey**](https://hf.co/oss-survey) 🙏 ### Dec 8, 2022 * Add 'EVA l' to `vision_transformer.py`, MAE style ViT-L/14 MIM pretrain w/ EVA-CLIP targets, FT on ImageNet-1k (w/ ImageNet-22k intermediate for some)