Add EVA FT results, hopefully fix BEiT test failures

pull/1578/head^2
Ross Wightman 2 years ago committed by Ross Wightman
parent 3cc4d7a894
commit 98047ef5e3

@ -22,7 +22,16 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
## What's New ## What's New
# Dec 6, 2022 # Dec 6, 2022
* Add 'EVA g', BEiT style ViT-g/14 model weights w/ both MIM pretrain and CLIP pretrain from https://github.com/baaivision/EVA * Add 'EVA g', BEiT style ViT-g/14 model weights w/ both MIM pretrain and CLIP pretrain to `beit.py`.
* original source: https://github.com/baaivision/EVA
* paper: https://arxiv.org/abs/2211.07636
| model | top1 | param_count | gmac | macts | hub |
|:-----------------------------------------|-------:|--------------:|-------:|--------:|:----------------------------------------|
| eva_giant_patch14_560.m30m_ft_in22k_in1k | 89.8 | 1014.4 | 1906.8 | 2577.2 | [link](https://huggingface.co/BAAI/EVA) |
| eva_giant_patch14_336.m30m_ft_in22k_in1k | 89.6 | 1013 | 620.6 | 550.7 | [link](https://huggingface.co/BAAI/EVA) |
| eva_giant_patch14_336.clip_ft_in1k | 89.4 | 1013 | 620.6 | 550.7 | [link](https://huggingface.co/BAAI/EVA) |
| eva_giant_patch14_224.clip_ft_in1k | 89.1 | 1012.6 | 267.2 | 192.6 | [link](https://huggingface.co/BAAI/EVA) |
# Dec 5, 2022 # Dec 5, 2022

@ -80,9 +80,11 @@ parser.add_argument('--results-file', default='', type=str,
parser.add_argument('--results-format', default='csv', type=str, parser.add_argument('--results-format', default='csv', type=str,
help='Format for results file one of (csv, json) (default: csv).') help='Format for results file one of (csv, json) (default: csv).')
parser.add_argument('--num-warm-iter', default=10, type=int, parser.add_argument('--num-warm-iter', default=10, type=int,
metavar='N', help='Number of warmup iterations (default: 10)') help='Number of warmup iterations (default: 10)')
parser.add_argument('--num-bench-iter', default=40, type=int, parser.add_argument('--num-bench-iter', default=40, type=int,
metavar='N', help='Number of benchmark iterations (default: 40)') help='Number of benchmark iterations (default: 40)')
parser.add_argument('--device', default='cuda', type=str,
help="device to run benchmark on")
# common inference / train args # common inference / train args
parser.add_argument('--model', '-m', metavar='NAME', default='resnet50', parser.add_argument('--model', '-m', metavar='NAME', default='resnet50',

@ -27,7 +27,7 @@ NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*', 'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', 'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', 'eva_*'
] ]
NUM_NON_STD = len(NON_STD_FILTERS) NUM_NON_STD = len(NON_STD_FILTERS)
@ -39,7 +39,7 @@ if 'GITHUB_ACTIONS' in os.environ:
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*', '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*',
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*', '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*',
'swin*giant*'] 'swin*giant*']
NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*'] NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*', 'eva_giant*']
else: else:
EXCLUDE_FILTERS = [] EXCLUDE_FILTERS = []
NON_STD_EXCLUDE_FILTERS = ['vit_gi*'] NON_STD_EXCLUDE_FILTERS = ['vit_gi*']

@ -1,4 +1,4 @@
""" BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) """ BEiT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
Model from official source: https://github.com/microsoft/unilm/tree/master/beit Model from official source: https://github.com/microsoft/unilm/tree/master/beit
@ -68,82 +68,6 @@ from .registry import register_model
from .vision_transformer import checkpoint_filter_fn from .vision_transformer import checkpoint_filter_fn
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = generate_default_cfgs({
'beit_base_patch16_224.in22k_ft_in22k_in1k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth'),
'beit_base_patch16_384.in22k_ft_in22k_in1k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth',
input_size=(3, 384, 384), crop_pct=1.0,
),
'beit_base_patch16_224.in22k_ft_in22k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22k.pth',
num_classes=21841,
),
'beit_large_patch16_224.in22k_ft_in22k_in1k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth'),
'beit_large_patch16_384.in22k_ft_in22k_in1k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth',
input_size=(3, 384, 384), crop_pct=1.0,
),
'beit_large_patch16_512.in22k_ft_in22k_in1k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth',
input_size=(3, 512, 512), crop_pct=1.0,
),
'beit_large_patch16_224.in22k_ft_in22k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth',
num_classes=21841,
),
'beitv2_base_patch16_224.in1k_ft_in22k_in1k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
),
'beitv2_base_patch16_224.in1k_ft_in22k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth',
num_classes=21841,
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
),
'beitv2_large_patch16_224.in1k_ft_in22k_in1k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth',
crop_pct=0.95,
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
),
'beitv2_large_patch16_224.in1k_ft_in22k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth',
num_classes=21841,
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
),
'eva_giant_patch14_224.clip_ft_in1k': _cfg(
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
),
'eva_giant_patch14_336.clip_ft_in1k': _cfg(
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 336, 336)),
'eva_giant_patch14_336.m30m_ft_in22k_in1k': _cfg(
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
input_size=(3, 336, 336)),
'eva_giant_patch14_560.m30m_ft_in22k_in1k': _cfg(
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
input_size=(3, 560, 560)),
})
def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor: def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor:
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
# cls to token & token 2 cls & cls to cls # cls to token & token 2 cls & cls to cls
@ -416,6 +340,82 @@ class Beit(nn.Module):
return x return x
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = generate_default_cfgs({
'beit_base_patch16_224.in22k_ft_in22k_in1k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth'),
'beit_base_patch16_384.in22k_ft_in22k_in1k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth',
input_size=(3, 384, 384), crop_pct=1.0,
),
'beit_base_patch16_224.in22k_ft_in22k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22k.pth',
num_classes=21841,
),
'beit_large_patch16_224.in22k_ft_in22k_in1k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth'),
'beit_large_patch16_384.in22k_ft_in22k_in1k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth',
input_size=(3, 384, 384), crop_pct=1.0,
),
'beit_large_patch16_512.in22k_ft_in22k_in1k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth',
input_size=(3, 512, 512), crop_pct=1.0,
),
'beit_large_patch16_224.in22k_ft_in22k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth',
num_classes=21841,
),
'beitv2_base_patch16_224.in1k_ft_in22k_in1k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
),
'beitv2_base_patch16_224.in1k_ft_in22k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth',
num_classes=21841,
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
),
'beitv2_large_patch16_224.in1k_ft_in22k_in1k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth',
crop_pct=0.95,
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
),
'beitv2_large_patch16_224.in1k_ft_in22k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth',
num_classes=21841,
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
),
'eva_giant_patch14_224.clip_ft_in1k': _cfg(
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0,
),
'eva_giant_patch14_336.clip_ft_in1k': _cfg(
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
'eva_giant_patch14_336.m30m_ft_in22k_in1k': _cfg(
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
'eva_giant_patch14_560.m30m_ft_in22k_in1k': _cfg(
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
input_size=(3, 560, 560), crop_pct=1.0, crop_mode='squash'),
})
def _beit_checkpoint_filter_fn(state_dict, model): def _beit_checkpoint_filter_fn(state_dict, model):
if 'module' in state_dict: if 'module' in state_dict:
# beit v2 didn't strip module # beit v2 didn't strip module
@ -425,7 +425,7 @@ def _beit_checkpoint_filter_fn(state_dict, model):
def _create_beit(variant, pretrained=False, **kwargs): def _create_beit(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None): if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Beit models.') raise RuntimeError('features_only not implemented for BEiT models.')
model = build_model_with_cfg( model = build_model_with_cfg(
Beit, variant, pretrained, Beit, variant, pretrained,
@ -453,15 +453,6 @@ def beit_base_patch16_384(pretrained=False, **kwargs):
return model return model
@register_model
def beit_base_patch16_224_in22k(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12,
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1, **kwargs)
model = _create_beit('beit_base_patch16_224_in22k', pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def beit_large_patch16_224(pretrained=False, **kwargs): def beit_large_patch16_224(pretrained=False, **kwargs):
model_kwargs = dict( model_kwargs = dict(
@ -489,15 +480,6 @@ def beit_large_patch16_512(pretrained=False, **kwargs):
return model return model
@register_model
def beit_large_patch16_224_in22k(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
model = _create_beit('beit_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def beitv2_base_patch16_224(pretrained=False, **kwargs): def beitv2_base_patch16_224(pretrained=False, **kwargs):
model_kwargs = dict( model_kwargs = dict(
@ -507,15 +489,6 @@ def beitv2_base_patch16_224(pretrained=False, **kwargs):
return model return model
@register_model
def beitv2_base_patch16_224_in22k(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
model = _create_beit('beitv2_base_patch16_224_in22k', pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def beitv2_large_patch16_224(pretrained=False, **kwargs): def beitv2_large_patch16_224(pretrained=False, **kwargs):
model_kwargs = dict( model_kwargs = dict(
@ -525,15 +498,6 @@ def beitv2_large_patch16_224(pretrained=False, **kwargs):
return model return model
@register_model
def beitv2_large_patch16_224_in22k(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
model = _create_beit('beitv2_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def eva_giant_patch14_224(pretrained=False, **kwargs): def eva_giant_patch14_224(pretrained=False, **kwargs):
""" EVA-g model https://arxiv.org/abs/2211.07636 """ """ EVA-g model https://arxiv.org/abs/2211.07636 """

@ -59,10 +59,11 @@ class PretrainedCfg:
def filter_pretrained_cfg(cfg, remove_source=False, remove_null=True): def filter_pretrained_cfg(cfg, remove_source=False, remove_null=True):
filtered_cfg = {} filtered_cfg = {}
keep_none = {'pool_size', 'first_conv', 'classifier'} # always keep these keys, even if none
for k, v in cfg.items(): for k, v in cfg.items():
if remove_source and k in {'url', 'file', 'hf_hub_id', 'hf_hub_id', 'hf_hub_filename', 'source'}: if remove_source and k in {'url', 'file', 'hf_hub_id', 'hf_hub_id', 'hf_hub_filename', 'source'}:
continue continue
if remove_null and v is None: if remove_null and v is None and k not in keep_none:
continue continue
filtered_cfg[k] = v filtered_cfg[k] = v
return filtered_cfg return filtered_cfg

Loading…
Cancel
Save