diff --git a/README.md b/README.md index 798f94f3..bb6485c0 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,37 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before ## What's New -# Dec 5, 2022 +### 🤗 Survey: Feedback Appreciated 🤗 + +For a few months now, `timm` has been part of the Hugging Face ecosystem. Yearly, we survey users of our tools to see what we could do better, what we need to continue doing, or what we need to stop doing. + +If you have a couple of minutes and want to participate in shaping the future of the ecosystem, please share your thoughts: +[**hf.co/oss-survey**](https://hf.co/oss-survey) 🙏 + +### Dec 8, 2022 +* Add 'EVA l' to `vision_transformer.py`, MAE style ViT-L/14 MIM pretrain w/ EVA-CLIP targets, FT on ImageNet-1k (w/ ImageNet-22k intermediate for some) + * original source: https://github.com/baaivision/EVA + +| model | top1 | param_count | gmac | macts | hub | +|:------------------------------------------|-----:|------------:|------:|------:|:----------------------------------------| +| eva_large_patch14_336.in22k_ft_in22k_in1k | 89.2 | 304.5 | 191.1 | 270.2 | [link](https://huggingface.co/BAAI/EVA) | +| eva_large_patch14_336.in22k_ft_in1k | 88.7 | 304.5 | 191.1 | 270.2 | [link](https://huggingface.co/BAAI/EVA) | +| eva_large_patch14_196.in22k_ft_in22k_in1k | 88.6 | 304.1 | 61.6 | 63.5 | [link](https://huggingface.co/BAAI/EVA) | +| eva_large_patch14_196.in22k_ft_in1k | 87.9 | 304.1 | 61.6 | 63.5 | [link](https://huggingface.co/BAAI/EVA) | + +### Dec 6, 2022 +* Add 'EVA g', BEiT style ViT-g/14 model weights w/ both MIM pretrain and CLIP pretrain to `beit.py`. + * original source: https://github.com/baaivision/EVA + * paper: https://arxiv.org/abs/2211.07636 + +| model | top1 | param_count | gmac | macts | hub | +|:-----------------------------------------|-------:|--------------:|-------:|--------:|:----------------------------------------| +| eva_giant_patch14_560.m30m_ft_in22k_in1k | 89.8 | 1014.4 | 1906.8 | 2577.2 | [link](https://huggingface.co/BAAI/EVA) | +| eva_giant_patch14_336.m30m_ft_in22k_in1k | 89.6 | 1013 | 620.6 | 550.7 | [link](https://huggingface.co/BAAI/EVA) | +| eva_giant_patch14_336.clip_ft_in1k | 89.4 | 1013 | 620.6 | 550.7 | [link](https://huggingface.co/BAAI/EVA) | +| eva_giant_patch14_224.clip_ft_in1k | 89.1 | 1012.6 | 267.2 | 192.6 | [link](https://huggingface.co/BAAI/EVA) | + +### Dec 5, 2022 * Pre-release (`0.8.0dev0`) of multi-weight support (`model_arch.pretrained_tag`). Install with `pip install --pre timm` * vision_transformer, maxvit, convnext are the first three model impl w/ support @@ -376,6 +406,7 @@ A full version of the list below with source links can be found in the [document * MobileNet-V2 - https://arxiv.org/abs/1801.04381 * Single-Path NAS - https://arxiv.org/abs/1904.02877 * TinyNet - https://arxiv.org/abs/2010.14819 +* EVA - https://arxiv.org/abs/2211.07636 * GCViT (Global Context Vision Transformer) - https://arxiv.org/abs/2206.09959 * GhostNet - https://arxiv.org/abs/1911.11907 * gMLP - https://arxiv.org/abs/2105.08050 diff --git a/benchmark.py b/benchmark.py index 95e2cb5a..58435ff8 100755 --- a/benchmark.py +++ b/benchmark.py @@ -81,9 +81,11 @@ parser.add_argument('--results-file', default='', type=str, parser.add_argument('--results-format', default='csv', type=str, help='Format for results file one of (csv, json) (default: csv).') parser.add_argument('--num-warm-iter', default=10, type=int, - metavar='N', help='Number of warmup iterations (default: 10)') + help='Number of warmup iterations (default: 10)') parser.add_argument('--num-bench-iter', default=40, type=int, - metavar='N', help='Number of benchmark iterations (default: 40)') + help='Number of benchmark iterations (default: 40)') +parser.add_argument('--device', default='cuda', type=str, + help="device to run benchmark on") # common inference / train args parser.add_argument('--model', '-m', metavar='NAME', default='resnet50', diff --git a/tests/test_models.py b/tests/test_models.py index d6c0052f..2392a190 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -27,7 +27,7 @@ NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*', 'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*', - 'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', + 'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', 'eva_*' ] NUM_NON_STD = len(NON_STD_FILTERS) @@ -39,7 +39,7 @@ if 'GITHUB_ACTIONS' in os.environ: '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*', '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*', 'swin*giant*'] - NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*'] + NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*', 'eva_giant*'] else: EXCLUDE_FILTERS = [] NON_STD_EXCLUDE_FILTERS = ['vit_gi*'] diff --git a/timm/models/_pretrained.py b/timm/models/_pretrained.py index c422dab7..b5ecbc50 100644 --- a/timm/models/_pretrained.py +++ b/timm/models/_pretrained.py @@ -62,10 +62,11 @@ class PretrainedCfg: def filter_pretrained_cfg(cfg, remove_source=False, remove_null=True): filtered_cfg = {} + keep_none = {'pool_size', 'first_conv', 'classifier'} # always keep these keys, even if none for k, v in cfg.items(): if remove_source and k in {'url', 'file', 'hf_hub_id', 'hf_hub_id', 'hf_hub_filename', 'source'}: continue - if remove_null and v is None: + if remove_null and v is None and k not in keep_none: continue filtered_cfg[k] = v return filtered_cfg diff --git a/timm/models/beit.py b/timm/models/beit.py index 7c4dd14d..de71f441 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -1,8 +1,6 @@ -""" BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) +""" BEiT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) Model from official source: https://github.com/microsoft/unilm/tree/master/beit -and -https://github.com/microsoft/unilm/tree/master/beit2 @inproceedings{beit, title={{BEiT}: {BERT} Pre-Training of Image Transformers}, @@ -12,6 +10,8 @@ year={2022}, url={https://openreview.net/forum?id=p-BhZSz59o4} } +BEiT-v2 from https://github.com/microsoft/unilm/tree/master/beit2 + @article{beitv2, title={{BEiT v2}: Masked Image Modeling with Vector-Quantized Visual Tokenizers}, author={Zhiliang Peng and Li Dong and Hangbo Bao and Qixiang Ye and Furu Wei}, @@ -21,6 +21,17 @@ archivePrefix={arXiv}, primaryClass={cs.CV} } +EVA from https://github.com/baaivision/EVA , paper: https://arxiv.org/abs/2211.07636 + +@article{EVA, + title={EVA: Exploring the Limits of Masked Visual Representation Learning at Scale}, + author={Fang, Yuxin and Wang, Wen and Xie, Binhui and Sun, Quan and Wu, Ledell and Wang, Xinggang and Huang, + Tiejun and Wang, Xinlong and Cao, Yue}, + journal={arXiv preprint arXiv:2211.07636}, + year={2022} +} + + At this point only the 1k fine-tuned classification weights and model configs have been added, see original source above for pre-training models and procedure. @@ -37,6 +48,9 @@ Modifications by / Copyright 2021 Ross Wightman, original copyrights below # https://github.com/facebookresearch/deit/ # https://github.com/facebookresearch/dino # --------------------------------------------------------' + +# EVA models Copyright (c) 2022 BAAI-Vision + import math from functools import partial from typing import Optional, Tuple @@ -49,69 +63,12 @@ from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_ from ._builder import build_model_with_cfg +from ._pretrained import generate_default_cfgs from ._registry import register_model from .vision_transformer import checkpoint_filter_fn __all__ = ['Beit'] -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, - 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), - 'first_conv': 'patch_embed.proj', 'classifier': 'head', - **kwargs - } - - -default_cfgs = { - 'beit_base_patch16_224': _cfg( - url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth'), - 'beit_base_patch16_384': _cfg( - url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth', - input_size=(3, 384, 384), crop_pct=1.0, - ), - 'beit_base_patch16_224_in22k': _cfg( - url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22k.pth', - num_classes=21841, - ), - 'beit_large_patch16_224': _cfg( - url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth'), - 'beit_large_patch16_384': _cfg( - url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth', - input_size=(3, 384, 384), crop_pct=1.0, - ), - 'beit_large_patch16_512': _cfg( - url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth', - input_size=(3, 512, 512), crop_pct=1.0, - ), - 'beit_large_patch16_224_in22k': _cfg( - url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth', - num_classes=21841, - ), - - 'beitv2_base_patch16_224': _cfg( - url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD - ), - 'beitv2_base_patch16_224_in22k': _cfg( - url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth', - num_classes=21841, - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD - ), - 'beitv2_large_patch16_224': _cfg( - url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth', - crop_pct=0.95, - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD - ), - 'beitv2_large_patch16_224_in22k': _cfg( - url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth', - num_classes=21841, - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD - ), -} - def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor: num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 @@ -385,6 +342,82 @@ class Beit(nn.Module): return x +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'beit_base_patch16_224.in22k_ft_in22k_in1k': _cfg( + url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth'), + 'beit_base_patch16_384.in22k_ft_in22k_in1k': _cfg( + url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth', + input_size=(3, 384, 384), crop_pct=1.0, + ), + 'beit_base_patch16_224.in22k_ft_in22k': _cfg( + url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22k.pth', + num_classes=21841, + ), + 'beit_large_patch16_224.in22k_ft_in22k_in1k': _cfg( + url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth'), + 'beit_large_patch16_384.in22k_ft_in22k_in1k': _cfg( + url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth', + input_size=(3, 384, 384), crop_pct=1.0, + ), + 'beit_large_patch16_512.in22k_ft_in22k_in1k': _cfg( + url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth', + input_size=(3, 512, 512), crop_pct=1.0, + ), + 'beit_large_patch16_224.in22k_ft_in22k': _cfg( + url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth', + num_classes=21841, + ), + + 'beitv2_base_patch16_224.in1k_ft_in22k_in1k': _cfg( + url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD + ), + 'beitv2_base_patch16_224.in1k_ft_in22k': _cfg( + url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth', + num_classes=21841, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD + ), + 'beitv2_large_patch16_224.in1k_ft_in22k_in1k': _cfg( + url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth', + crop_pct=0.95, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD + ), + 'beitv2_large_patch16_224.in1k_ft_in22k': _cfg( + url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth', + num_classes=21841, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD + ), + + 'eva_giant_patch14_224.clip_ft_in1k': _cfg( + hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, + ), + 'eva_giant_patch14_336.clip_ft_in1k': _cfg( + hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'), + 'eva_giant_patch14_336.m30m_ft_in22k_in1k': _cfg( + hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, + input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'), + 'eva_giant_patch14_560.m30m_ft_in22k_in1k': _cfg( + hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, + input_size=(3, 560, 560), crop_pct=1.0, crop_mode='squash'), +}) + + def _beit_checkpoint_filter_fn(state_dict, model): if 'module' in state_dict: # beit v2 didn't strip module @@ -394,7 +427,7 @@ def _beit_checkpoint_filter_fn(state_dict, model): def _create_beit(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Beit models.') + raise RuntimeError('features_only not implemented for BEiT models.') model = build_model_with_cfg( Beit, variant, pretrained, @@ -416,25 +449,16 @@ def beit_base_patch16_224(pretrained=False, **kwargs): @register_model def beit_base_patch16_384(pretrained=False, **kwargs): model_kwargs = dict( - img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1, **kwargs) model = _create_beit('beit_base_patch16_384', pretrained=pretrained, **model_kwargs) return model -@register_model -def beit_base_patch16_224_in22k(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, - use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1, **kwargs) - model = _create_beit('beit_base_patch16_224_in22k', pretrained=pretrained, **model_kwargs) - return model - - @register_model def beit_large_patch16_224(pretrained=False, **kwargs): model_kwargs = dict( - patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + patch_size=16, embed_dim=1024, depth=24, num_heads=16, use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) model = _create_beit('beit_large_patch16_224', pretrained=pretrained, **model_kwargs) return model @@ -443,7 +467,7 @@ def beit_large_patch16_224(pretrained=False, **kwargs): @register_model def beit_large_patch16_384(pretrained=False, **kwargs): model_kwargs = dict( - img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) model = _create_beit('beit_large_patch16_384', pretrained=pretrained, **model_kwargs) return model @@ -452,52 +476,52 @@ def beit_large_patch16_384(pretrained=False, **kwargs): @register_model def beit_large_patch16_512(pretrained=False, **kwargs): model_kwargs = dict( - img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) model = _create_beit('beit_large_patch16_512', pretrained=pretrained, **model_kwargs) return model @register_model -def beit_large_patch16_224_in22k(pretrained=False, **kwargs): +def beitv2_base_patch16_224(pretrained=False, **kwargs): model_kwargs = dict( - patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, - use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) - model = _create_beit('beit_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs) + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) + model = _create_beit('beitv2_base_patch16_224', pretrained=pretrained, **model_kwargs) return model @register_model -def beitv2_base_patch16_224(pretrained=False, **kwargs): +def beitv2_large_patch16_224(pretrained=False, **kwargs): model_kwargs = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, - use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) - model = _create_beit('beitv2_base_patch16_224', pretrained=pretrained, **model_kwargs) + patch_size=16, embed_dim=1024, depth=24, num_heads=16, + use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) + model = _create_beit('beitv2_large_patch16_224', pretrained=pretrained, **model_kwargs) return model @register_model -def beitv2_base_patch16_224_in22k(pretrained=False, **kwargs): +def eva_giant_patch14_224(pretrained=False, **kwargs): + """ EVA-g model https://arxiv.org/abs/2211.07636 """ model_kwargs = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, - use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) - model = _create_beit('beitv2_base_patch16_224_in22k', pretrained=pretrained, **model_kwargs) + patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs) + model = _create_beit('eva_giant_patch14_224', pretrained=pretrained, **model_kwargs) return model @register_model -def beitv2_large_patch16_224(pretrained=False, **kwargs): +def eva_giant_patch14_336(pretrained=False, **kwargs): + """ EVA-g model https://arxiv.org/abs/2211.07636 """ model_kwargs = dict( - patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, - use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) - model = _create_beit('beitv2_large_patch16_224', pretrained=pretrained, **model_kwargs) + patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs) + model = _create_beit('eva_giant_patch14_336', pretrained=pretrained, **model_kwargs) return model @register_model -def beitv2_large_patch16_224_in22k(pretrained=False, **kwargs): +def eva_giant_patch14_560(pretrained=False, **kwargs): + """ EVA-g model https://arxiv.org/abs/2211.07636 """ model_kwargs = dict( - patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, - use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) - model = _create_beit('beitv2_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs) + patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs) + model = _create_beit('eva_giant_patch14_560', pretrained=pretrained, **model_kwargs) return model diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 3c2ebc29..5b93628f 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -938,6 +938,25 @@ default_cfgs = generate_default_cfgs({ 'vit_small_patch16_36x1_224': _cfg(url=''), 'vit_small_patch16_18x2_224': _cfg(url=''), 'vit_base_patch16_18x2_224': _cfg(url=''), + + # EVA fine-tuned weights from MAE style MIM - EVA-CLIP target pretrain + # https://github.com/baaivision/EVA/blob/7ecf2c0a370d97967e86d047d7af9188f78d2df3/eva/README.md#eva-l-learning-better-mim-representations-from-eva-clip + 'eva_large_patch14_196.in22k_ft_in22k_in1k': _cfg( + hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_21k_to_1k_ft_88p6.pt', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 196, 196), crop_pct=1.0), + 'eva_large_patch14_336.in22k_ft_in22k_in1k': _cfg( + hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_21k_to_1k_ft_89p2.pt', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'), + 'eva_large_patch14_196.in22k_ft_in1k': _cfg( + hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_1k_ft_88p0.pt', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 196, 196), crop_pct=1.0), + 'eva_large_patch14_336.in22k_ft_in1k': _cfg( + hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_1k_ft_88p65.pt', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'), }) @@ -1359,3 +1378,21 @@ def vit_base_patch16_18x2_224(pretrained=False, **kwargs): patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock, **kwargs) model = _create_vision_transformer('vit_base_patch16_18x2_224', pretrained=pretrained, **model_kwargs) return model + + +@register_model +def eva_large_patch14_196(pretrained=False, **kwargs): + """ EVA-large model https://arxiv.org/abs/2211.07636 /via MAE MIM pretrain""" + model_kwargs = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg', **kwargs) + model = _create_vision_transformer('eva_large_patch14_196', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def eva_large_patch14_336(pretrained=False, **kwargs): + """ EVA-large model https://arxiv.org/abs/2211.07636 via MAE MIM pretrain""" + model_kwargs = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg', **kwargs) + model = _create_vision_transformer('eva_large_patch14_336', pretrained=pretrained, **model_kwargs) + return model diff --git a/train.py b/train.py index 1276840d..e51d7c90 100755 --- a/train.py +++ b/train.py @@ -969,16 +969,16 @@ def validate( with amp_autocast(): output = model(input) - if isinstance(output, (tuple, list)): - output = output[0] + if isinstance(output, (tuple, list)): + output = output[0] - # augmentation reduction - reduce_factor = args.tta - if reduce_factor > 1: - output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) - target = target[0:target.size(0):reduce_factor] + # augmentation reduction + reduce_factor = args.tta + if reduce_factor > 1: + output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) + target = target[0:target.size(0):reduce_factor] - loss = loss_fn(output, target) + loss = loss_fn(output, target) acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) if args.distributed: diff --git a/validate.py b/validate.py index 3bbf07cf..4669fbac 100755 --- a/validate.py +++ b/validate.py @@ -296,9 +296,9 @@ def validate(args): with amp_autocast(): output = model(input) - if valid_labels is not None: - output = output[:, valid_labels] - loss = criterion(output, target) + if valid_labels is not None: + output = output[:, valid_labels] + loss = criterion(output, target) if real_labels is not None: real_labels.add_result(output)