Add 'gigantic' vit clip variant for feature extraction and future fine-tuning

pull/1583/head^2
Ross Wightman 1 year ago
parent 3aa31f537d
commit 64667bfa0e

@ -1029,6 +1029,10 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
'vit_gigantic_patch14_clip_224.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1280),
'vit_base_patch32_clip_224.openai': _cfg(
hf_hub_id='timm/',
@ -1498,6 +1502,17 @@ def vit_giant_patch14_clip_224(pretrained=False, **kwargs):
return model
@register_model
def vit_gigantic_patch14_clip_224(pretrained=False, **kwargs):
""" ViT-bigG model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
Pretrained weights from CLIP image tower.
"""
model_kwargs = dict(
patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_gigantic_patch14_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
# Experimental models below
@register_model

Loading…
Cancel
Save