Merge pull request #1467 from rwightman/clip_laion2b

Adding support for fine-tune CLIP LAION-2B image tower weights for B/32, L/14, H/14, and g/14.
pull/1476/head
Ross Wightman 2 years ago committed by GitHub
commit d199f6651d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,3 +1,4 @@
torch>=1.4.0 torch>=1.7
torchvision>=0.5.0 torchvision
pyyaml pyyaml
huggingface_hub

@ -25,13 +25,15 @@ setup(
# 3 - Alpha # 3 - Alpha
# 4 - Beta # 4 - Beta
# 5 - Production/Stable # 5 - Production/Stable
'Development Status :: 3 - Alpha', 'Development Status :: 4 - Beta',
'Intended Audience :: Education', 'Intended Audience :: Education',
'Intended Audience :: Science/Research', 'Intended Audience :: Science/Research',
'License :: OSI Approved :: Apache Software License', 'License :: OSI Approved :: Apache Software License',
'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Software Development', 'Topic :: Software Development',
@ -40,9 +42,10 @@ setup(
], ],
# Note that this is a string of words separated by whitespace, not a list. # Note that this is a string of words separated by whitespace, not a list.
keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet', keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet resnet vision transformer vit',
packages=find_packages(exclude=['convert', 'tests', 'results']), packages=find_packages(exclude=['convert', 'tests', 'results']),
include_package_data=True, include_package_data=True,
install_requires=['torch >= 1.4', 'torchvision'], install_requires=['torch >= 1.7', 'torchvision', 'pyyaml', 'huggingface_hub'],
python_requires='>=3.6', python_requires='>=3.6',
) )

@ -5,3 +5,5 @@ IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255) IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3) IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)

@ -138,6 +138,9 @@ def _resolve_pretrained_source(pretrained_cfg):
# hf-hub available as alternate weight source in default_cfg # hf-hub available as alternate weight source in default_cfg
load_from = 'hf-hub' load_from = 'hf-hub'
pretrained_loc = hf_hub_id pretrained_loc = hf_hub_id
if load_from == 'hf-hub' and 'hf_hub_filename' in pretrained_cfg:
# if a filename override is set, return tuple for location w/ (hub_id, filename)
pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename']
return load_from, pretrained_loc return load_from, pretrained_loc
@ -246,6 +249,9 @@ def load_pretrained(
pretrained_loc, map_location='cpu', progress=_DOWNLOAD_PROGRESS, check_hash=_CHECK_HASH) pretrained_loc, map_location='cpu', progress=_DOWNLOAD_PROGRESS, check_hash=_CHECK_HASH)
elif load_from == 'hf-hub': elif load_from == 'hf-hub':
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})') _logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
if isinstance(pretrained_loc, (list, tuple)):
state_dict = load_state_dict_from_hf(*pretrained_loc)
else:
state_dict = load_state_dict_from_hf(pretrained_loc) state_dict = load_state_dict_from_hf(pretrained_loc)
else: else:
_logger.warning("No pretrained weights exist or were found for this model. Using random initialization.") _logger.warning("No pretrained weights exist or were found for this model. Using random initialization.")

@ -13,6 +13,7 @@ except ImportError:
from torch.hub import _get_torch_home as get_dir from torch.hub import _get_torch_home as get_dir
from timm import __version__ from timm import __version__
try: try:
from huggingface_hub import HfApi, HfFolder, Repository, hf_hub_download, hf_hub_url from huggingface_hub import HfApi, HfFolder, Repository, hf_hub_download, hf_hub_url
hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__) hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
@ -55,7 +56,7 @@ def download_cached_file(url, check_hash=True, progress=False):
def has_hf_hub(necessary=False): def has_hf_hub(necessary=False):
if not _has_hf_hub and necessary: if not _has_hf_hub and necessary:
# if no HF Hub module installed and it is necessary to continue, raise error # if no HF Hub module installed, and it is necessary to continue, raise error
raise RuntimeError( raise RuntimeError(
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
return _has_hf_hub return _has_hf_hub
@ -78,7 +79,7 @@ def load_cfg_from_json(json_file: Union[str, os.PathLike]):
def _download_from_hf(model_id: str, filename: str): def _download_from_hf(model_id: str, filename: str):
hf_model_id, hf_revision = hf_split(model_id) hf_model_id, hf_revision = hf_split(model_id)
return hf_hub_download(hf_model_id, filename, revision=hf_revision, cache_dir=get_cache_dir('hf')) return hf_hub_download(hf_model_id, filename, revision=hf_revision)
def load_model_config_from_hf(model_id: str): def load_model_config_from_hf(model_id: str):
@ -91,9 +92,9 @@ def load_model_config_from_hf(model_id: str):
return pretrained_cfg, model_name return pretrained_cfg, model_name
def load_state_dict_from_hf(model_id: str): def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'):
assert has_hf_hub(True) assert has_hf_hub(True)
cached_file = _download_from_hf(model_id, 'pytorch_model.bin') cached_file = _download_from_hf(model_id, filename)
state_dict = torch.load(cached_file, map_location='cpu') state_dict = torch.load(cached_file, map_location='cpu')
return state_dict return state_dict

@ -15,7 +15,16 @@ from .trace_utils import _assert
class PatchEmbed(nn.Module): class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding """ 2D Image to Patch Embedding
""" """
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten=True,
bias=True,
):
super().__init__() super().__init__()
img_size = to_2tuple(img_size) img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size) patch_size = to_2tuple(patch_size)
@ -25,7 +34,7 @@ class PatchEmbed(nn.Module):
self.num_patches = self.grid_size[0] * self.grid_size[1] self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten self.flatten = flatten
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x): def forward(self, x):

@ -30,7 +30,8 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD,\
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply, adapt_input_conv, checkpoint_seq from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply, adapt_input_conv, checkpoint_seq
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
from .registry import register_model from .registry import register_model
@ -177,6 +178,24 @@ default_cfgs = {
'vit_small_patch16_36x1_224': _cfg(url=''), 'vit_small_patch16_36x1_224': _cfg(url=''),
'vit_small_patch16_18x2_224': _cfg(url=''), 'vit_small_patch16_18x2_224': _cfg(url=''),
'vit_base_patch16_18x2_224': _cfg(url=''), 'vit_base_patch16_18x2_224': _cfg(url=''),
'vit_base_patch32_224_clip_laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-B-32-laion2B-s34B-b79K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
'vit_large_patch14_224_clip_laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-L-14-laion2B-s32B-b82K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=768),
'vit_huge_patch14_224_clip_laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=1024),
'vit_giant_patch14_224_clip_laion2b': _cfg(
hf_hub_id='CLIP-ViT-g-14-laion2B-s12B-b42K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=1024),
} }
@ -221,8 +240,18 @@ class LayerScale(nn.Module):
class Block(nn.Module): class Block(nn.Module):
def __init__( def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, self,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
drop=0.,
attn_drop=0.,
init_values=None,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm
):
super().__init__() super().__init__()
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
@ -244,8 +273,18 @@ class Block(nn.Module):
class ResPostBlock(nn.Module): class ResPostBlock(nn.Module):
def __init__( def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, self,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
drop=0.,
attn_drop=0.,
init_values=None,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm
):
super().__init__() super().__init__()
self.init_values = init_values self.init_values = init_values
@ -274,8 +313,19 @@ class ResPostBlock(nn.Module):
class ParallelBlock(nn.Module): class ParallelBlock(nn.Module):
def __init__( def __init__(
self, dim, num_heads, num_parallel=2, mlp_ratio=4., qkv_bias=False, init_values=None, self,
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): dim,
num_heads,
num_parallel=2,
mlp_ratio=4.,
qkv_bias=False,
init_values=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm
):
super().__init__() super().__init__()
self.num_parallel = num_parallel self.num_parallel = num_parallel
self.attns = nn.ModuleList() self.attns = nn.ModuleList()
@ -320,10 +370,31 @@ class VisionTransformer(nn.Module):
""" """
def __init__( def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', self,
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None, img_size=224,
class_token=True, no_embed_class=False, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., patch_size=16,
weight_init='', embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block): in_chans=3,
num_classes=1000,
global_pool='token',
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=True,
init_values=None,
class_token=True,
no_embed_class=False,
pre_norm=False,
fc_norm=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
weight_init='',
embed_layer=PatchEmbed,
norm_layer=None,
act_layer=None,
block_fn=Block,
):
""" """
Args: Args:
img_size (int, tuple): input image size img_size (int, tuple): input image size
@ -362,19 +433,34 @@ class VisionTransformer(nn.Module):
self.grad_checkpointing = False self.grad_checkpointing = False
self.patch_embed = embed_layer( self.patch_embed = embed_layer(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
)
num_patches = self.patch_embed.num_patches num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
self.pos_drop = nn.Dropout(p=drop_rate) self.pos_drop = nn.Dropout(p=drop_rate)
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.Sequential(*[ self.blocks = nn.Sequential(*[
block_fn( block_fn(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values, dim=embed_dim,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
init_values=init_values,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer
)
for i in range(depth)]) for i in range(depth)])
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
@ -445,6 +531,7 @@ class VisionTransformer(nn.Module):
def forward_features(self, x): def forward_features(self, x):
x = self.patch_embed(x) x = self.patch_embed(x)
x = self._pos_embed(x) x = self._pos_embed(x)
x = self.norm_pre(x)
if self.grad_checkpointing and not torch.jit.is_scripting(): if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x) x = checkpoint_seq(self.blocks, x)
else: else:
@ -623,6 +710,40 @@ def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()):
return posemb return posemb
def _convert_openai_clip(state_dict, model):
out_dict = {}
swaps = [
('visual.', ''), ('conv1', 'patch_embed.proj'), ('positional_embedding', 'pos_embed'),
('transformer.resblocks.', 'blocks.'), ('ln_pre', 'norm_pre'), ('ln_post', 'norm'), ('ln_', 'norm'),
('in_proj_', 'qkv.'), ('out_proj', 'proj'), ('mlp.c_fc', 'mlp.fc1'), ('mlp.c_proj', 'mlp.fc2'),
]
for k, v in state_dict.items():
if not k.startswith('visual.'):
continue
for sp in swaps:
k = k.replace(sp[0], sp[1])
if k == 'proj':
k = 'head.weight'
v = v.transpose(0, 1)
out_dict['head.bias'] = torch.zeros(v.shape[0])
elif k == 'class_embedding':
k = 'cls_token'
v = v.unsqueeze(0).unsqueeze(1)
elif k == 'pos_embed':
v = v.unsqueeze(0)
if v.shape[1] != model.pos_embed.shape[1]:
# To resize pos embedding when using model at different size from pretrained weights
v = resize_pos_embed(
v,
model.pos_embed,
0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1),
model.patch_embed.grid_size
)
out_dict[k] = v
return out_dict
def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False): def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False):
""" convert patch embedding weight from manual patchify + linear proj to conv""" """ convert patch embedding weight from manual patchify + linear proj to conv"""
import re import re
@ -631,6 +752,9 @@ def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False):
# For deit models # For deit models
state_dict = state_dict['model'] state_dict = state_dict['model']
if 'visual.class_embedding' in state_dict:
return _convert_openai_clip(state_dict, model)
for k, v in state_dict.items(): for k, v in state_dict.items():
if 'patch_embed.proj.weight' in k and len(v.shape) < 4: if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
# For old models that I trained prior to conv based patchification # For old models that I trained prior to conv based patchification
@ -833,7 +957,7 @@ def vit_huge_patch14_224(pretrained=False, **kwargs):
@register_model @register_model
def vit_giant_patch14_224(pretrained=False, **kwargs): def vit_giant_patch14_224(pretrained=False, **kwargs):
""" ViT-Giant model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
""" """
model_kwargs = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, **kwargs) model_kwargs = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **model_kwargs)
@ -842,7 +966,7 @@ def vit_giant_patch14_224(pretrained=False, **kwargs):
@register_model @register_model
def vit_gigantic_patch14_224(pretrained=False, **kwargs): def vit_gigantic_patch14_224(pretrained=False, **kwargs):
""" ViT-Gigantic model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 """ ViT-Gigantic (big-G) model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
""" """
model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, **kwargs) model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_gigantic_patch14_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_gigantic_patch14_224', pretrained=pretrained, **model_kwargs)
@ -1085,3 +1209,44 @@ def vit_base_patch16_18x2_224(pretrained=False, **kwargs):
patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock, **kwargs) patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock, **kwargs)
model = _create_vision_transformer('vit_base_patch16_18x2_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_base_patch16_18x2_224', pretrained=pretrained, **model_kwargs)
return model return model
@register_model
def vit_base_patch32_224_clip_laion2b(pretrained=False, **kwargs):
""" ViT-B/32
Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
"""
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, **kwargs)
model = _create_vision_transformer('vit_base_patch32_224_clip_laion2b', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_large_patch14_224_clip_laion2b(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/14)
Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
"""
model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, **kwargs)
model = _create_vision_transformer('vit_large_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_huge_patch14_224_clip_laion2b(pretrained=False, **kwargs):
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
"""
model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, **kwargs)
model = _create_vision_transformer('vit_huge_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_giant_patch14_224_clip_laion2b(pretrained=False, **kwargs):
""" ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
"""
model_kwargs = dict(
patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, pre_norm=True, **kwargs)
model = _create_vision_transformer('vit_giant_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs)
return model

@ -101,7 +101,16 @@ class HybridEmbed(nn.Module):
""" CNN Feature Map Embedding """ CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim. Extract feature map from CNN, flatten, project to embedding dim.
""" """
def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768): def __init__(
self,
backbone,
img_size=224,
patch_size=1,
feature_size=None,
in_chans=3,
embed_dim=768,
bias=True,
):
super().__init__() super().__init__()
assert isinstance(backbone, nn.Module) assert isinstance(backbone, nn.Module)
img_size = to_2tuple(img_size) img_size = to_2tuple(img_size)
@ -130,7 +139,7 @@ class HybridEmbed(nn.Module):
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0 assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1]) self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1] self.num_patches = self.grid_size[0] * self.grid_size[1]
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size) self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
def forward(self, x): def forward(self, x):
x = self.backbone(x) x = self.backbone(x)

@ -1,7 +1,7 @@
""" Optimizer Factory w/ Custom Weight Decay """ Optimizer Factory w/ Custom Weight Decay
Hacked together by / Copyright 2021 Ross Wightman Hacked together by / Copyright 2021 Ross Wightman
""" """
import json import logging
from itertools import islice from itertools import islice
from typing import Optional, Callable, Tuple from typing import Optional, Callable, Tuple
@ -31,6 +31,8 @@ try:
except ImportError: except ImportError:
has_apex = False has_apex = False
_logger = logging.getLogger(__name__)
def param_groups_weight_decay( def param_groups_weight_decay(
model: nn.Module, model: nn.Module,
@ -92,6 +94,7 @@ def param_groups_layer_decay(
no_weight_decay_list: Tuple[str] = (), no_weight_decay_list: Tuple[str] = (),
layer_decay: float = .75, layer_decay: float = .75,
end_layer_decay: Optional[float] = None, end_layer_decay: Optional[float] = None,
verbose: bool = False,
): ):
""" """
Parameter groups for layer-wise lr decay & weight decay Parameter groups for layer-wise lr decay & weight decay
@ -142,8 +145,9 @@ def param_groups_layer_decay(
param_group_names[group_name]["param_names"].append(name) param_group_names[group_name]["param_names"].append(name)
param_groups[group_name]["params"].append(param) param_groups[group_name]["params"].append(param)
# FIXME temporary output to debug new feature if verbose:
print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) import json
_logger.info("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
return list(param_groups.values()) return list(param_groups.values())

Loading…
Cancel
Save