From e069249a2dab0056a7687d5e48043757c46e7525 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 16 Sep 2022 21:39:05 -0700 Subject: [PATCH] Add hf hub entries for laion2b clip models, add huggingface_hub dependency, update some setup/reqs, torch >= 1.7 --- requirements.txt | 5 +++-- setup.py | 9 ++++++--- tests/test_models.py | 2 +- timm/data/constants.py | 2 ++ timm/models/vision_transformer.py | 32 +++++++++++++++++-------------- 5 files changed, 30 insertions(+), 20 deletions(-) diff --git a/requirements.txt b/requirements.txt index 2d29a27c..5846bb36 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ -torch>=1.4.0 -torchvision>=0.5.0 +torch>=1.7 +torchvision pyyaml +huggingface_hub diff --git a/setup.py b/setup.py index 882ed467..59b4ed4c 100644 --- a/setup.py +++ b/setup.py @@ -25,13 +25,15 @@ setup( # 3 - Alpha # 4 - Beta # 5 - Production/Stable - 'Development Status :: 3 - Alpha', + 'Development Status :: 4 - Beta', 'Intended Audience :: Education', 'Intended Audience :: Science/Research', 'License :: OSI Approved :: Apache Software License', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', 'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Software Development', @@ -40,9 +42,10 @@ setup( ], # Note that this is a string of words separated by whitespace, not a list. - keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet', + keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet resnet vision transformer vit', packages=find_packages(exclude=['convert', 'tests', 'results']), include_package_data=True, - install_requires=['torch >= 1.4', 'torchvision'], + install_requires=['torch >= 1.7', 'torchvision', 'pyyaml', 'huggingface_hub'], python_requires='>=3.6', ) + diff --git a/tests/test_models.py b/tests/test_models.py index f0b5a820..d007d65a 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -38,7 +38,7 @@ if 'GITHUB_ACTIONS' in os.environ: EXCLUDE_FILTERS = [ '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm', '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*', - '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_g*', 'swin*huge*', + '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*', 'swin*giant*'] NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*'] else: diff --git a/timm/data/constants.py b/timm/data/constants.py index d6d4a01b..e4d8bb7e 100644 --- a/timm/data/constants.py +++ b/timm/data/constants.py @@ -5,3 +5,5 @@ IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255) IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3) +OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073) +OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index ac8e820c..474ea23c 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -30,7 +30,8 @@ import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD,\ + OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply, adapt_input_conv, checkpoint_seq from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ from .registry import register_model @@ -106,7 +107,7 @@ default_cfgs = { 'vit_large_patch14_224': _cfg(url=''), 'vit_huge_patch14_224': _cfg(url=''), 'vit_giant_patch14_224': _cfg(url=''), - 'vit_gee_patch14_224': _cfg(url=''), + 'vit_gigantic_patch14_224': _cfg(url=''), # patch models, imagenet21k (weights from official Google JAX impl) @@ -179,17 +180,21 @@ default_cfgs = { 'vit_base_patch16_18x2_224': _cfg(url=''), 'vit_base_patch32_224_clip_laion2b': _cfg( - hf_hub_id='', - num_classes=512), + hf_hub_id='laion/CLIP-ViT-B-32-laion2B-s34B-b79K', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), 'vit_large_patch14_224_clip_laion2b': _cfg( - hf_hub_id='', - num_classes=768), + hf_hub_id='laion/CLIP-ViT-L-14-laion2B-s32B-b82K', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=768), 'vit_huge_patch14_224_clip_laion2b': _cfg( - hf_hub_id='', - num_classes=1024), + hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=1024), 'vit_giant_patch14_224_clip_laion2b': _cfg( - hf_hub_id='', - num_classes=1024), + hf_hub_id='CLIP-ViT-g-14-laion2B-s12B-b42K', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=1024), } @@ -960,12 +965,11 @@ def vit_giant_patch14_224(pretrained=False, **kwargs): @register_model -def vit_gee_patch14_224(pretrained=False, **kwargs): - """ ViT-GEE (big-G) model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 - As per https://twitter.com/wightmanr/status/1570549064667889666 +def vit_gigantic_patch14_224(pretrained=False, **kwargs): + """ ViT-Gigantic (big-G) model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 """ model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, **kwargs) - model = _create_vision_transformer('vit_gee_patch14_224', pretrained=pretrained, **model_kwargs) + model = _create_vision_transformer('vit_gigantic_patch14_224', pretrained=pretrained, **model_kwargs) return model