From 9709dbaaa95ee603841fcc055a96327f8edf4320 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 15 Sep 2022 17:25:59 -0700 Subject: [PATCH] Adding support for fine-tune CLIP LAION-2B image tower weights for B/32, L/14, H/14 and g/14. Still WIP --- tests/test_models.py | 2 +- timm/models/helpers.py | 8 +- timm/models/hub.py | 8 +- timm/models/layers/patch_embed.py | 13 +- timm/models/vision_transformer.py | 198 ++++++++++++++++++++--- timm/models/vision_transformer_hybrid.py | 13 +- 6 files changed, 214 insertions(+), 28 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index d007d65a..f0b5a820 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -38,7 +38,7 @@ if 'GITHUB_ACTIONS' in os.environ: EXCLUDE_FILTERS = [ '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm', '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*', - '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*', + '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_g*', 'swin*huge*', 'swin*giant*'] NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*'] else: diff --git a/timm/models/helpers.py b/timm/models/helpers.py index fda84171..b84a4523 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -138,6 +138,9 @@ def _resolve_pretrained_source(pretrained_cfg): # hf-hub available as alternate weight source in default_cfg load_from = 'hf-hub' pretrained_loc = hf_hub_id + if load_from == 'hf-hub' and 'hf_hub_filename' in pretrained_cfg: + # if a filename override is set, return tuple for location w/ (hub_id, filename) + pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename'] return load_from, pretrained_loc @@ -246,7 +249,10 @@ def load_pretrained( pretrained_loc, map_location='cpu', progress=_DOWNLOAD_PROGRESS, check_hash=_CHECK_HASH) elif load_from == 'hf-hub': _logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})') - state_dict = load_state_dict_from_hf(pretrained_loc) + if isinstance(pretrained_loc, (list, tuple)): + state_dict = load_state_dict_from_hf(*pretrained_loc) + else: + state_dict = load_state_dict_from_hf(pretrained_loc) else: _logger.warning("No pretrained weights exist or were found for this model. Using random initialization.") return diff --git a/timm/models/hub.py b/timm/models/hub.py index c3d3d15e..265259e5 100644 --- a/timm/models/hub.py +++ b/timm/models/hub.py @@ -55,7 +55,7 @@ def download_cached_file(url, check_hash=True, progress=False): def has_hf_hub(necessary=False): if not _has_hf_hub and necessary: - # if no HF Hub module installed and it is necessary to continue, raise error + # if no HF Hub module installed, and it is necessary to continue, raise error raise RuntimeError( 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') return _has_hf_hub @@ -78,7 +78,7 @@ def load_cfg_from_json(json_file: Union[str, os.PathLike]): def _download_from_hf(model_id: str, filename: str): hf_model_id, hf_revision = hf_split(model_id) - return hf_hub_download(hf_model_id, filename, revision=hf_revision, cache_dir=get_cache_dir('hf')) + return hf_hub_download(hf_model_id, filename, revision=hf_revision) def load_model_config_from_hf(model_id: str): @@ -91,9 +91,9 @@ def load_model_config_from_hf(model_id: str): return pretrained_cfg, model_name -def load_state_dict_from_hf(model_id: str): +def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'): assert has_hf_hub(True) - cached_file = _download_from_hf(model_id, 'pytorch_model.bin') + cached_file = _download_from_hf(model_id, filename) state_dict = torch.load(cached_file, map_location='cpu') return state_dict diff --git a/timm/models/layers/patch_embed.py b/timm/models/layers/patch_embed.py index 6a7facef..be8740ce 100644 --- a/timm/models/layers/patch_embed.py +++ b/timm/models/layers/patch_embed.py @@ -15,7 +15,16 @@ from .trace_utils import _assert class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding """ - def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True, + bias=True, + ): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) @@ -25,7 +34,7 @@ class PatchEmbed(nn.Module): self.num_patches = self.grid_size[0] * self.grid_size[1] self.flatten = flatten - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 296e575e..b78b9197 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -25,6 +25,7 @@ from functools import partial from collections import OrderedDict from typing import Optional +import huggingface_hub.file_download import torch import torch.nn as nn import torch.nn.functional as F @@ -106,7 +107,7 @@ default_cfgs = { 'vit_large_patch14_224': _cfg(url=''), 'vit_huge_patch14_224': _cfg(url=''), 'vit_giant_patch14_224': _cfg(url=''), - 'vit_gigantic_patch14_224': _cfg(url=''), + 'vit_gee_patch14_224': _cfg(url=''), # patch models, imagenet21k (weights from official Google JAX impl) @@ -177,6 +178,20 @@ default_cfgs = { 'vit_small_patch16_36x1_224': _cfg(url=''), 'vit_small_patch16_18x2_224': _cfg(url=''), 'vit_base_patch16_18x2_224': _cfg(url=''), + + 'vit_base_patch32_224_clip_laion2b': _cfg( + hf_hub_id='', + num_classes=512), + 'vit_large_patch14_224_clip_laion2b': _cfg( + hf_hub_id='', + num_classes=768), + 'vit_huge_patch14_224_clip_laion2b': _cfg( + hf_hub_id='', + num_classes=1024), + 'vit_giant_patch14_224_clip_laion2b': _cfg( + hf_hub_id='', + num_classes=1024), + } @@ -221,8 +236,18 @@ class LayerScale(nn.Module): class Block(nn.Module): def __init__( - self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + drop=0., + attn_drop=0., + init_values=None, + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm + ): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) @@ -244,8 +269,18 @@ class Block(nn.Module): class ResPostBlock(nn.Module): def __init__( - self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + drop=0., + attn_drop=0., + init_values=None, + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm + ): super().__init__() self.init_values = init_values @@ -274,8 +309,19 @@ class ResPostBlock(nn.Module): class ParallelBlock(nn.Module): def __init__( - self, dim, num_heads, num_parallel=2, mlp_ratio=4., qkv_bias=False, init_values=None, - drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + self, + dim, + num_heads, + num_parallel=2, + mlp_ratio=4., + qkv_bias=False, + init_values=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm + ): super().__init__() self.num_parallel = num_parallel self.attns = nn.ModuleList() @@ -320,10 +366,31 @@ class VisionTransformer(nn.Module): """ def __init__( - self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', - embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None, - class_token=True, no_embed_class=False, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., - weight_init='', embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block): + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + global_pool='token', + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4., + qkv_bias=True, + init_values=None, + class_token=True, + no_embed_class=False, + pre_norm=False, + fc_norm=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + weight_init='', + embed_layer=PatchEmbed, + norm_layer=None, + act_layer=None, + block_fn=Block, + ): """ Args: img_size (int, tuple): input image size @@ -362,19 +429,34 @@ class VisionTransformer(nn.Module): self.grad_checkpointing = False self.patch_embed = embed_layer( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) + ) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) self.pos_drop = nn.Dropout(p=drop_rate) + self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.blocks = nn.Sequential(*[ block_fn( - dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + init_values=init_values, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer + ) for i in range(depth)]) self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() @@ -445,6 +527,7 @@ class VisionTransformer(nn.Module): def forward_features(self, x): x = self.patch_embed(x) x = self._pos_embed(x) + x = self.norm_pre(x) if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) else: @@ -623,6 +706,40 @@ def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()): return posemb +def _convert_openai_clip(state_dict, model): + out_dict = {} + swaps = [ + ('visual.', ''), ('conv1', 'patch_embed.proj'), ('positional_embedding', 'pos_embed'), + ('transformer.resblocks.', 'blocks.'), ('ln_pre', 'norm_pre'), ('ln_post', 'norm'), ('ln_', 'norm'), + ('in_proj_', 'qkv.'), ('out_proj', 'proj'), ('mlp.c_fc', 'mlp.fc1'), ('mlp.c_proj', 'mlp.fc2'), + ] + for k, v in state_dict.items(): + if not k.startswith('visual.'): + continue + for sp in swaps: + k = k.replace(sp[0], sp[1]) + + if k == 'proj': + k = 'head.weight' + v = v.transpose(0, 1) + out_dict['head.bias'] = torch.zeros(v.shape[0]) + elif k == 'class_embedding': + k = 'cls_token' + v = v.unsqueeze(0).unsqueeze(1) + elif k == 'pos_embed': + v = v.unsqueeze(0) + if v.shape[1] != model.pos_embed.shape[1]: + # To resize pos embedding when using model at different size from pretrained weights + v = resize_pos_embed( + v, + model.pos_embed, + 0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1), + model.patch_embed.grid_size + ) + out_dict[k] = v + return out_dict + + def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False): """ convert patch embedding weight from manual patchify + linear proj to conv""" import re @@ -631,6 +748,9 @@ def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False): # For deit models state_dict = state_dict['model'] + if 'visual.class_embedding' in state_dict: + return _convert_openai_clip(state_dict, model) + for k, v in state_dict.items(): if 'patch_embed.proj.weight' in k and len(v.shape) < 4: # For old models that I trained prior to conv based patchification @@ -833,7 +953,7 @@ def vit_huge_patch14_224(pretrained=False, **kwargs): @register_model def vit_giant_patch14_224(pretrained=False, **kwargs): - """ ViT-Giant model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 + """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 """ model_kwargs = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, **kwargs) model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **model_kwargs) @@ -841,11 +961,12 @@ def vit_giant_patch14_224(pretrained=False, **kwargs): @register_model -def vit_gigantic_patch14_224(pretrained=False, **kwargs): - """ ViT-Gigantic model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 +def vit_gee_patch14_224(pretrained=False, **kwargs): + """ ViT-GEE (big-G) model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 + As per https://twitter.com/wightmanr/status/1570549064667889666 """ model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, **kwargs) - model = _create_vision_transformer('vit_gigantic_patch14_224', pretrained=pretrained, **model_kwargs) + model = _create_vision_transformer('vit_gee_patch14_224', pretrained=pretrained, **model_kwargs) return model @@ -1085,3 +1206,44 @@ def vit_base_patch16_18x2_224(pretrained=False, **kwargs): patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock, **kwargs) model = _create_vision_transformer('vit_base_patch16_18x2_224', pretrained=pretrained, **model_kwargs) return model + + +@register_model +def vit_base_patch32_224_clip_laion2b(pretrained=False, **kwargs): + """ ViT-B/32 + Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs. + """ + model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, **kwargs) + model = _create_vision_transformer('vit_base_patch32_224_clip_laion2b', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch14_224_clip_laion2b(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/14) + Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs. + """ + model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, **kwargs) + model = _create_vision_transformer('vit_large_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_huge_patch14_224_clip_laion2b(pretrained=False, **kwargs): + """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). + Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs. + """ + model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, **kwargs) + model = _create_vision_transformer('vit_huge_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_giant_patch14_224_clip_laion2b(pretrained=False, **kwargs): + """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 + Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs. + """ + model_kwargs = dict( + patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, pre_norm=True, **kwargs) + model = _create_vision_transformer('vit_giant_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs) + return model diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 24ff2096..156894ac 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -101,7 +101,16 @@ class HybridEmbed(nn.Module): """ CNN Feature Map Embedding Extract feature map from CNN, flatten, project to embedding dim. """ - def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768): + def __init__( + self, + backbone, + img_size=224, + patch_size=1, + feature_size=None, + in_chans=3, + embed_dim=768, + bias=True, + ): super().__init__() assert isinstance(backbone, nn.Module) img_size = to_2tuple(img_size) @@ -130,7 +139,7 @@ class HybridEmbed(nn.Module): assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0 self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] - self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size) + self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) def forward(self, x): x = self.backbone(x)