Adding support for fine-tune CLIP LAION-2B image tower weights for B/32, L/14, H/14 and g/14. Still WIP

pull/1467/head
Ross Wightman 2 years ago
parent a520da9b49
commit 9709dbaaa9

@ -38,7 +38,7 @@ if 'GITHUB_ACTIONS' in os.environ:
EXCLUDE_FILTERS = [ EXCLUDE_FILTERS = [
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm', '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*', '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*',
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*', '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_g*', 'swin*huge*',
'swin*giant*'] 'swin*giant*']
NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*'] NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*']
else: else:

@ -138,6 +138,9 @@ def _resolve_pretrained_source(pretrained_cfg):
# hf-hub available as alternate weight source in default_cfg # hf-hub available as alternate weight source in default_cfg
load_from = 'hf-hub' load_from = 'hf-hub'
pretrained_loc = hf_hub_id pretrained_loc = hf_hub_id
if load_from == 'hf-hub' and 'hf_hub_filename' in pretrained_cfg:
# if a filename override is set, return tuple for location w/ (hub_id, filename)
pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename']
return load_from, pretrained_loc return load_from, pretrained_loc
@ -246,6 +249,9 @@ def load_pretrained(
pretrained_loc, map_location='cpu', progress=_DOWNLOAD_PROGRESS, check_hash=_CHECK_HASH) pretrained_loc, map_location='cpu', progress=_DOWNLOAD_PROGRESS, check_hash=_CHECK_HASH)
elif load_from == 'hf-hub': elif load_from == 'hf-hub':
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})') _logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
if isinstance(pretrained_loc, (list, tuple)):
state_dict = load_state_dict_from_hf(*pretrained_loc)
else:
state_dict = load_state_dict_from_hf(pretrained_loc) state_dict = load_state_dict_from_hf(pretrained_loc)
else: else:
_logger.warning("No pretrained weights exist or were found for this model. Using random initialization.") _logger.warning("No pretrained weights exist or were found for this model. Using random initialization.")

@ -55,7 +55,7 @@ def download_cached_file(url, check_hash=True, progress=False):
def has_hf_hub(necessary=False): def has_hf_hub(necessary=False):
if not _has_hf_hub and necessary: if not _has_hf_hub and necessary:
# if no HF Hub module installed and it is necessary to continue, raise error # if no HF Hub module installed, and it is necessary to continue, raise error
raise RuntimeError( raise RuntimeError(
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
return _has_hf_hub return _has_hf_hub
@ -78,7 +78,7 @@ def load_cfg_from_json(json_file: Union[str, os.PathLike]):
def _download_from_hf(model_id: str, filename: str): def _download_from_hf(model_id: str, filename: str):
hf_model_id, hf_revision = hf_split(model_id) hf_model_id, hf_revision = hf_split(model_id)
return hf_hub_download(hf_model_id, filename, revision=hf_revision, cache_dir=get_cache_dir('hf')) return hf_hub_download(hf_model_id, filename, revision=hf_revision)
def load_model_config_from_hf(model_id: str): def load_model_config_from_hf(model_id: str):
@ -91,9 +91,9 @@ def load_model_config_from_hf(model_id: str):
return pretrained_cfg, model_name return pretrained_cfg, model_name
def load_state_dict_from_hf(model_id: str): def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'):
assert has_hf_hub(True) assert has_hf_hub(True)
cached_file = _download_from_hf(model_id, 'pytorch_model.bin') cached_file = _download_from_hf(model_id, filename)
state_dict = torch.load(cached_file, map_location='cpu') state_dict = torch.load(cached_file, map_location='cpu')
return state_dict return state_dict

@ -15,7 +15,16 @@ from .trace_utils import _assert
class PatchEmbed(nn.Module): class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding """ 2D Image to Patch Embedding
""" """
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten=True,
bias=True,
):
super().__init__() super().__init__()
img_size = to_2tuple(img_size) img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size) patch_size = to_2tuple(patch_size)
@ -25,7 +34,7 @@ class PatchEmbed(nn.Module):
self.num_patches = self.grid_size[0] * self.grid_size[1] self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten self.flatten = flatten
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x): def forward(self, x):

@ -25,6 +25,7 @@ from functools import partial
from collections import OrderedDict from collections import OrderedDict
from typing import Optional from typing import Optional
import huggingface_hub.file_download
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -106,7 +107,7 @@ default_cfgs = {
'vit_large_patch14_224': _cfg(url=''), 'vit_large_patch14_224': _cfg(url=''),
'vit_huge_patch14_224': _cfg(url=''), 'vit_huge_patch14_224': _cfg(url=''),
'vit_giant_patch14_224': _cfg(url=''), 'vit_giant_patch14_224': _cfg(url=''),
'vit_gigantic_patch14_224': _cfg(url=''), 'vit_gee_patch14_224': _cfg(url=''),
# patch models, imagenet21k (weights from official Google JAX impl) # patch models, imagenet21k (weights from official Google JAX impl)
@ -177,6 +178,20 @@ default_cfgs = {
'vit_small_patch16_36x1_224': _cfg(url=''), 'vit_small_patch16_36x1_224': _cfg(url=''),
'vit_small_patch16_18x2_224': _cfg(url=''), 'vit_small_patch16_18x2_224': _cfg(url=''),
'vit_base_patch16_18x2_224': _cfg(url=''), 'vit_base_patch16_18x2_224': _cfg(url=''),
'vit_base_patch32_224_clip_laion2b': _cfg(
hf_hub_id='',
num_classes=512),
'vit_large_patch14_224_clip_laion2b': _cfg(
hf_hub_id='',
num_classes=768),
'vit_huge_patch14_224_clip_laion2b': _cfg(
hf_hub_id='',
num_classes=1024),
'vit_giant_patch14_224_clip_laion2b': _cfg(
hf_hub_id='',
num_classes=1024),
} }
@ -221,8 +236,18 @@ class LayerScale(nn.Module):
class Block(nn.Module): class Block(nn.Module):
def __init__( def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, self,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
drop=0.,
attn_drop=0.,
init_values=None,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm
):
super().__init__() super().__init__()
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
@ -244,8 +269,18 @@ class Block(nn.Module):
class ResPostBlock(nn.Module): class ResPostBlock(nn.Module):
def __init__( def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, self,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
drop=0.,
attn_drop=0.,
init_values=None,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm
):
super().__init__() super().__init__()
self.init_values = init_values self.init_values = init_values
@ -274,8 +309,19 @@ class ResPostBlock(nn.Module):
class ParallelBlock(nn.Module): class ParallelBlock(nn.Module):
def __init__( def __init__(
self, dim, num_heads, num_parallel=2, mlp_ratio=4., qkv_bias=False, init_values=None, self,
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): dim,
num_heads,
num_parallel=2,
mlp_ratio=4.,
qkv_bias=False,
init_values=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm
):
super().__init__() super().__init__()
self.num_parallel = num_parallel self.num_parallel = num_parallel
self.attns = nn.ModuleList() self.attns = nn.ModuleList()
@ -320,10 +366,31 @@ class VisionTransformer(nn.Module):
""" """
def __init__( def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', self,
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None, img_size=224,
class_token=True, no_embed_class=False, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., patch_size=16,
weight_init='', embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block): in_chans=3,
num_classes=1000,
global_pool='token',
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=True,
init_values=None,
class_token=True,
no_embed_class=False,
pre_norm=False,
fc_norm=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
weight_init='',
embed_layer=PatchEmbed,
norm_layer=None,
act_layer=None,
block_fn=Block,
):
""" """
Args: Args:
img_size (int, tuple): input image size img_size (int, tuple): input image size
@ -362,19 +429,34 @@ class VisionTransformer(nn.Module):
self.grad_checkpointing = False self.grad_checkpointing = False
self.patch_embed = embed_layer( self.patch_embed = embed_layer(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
)
num_patches = self.patch_embed.num_patches num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
self.pos_drop = nn.Dropout(p=drop_rate) self.pos_drop = nn.Dropout(p=drop_rate)
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.Sequential(*[ self.blocks = nn.Sequential(*[
block_fn( block_fn(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values, dim=embed_dim,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
init_values=init_values,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer
)
for i in range(depth)]) for i in range(depth)])
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
@ -445,6 +527,7 @@ class VisionTransformer(nn.Module):
def forward_features(self, x): def forward_features(self, x):
x = self.patch_embed(x) x = self.patch_embed(x)
x = self._pos_embed(x) x = self._pos_embed(x)
x = self.norm_pre(x)
if self.grad_checkpointing and not torch.jit.is_scripting(): if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x) x = checkpoint_seq(self.blocks, x)
else: else:
@ -623,6 +706,40 @@ def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()):
return posemb return posemb
def _convert_openai_clip(state_dict, model):
out_dict = {}
swaps = [
('visual.', ''), ('conv1', 'patch_embed.proj'), ('positional_embedding', 'pos_embed'),
('transformer.resblocks.', 'blocks.'), ('ln_pre', 'norm_pre'), ('ln_post', 'norm'), ('ln_', 'norm'),
('in_proj_', 'qkv.'), ('out_proj', 'proj'), ('mlp.c_fc', 'mlp.fc1'), ('mlp.c_proj', 'mlp.fc2'),
]
for k, v in state_dict.items():
if not k.startswith('visual.'):
continue
for sp in swaps:
k = k.replace(sp[0], sp[1])
if k == 'proj':
k = 'head.weight'
v = v.transpose(0, 1)
out_dict['head.bias'] = torch.zeros(v.shape[0])
elif k == 'class_embedding':
k = 'cls_token'
v = v.unsqueeze(0).unsqueeze(1)
elif k == 'pos_embed':
v = v.unsqueeze(0)
if v.shape[1] != model.pos_embed.shape[1]:
# To resize pos embedding when using model at different size from pretrained weights
v = resize_pos_embed(
v,
model.pos_embed,
0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1),
model.patch_embed.grid_size
)
out_dict[k] = v
return out_dict
def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False): def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False):
""" convert patch embedding weight from manual patchify + linear proj to conv""" """ convert patch embedding weight from manual patchify + linear proj to conv"""
import re import re
@ -631,6 +748,9 @@ def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False):
# For deit models # For deit models
state_dict = state_dict['model'] state_dict = state_dict['model']
if 'visual.class_embedding' in state_dict:
return _convert_openai_clip(state_dict, model)
for k, v in state_dict.items(): for k, v in state_dict.items():
if 'patch_embed.proj.weight' in k and len(v.shape) < 4: if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
# For old models that I trained prior to conv based patchification # For old models that I trained prior to conv based patchification
@ -833,7 +953,7 @@ def vit_huge_patch14_224(pretrained=False, **kwargs):
@register_model @register_model
def vit_giant_patch14_224(pretrained=False, **kwargs): def vit_giant_patch14_224(pretrained=False, **kwargs):
""" ViT-Giant model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
""" """
model_kwargs = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, **kwargs) model_kwargs = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **model_kwargs)
@ -841,11 +961,12 @@ def vit_giant_patch14_224(pretrained=False, **kwargs):
@register_model @register_model
def vit_gigantic_patch14_224(pretrained=False, **kwargs): def vit_gee_patch14_224(pretrained=False, **kwargs):
""" ViT-Gigantic model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 """ ViT-GEE (big-G) model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
As per https://twitter.com/wightmanr/status/1570549064667889666
""" """
model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, **kwargs) model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_gigantic_patch14_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_gee_patch14_224', pretrained=pretrained, **model_kwargs)
return model return model
@ -1085,3 +1206,44 @@ def vit_base_patch16_18x2_224(pretrained=False, **kwargs):
patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock, **kwargs) patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock, **kwargs)
model = _create_vision_transformer('vit_base_patch16_18x2_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_base_patch16_18x2_224', pretrained=pretrained, **model_kwargs)
return model return model
@register_model
def vit_base_patch32_224_clip_laion2b(pretrained=False, **kwargs):
""" ViT-B/32
Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
"""
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, **kwargs)
model = _create_vision_transformer('vit_base_patch32_224_clip_laion2b', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_large_patch14_224_clip_laion2b(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/14)
Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
"""
model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, **kwargs)
model = _create_vision_transformer('vit_large_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_huge_patch14_224_clip_laion2b(pretrained=False, **kwargs):
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
"""
model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, **kwargs)
model = _create_vision_transformer('vit_huge_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_giant_patch14_224_clip_laion2b(pretrained=False, **kwargs):
""" ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
"""
model_kwargs = dict(
patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, pre_norm=True, **kwargs)
model = _create_vision_transformer('vit_giant_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs)
return model

@ -101,7 +101,16 @@ class HybridEmbed(nn.Module):
""" CNN Feature Map Embedding """ CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim. Extract feature map from CNN, flatten, project to embedding dim.
""" """
def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768): def __init__(
self,
backbone,
img_size=224,
patch_size=1,
feature_size=None,
in_chans=3,
embed_dim=768,
bias=True,
):
super().__init__() super().__init__()
assert isinstance(backbone, nn.Module) assert isinstance(backbone, nn.Module)
img_size = to_2tuple(img_size) img_size = to_2tuple(img_size)
@ -130,7 +139,7 @@ class HybridEmbed(nn.Module):
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0 assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1]) self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1] self.num_patches = self.grid_size[0] * self.grid_size[1]
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size) self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
def forward(self, x): def forward(self, x):
x = self.backbone(x) x = self.backbone(x)

Loading…
Cancel
Save