diff --git a/requirements.txt b/requirements.txt index 2d29a27c..5846bb36 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ -torch>=1.4.0 -torchvision>=0.5.0 +torch>=1.7 +torchvision pyyaml +huggingface_hub diff --git a/setup.py b/setup.py index 882ed467..59b4ed4c 100644 --- a/setup.py +++ b/setup.py @@ -25,13 +25,15 @@ setup( # 3 - Alpha # 4 - Beta # 5 - Production/Stable - 'Development Status :: 3 - Alpha', + 'Development Status :: 4 - Beta', 'Intended Audience :: Education', 'Intended Audience :: Science/Research', 'License :: OSI Approved :: Apache Software License', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', 'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Software Development', @@ -40,9 +42,10 @@ setup( ], # Note that this is a string of words separated by whitespace, not a list. - keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet', + keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet resnet vision transformer vit', packages=find_packages(exclude=['convert', 'tests', 'results']), include_package_data=True, - install_requires=['torch >= 1.4', 'torchvision'], + install_requires=['torch >= 1.7', 'torchvision', 'pyyaml', 'huggingface_hub'], python_requires='>=3.6', ) + diff --git a/timm/data/constants.py b/timm/data/constants.py index d6d4a01b..e4d8bb7e 100644 --- a/timm/data/constants.py +++ b/timm/data/constants.py @@ -5,3 +5,5 @@ IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255) IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3) +OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073) +OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711) diff --git a/timm/models/helpers.py b/timm/models/helpers.py index fda84171..b84a4523 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -138,6 +138,9 @@ def _resolve_pretrained_source(pretrained_cfg): # hf-hub available as alternate weight source in default_cfg load_from = 'hf-hub' pretrained_loc = hf_hub_id + if load_from == 'hf-hub' and 'hf_hub_filename' in pretrained_cfg: + # if a filename override is set, return tuple for location w/ (hub_id, filename) + pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename'] return load_from, pretrained_loc @@ -246,7 +249,10 @@ def load_pretrained( pretrained_loc, map_location='cpu', progress=_DOWNLOAD_PROGRESS, check_hash=_CHECK_HASH) elif load_from == 'hf-hub': _logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})') - state_dict = load_state_dict_from_hf(pretrained_loc) + if isinstance(pretrained_loc, (list, tuple)): + state_dict = load_state_dict_from_hf(*pretrained_loc) + else: + state_dict = load_state_dict_from_hf(pretrained_loc) else: _logger.warning("No pretrained weights exist or were found for this model. Using random initialization.") return diff --git a/timm/models/hub.py b/timm/models/hub.py index c3d3d15e..8b63ff7e 100644 --- a/timm/models/hub.py +++ b/timm/models/hub.py @@ -13,6 +13,7 @@ except ImportError: from torch.hub import _get_torch_home as get_dir from timm import __version__ + try: from huggingface_hub import HfApi, HfFolder, Repository, hf_hub_download, hf_hub_url hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__) @@ -55,7 +56,7 @@ def download_cached_file(url, check_hash=True, progress=False): def has_hf_hub(necessary=False): if not _has_hf_hub and necessary: - # if no HF Hub module installed and it is necessary to continue, raise error + # if no HF Hub module installed, and it is necessary to continue, raise error raise RuntimeError( 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') return _has_hf_hub @@ -78,7 +79,7 @@ def load_cfg_from_json(json_file: Union[str, os.PathLike]): def _download_from_hf(model_id: str, filename: str): hf_model_id, hf_revision = hf_split(model_id) - return hf_hub_download(hf_model_id, filename, revision=hf_revision, cache_dir=get_cache_dir('hf')) + return hf_hub_download(hf_model_id, filename, revision=hf_revision) def load_model_config_from_hf(model_id: str): @@ -91,9 +92,9 @@ def load_model_config_from_hf(model_id: str): return pretrained_cfg, model_name -def load_state_dict_from_hf(model_id: str): +def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'): assert has_hf_hub(True) - cached_file = _download_from_hf(model_id, 'pytorch_model.bin') + cached_file = _download_from_hf(model_id, filename) state_dict = torch.load(cached_file, map_location='cpu') return state_dict diff --git a/timm/models/layers/patch_embed.py b/timm/models/layers/patch_embed.py index 6a7facef..be8740ce 100644 --- a/timm/models/layers/patch_embed.py +++ b/timm/models/layers/patch_embed.py @@ -15,7 +15,16 @@ from .trace_utils import _assert class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding """ - def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True, + bias=True, + ): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) @@ -25,7 +34,7 @@ class PatchEmbed(nn.Module): self.num_patches = self.grid_size[0] * self.grid_size[1] self.flatten = flatten - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 296e575e..474ea23c 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -30,7 +30,8 @@ import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD,\ + OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply, adapt_input_conv, checkpoint_seq from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ from .registry import register_model @@ -177,6 +178,24 @@ default_cfgs = { 'vit_small_patch16_36x1_224': _cfg(url=''), 'vit_small_patch16_18x2_224': _cfg(url=''), 'vit_base_patch16_18x2_224': _cfg(url=''), + + 'vit_base_patch32_224_clip_laion2b': _cfg( + hf_hub_id='laion/CLIP-ViT-B-32-laion2B-s34B-b79K', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), + 'vit_large_patch14_224_clip_laion2b': _cfg( + hf_hub_id='laion/CLIP-ViT-L-14-laion2B-s32B-b82K', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=768), + 'vit_huge_patch14_224_clip_laion2b': _cfg( + hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=1024), + 'vit_giant_patch14_224_clip_laion2b': _cfg( + hf_hub_id='CLIP-ViT-g-14-laion2B-s12B-b42K', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=1024), + } @@ -221,8 +240,18 @@ class LayerScale(nn.Module): class Block(nn.Module): def __init__( - self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + drop=0., + attn_drop=0., + init_values=None, + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm + ): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) @@ -244,8 +273,18 @@ class Block(nn.Module): class ResPostBlock(nn.Module): def __init__( - self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + drop=0., + attn_drop=0., + init_values=None, + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm + ): super().__init__() self.init_values = init_values @@ -274,8 +313,19 @@ class ResPostBlock(nn.Module): class ParallelBlock(nn.Module): def __init__( - self, dim, num_heads, num_parallel=2, mlp_ratio=4., qkv_bias=False, init_values=None, - drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + self, + dim, + num_heads, + num_parallel=2, + mlp_ratio=4., + qkv_bias=False, + init_values=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm + ): super().__init__() self.num_parallel = num_parallel self.attns = nn.ModuleList() @@ -320,10 +370,31 @@ class VisionTransformer(nn.Module): """ def __init__( - self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', - embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None, - class_token=True, no_embed_class=False, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., - weight_init='', embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block): + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + global_pool='token', + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4., + qkv_bias=True, + init_values=None, + class_token=True, + no_embed_class=False, + pre_norm=False, + fc_norm=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + weight_init='', + embed_layer=PatchEmbed, + norm_layer=None, + act_layer=None, + block_fn=Block, + ): """ Args: img_size (int, tuple): input image size @@ -362,19 +433,34 @@ class VisionTransformer(nn.Module): self.grad_checkpointing = False self.patch_embed = embed_layer( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) + ) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) self.pos_drop = nn.Dropout(p=drop_rate) + self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.blocks = nn.Sequential(*[ block_fn( - dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + init_values=init_values, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer + ) for i in range(depth)]) self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() @@ -445,6 +531,7 @@ class VisionTransformer(nn.Module): def forward_features(self, x): x = self.patch_embed(x) x = self._pos_embed(x) + x = self.norm_pre(x) if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) else: @@ -623,6 +710,40 @@ def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()): return posemb +def _convert_openai_clip(state_dict, model): + out_dict = {} + swaps = [ + ('visual.', ''), ('conv1', 'patch_embed.proj'), ('positional_embedding', 'pos_embed'), + ('transformer.resblocks.', 'blocks.'), ('ln_pre', 'norm_pre'), ('ln_post', 'norm'), ('ln_', 'norm'), + ('in_proj_', 'qkv.'), ('out_proj', 'proj'), ('mlp.c_fc', 'mlp.fc1'), ('mlp.c_proj', 'mlp.fc2'), + ] + for k, v in state_dict.items(): + if not k.startswith('visual.'): + continue + for sp in swaps: + k = k.replace(sp[0], sp[1]) + + if k == 'proj': + k = 'head.weight' + v = v.transpose(0, 1) + out_dict['head.bias'] = torch.zeros(v.shape[0]) + elif k == 'class_embedding': + k = 'cls_token' + v = v.unsqueeze(0).unsqueeze(1) + elif k == 'pos_embed': + v = v.unsqueeze(0) + if v.shape[1] != model.pos_embed.shape[1]: + # To resize pos embedding when using model at different size from pretrained weights + v = resize_pos_embed( + v, + model.pos_embed, + 0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1), + model.patch_embed.grid_size + ) + out_dict[k] = v + return out_dict + + def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False): """ convert patch embedding weight from manual patchify + linear proj to conv""" import re @@ -631,6 +752,9 @@ def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False): # For deit models state_dict = state_dict['model'] + if 'visual.class_embedding' in state_dict: + return _convert_openai_clip(state_dict, model) + for k, v in state_dict.items(): if 'patch_embed.proj.weight' in k and len(v.shape) < 4: # For old models that I trained prior to conv based patchification @@ -833,7 +957,7 @@ def vit_huge_patch14_224(pretrained=False, **kwargs): @register_model def vit_giant_patch14_224(pretrained=False, **kwargs): - """ ViT-Giant model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 + """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 """ model_kwargs = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, **kwargs) model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **model_kwargs) @@ -842,7 +966,7 @@ def vit_giant_patch14_224(pretrained=False, **kwargs): @register_model def vit_gigantic_patch14_224(pretrained=False, **kwargs): - """ ViT-Gigantic model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 + """ ViT-Gigantic (big-G) model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 """ model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, **kwargs) model = _create_vision_transformer('vit_gigantic_patch14_224', pretrained=pretrained, **model_kwargs) @@ -1085,3 +1209,44 @@ def vit_base_patch16_18x2_224(pretrained=False, **kwargs): patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock, **kwargs) model = _create_vision_transformer('vit_base_patch16_18x2_224', pretrained=pretrained, **model_kwargs) return model + + +@register_model +def vit_base_patch32_224_clip_laion2b(pretrained=False, **kwargs): + """ ViT-B/32 + Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs. + """ + model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, **kwargs) + model = _create_vision_transformer('vit_base_patch32_224_clip_laion2b', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch14_224_clip_laion2b(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/14) + Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs. + """ + model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, **kwargs) + model = _create_vision_transformer('vit_large_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_huge_patch14_224_clip_laion2b(pretrained=False, **kwargs): + """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). + Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs. + """ + model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, **kwargs) + model = _create_vision_transformer('vit_huge_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_giant_patch14_224_clip_laion2b(pretrained=False, **kwargs): + """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 + Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs. + """ + model_kwargs = dict( + patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, pre_norm=True, **kwargs) + model = _create_vision_transformer('vit_giant_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs) + return model diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 24ff2096..156894ac 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -101,7 +101,16 @@ class HybridEmbed(nn.Module): """ CNN Feature Map Embedding Extract feature map from CNN, flatten, project to embedding dim. """ - def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768): + def __init__( + self, + backbone, + img_size=224, + patch_size=1, + feature_size=None, + in_chans=3, + embed_dim=768, + bias=True, + ): super().__init__() assert isinstance(backbone, nn.Module) img_size = to_2tuple(img_size) @@ -130,7 +139,7 @@ class HybridEmbed(nn.Module): assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0 self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] - self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size) + self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) def forward(self, x): x = self.backbone(x) diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py index 06788d2e..c82fd3d2 100644 --- a/timm/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -1,7 +1,7 @@ """ Optimizer Factory w/ Custom Weight Decay Hacked together by / Copyright 2021 Ross Wightman """ -import json +import logging from itertools import islice from typing import Optional, Callable, Tuple @@ -31,6 +31,8 @@ try: except ImportError: has_apex = False +_logger = logging.getLogger(__name__) + def param_groups_weight_decay( model: nn.Module, @@ -92,6 +94,7 @@ def param_groups_layer_decay( no_weight_decay_list: Tuple[str] = (), layer_decay: float = .75, end_layer_decay: Optional[float] = None, + verbose: bool = False, ): """ Parameter groups for layer-wise lr decay & weight decay @@ -142,8 +145,9 @@ def param_groups_layer_decay( param_group_names[group_name]["param_names"].append(name) param_groups[group_name]["params"].append(param) - # FIXME temporary output to debug new feature - print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) + if verbose: + import json + _logger.info("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) return list(param_groups.values())