From b9cfb64412e367a1352d46f00906453d0274282c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 14 Jun 2021 12:31:44 -0700 Subject: [PATCH] Support npz custom load for vision transformer hybrid models. Add posembed rescale for npz load. --- timm/models/layers/pool2d_same.py | 10 +- timm/models/vision_transformer.py | 96 ++++++++++++----- timm/models/vision_transformer_hybrid.py | 131 ++++++++++++++++++----- 3 files changed, 181 insertions(+), 56 deletions(-) diff --git a/timm/models/layers/pool2d_same.py b/timm/models/layers/pool2d_same.py index 5fcd0f1f..4c2a1c44 100644 --- a/timm/models/layers/pool2d_same.py +++ b/timm/models/layers/pool2d_same.py @@ -27,7 +27,8 @@ class AvgPool2dSame(nn.AvgPool2d): super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad) def forward(self, x): - return avg_pool2d_same( + x = pad_same(x, self.kernel_size, self.stride) + return F.avg_pool2d( x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) @@ -41,14 +42,15 @@ def max_pool2d_same( class MaxPool2dSame(nn.MaxPool2d): """ Tensorflow like 'SAME' wrapper for 2D max pooling """ - def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False, count_include_pad=True): + def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False): kernel_size = to_2tuple(kernel_size) stride = to_2tuple(stride) dilation = to_2tuple(dilation) - super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode, count_include_pad) + super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode) def forward(self, x): - return max_pool2d_same(x, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode) + x = pad_same(x, self.kernel_size, self.stride, value=-float('inf')) + return F.max_pool2d(x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode) def create_pool2d(pool_type, kernel_size, stride=None, **kwargs): diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index c44358df..7dd9137e 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -52,6 +52,10 @@ default_cfgs = { url='', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), ), + 'vit_tiny_patch16_384': _cfg( + url='', + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0 + ), 'vit_small_patch16_224': _cfg( url='', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), @@ -60,6 +64,14 @@ default_cfgs = { url='', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), ), + 'vit_small_patch16_384': _cfg( + url='', + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0 + ), + 'vit_small_patch32_384': _cfg( + url='', + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0 + ), # patch models (weights ported from official Google JAX impl) 'vit_base_patch16_224': _cfg( @@ -102,6 +114,7 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 'vit_huge_patch14_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz', hf_hub='timm/vit_huge_patch14_224_in21k', num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), @@ -371,24 +384,53 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = import numpy as np def _n2p(w, t=True): - if t and w.ndim == 4: - w = w.transpose([3, 2, 0, 1]) - elif t and w.ndim == 3: - w = w.transpose([2, 0, 1]) - elif t and w.ndim == 2: - w = w.transpose([1, 0]) + if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: + w = w.flatten() + if t: + if w.ndim == 4: + w = w.transpose([3, 2, 0, 1]) + elif w.ndim == 3: + w = w.transpose([2, 0, 1]) + elif w.ndim == 2: + w = w.transpose([1, 0]) return torch.from_numpy(w) w = np.load(checkpoint_path) - if not prefix: - prefix = 'opt/target/' if 'opt/target/embedding/kernel' in w else prefix - - input_conv_w = adapt_input_conv( - model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) - model.patch_embed.proj.weight.copy_(input_conv_w) + if not prefix and 'opt/target/embedding/kernel' in w: + prefix = 'opt/target/' + + if hasattr(model.patch_embed, 'backbone'): + # hybrid + backbone = model.patch_embed.backbone + stem_only = not hasattr(backbone, 'stem') + stem = backbone if stem_only else backbone.stem + stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) + stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) + stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) + if not stem_only: + for i, stage in enumerate(backbone.stages): + for j, block in enumerate(stage.blocks): + bp = f'{prefix}block{i + 1}/unit{j + 1}/' + for r in range(3): + getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) + getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) + getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) + if block.downsample is not None: + block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) + block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) + block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) + embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) + else: + embed_conv_w = adapt_input_conv( + model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) + model.patch_embed.proj.weight.copy_(embed_conv_w) model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) - model.pos_embed.copy_(_n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)) + pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) + if pos_embed_w.shape != model.pos_embed.shape: + pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights + pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + model.pos_embed.copy_(pos_embed_w) model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) if model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: @@ -396,23 +438,18 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) for i, block in enumerate(model.blocks.children()): block_prefix = f'{prefix}Transformer/encoderblock_{i}/' + mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) - mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' block.attn.qkv.weight.copy_(torch.cat([ - _n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T, - _n2p(w[f'{mha_prefix}key/kernel'], t=False).flatten(1).T, - _n2p(w[f'{mha_prefix}value/kernel'], t=False).flatten(1).T])) + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) block.attn.qkv.bias.copy_(torch.cat([ - _n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1), - _n2p(w[f'{mha_prefix}key/bias'], t=False).reshape(-1), - _n2p(w[f'{mha_prefix}value/bias'], t=False).reshape(-1)])) + _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) - block.mlp.fc1.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_0/kernel'])) - block.mlp.fc1.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_0/bias'])) - block.mlp.fc2.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_1/kernel'])) - block.mlp.fc2.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_1/bias'])) + for r in range(2): + getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) + getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) @@ -478,6 +515,7 @@ def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kw default_cfg=default_cfg, representation_size=repr_size, pretrained_filter_fn=checkpoint_filter_fn, + pretrained_custom_load='npz' in default_cfg['url'], **kwargs) return model @@ -510,6 +548,16 @@ def vit_small_patch32_224(pretrained=False, **kwargs): return model +@register_model +def vit_small_patch16_384(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/16) + NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_base_patch16_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index c807ee9a..1bfe6685 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -35,26 +35,34 @@ def _cfg(url='', **kwargs): default_cfgs = { - # hybrid in-21k models (weights ported from official Google JAX impl where they exist) - 'vit_base_r50_s16_224_in21k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth', - num_classes=21843, crop_pct=0.9), - - # hybrid in-1k models (weights ported from official JAX impl) - 'vit_base_r50_s16_384': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth', - input_size=(3, 384, 384), crop_pct=1.0), - - # hybrid in-1k models (mostly untrained, experimental configs w/ resnetv2 stdconv backbones) + # hybrid in-1k models (weights ported from official JAX impl where they exist) 'vit_tiny_r_s16_p8_224': _cfg(first_conv='patch_embed.backbone.conv'), + 'vit_tiny_r_s16_p8_384': _cfg( + first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0), 'vit_small_r_s16_p8_224': _cfg(first_conv='patch_embed.backbone.conv'), 'vit_small_r20_s16_p2_224': _cfg(), 'vit_small_r20_s16_224': _cfg(), 'vit_small_r26_s32_224': _cfg(), + 'vit_small_r26_s32_384': _cfg( + input_size=(3, 384, 384), crop_pct=1.0), 'vit_base_r20_s16_224': _cfg(), 'vit_base_r26_s32_224': _cfg(), 'vit_base_r50_s16_224': _cfg(), + 'vit_base_r50_s16_384': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth', + input_size=(3, 384, 384), crop_pct=1.0), 'vit_large_r50_s32_224': _cfg(), + 'vit_large_r50_s32_384': _cfg(), + + # hybrid in-21k models (weights ported from official Google JAX impl where they exist) + 'vit_small_r26_s32_224_in21k': _cfg( + num_classes=21843, crop_pct=0.9), + 'vit_small_r26_s32_384_in21k': _cfg( + num_classes=21843, input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_r50_s16_224_in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth', + num_classes=21843, crop_pct=0.9), + 'vit_large_r50_s32_224_in21k': _cfg(num_classes=21843, crop_pct=0.9), # hybrid models (using timm resnet backbones) 'vit_small_resnet26d_224': _cfg( @@ -99,7 +107,8 @@ class HybridEmbed(nn.Module): else: feature_dim = self.backbone.num_features assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0 - self.num_patches = feature_size[0] // patch_size[0] * feature_size[1] // patch_size[1] + self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): @@ -133,37 +142,35 @@ def _resnetv2(layers=(3, 4, 9), **kwargs): @register_model -def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs): - """ R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929). - ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. +def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs): + """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224. """ - backbone = _resnetv2(layers=(3, 4, 9), **kwargs) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs) + backbone = _resnetv2(layers=(), **kwargs) + model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs) model = _create_vision_transformer_hybrid( - 'vit_base_r50_s16_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) + 'vit_tiny_r_s16_p8_224', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @register_model -def vit_base_r50_s16_384(pretrained=False, **kwargs): - """ R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929). - ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. +def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs): + """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 384 x 384. """ - backbone = _resnetv2((3, 4, 9), **kwargs) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) + backbone = _resnetv2(layers=(), **kwargs) + model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs) model = _create_vision_transformer_hybrid( - 'vit_base_r50_s16_384', backbone=backbone, pretrained=pretrained, **model_kwargs) + 'vit_tiny_r_s16_p8_384', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @register_model -def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs): - """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224. +def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs): + """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 384 x 384. """ backbone = _resnetv2(layers=(), **kwargs) model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs) model = _create_vision_transformer_hybrid( - 'vit_tiny_r_s16_p8_224', backbone=backbone, pretrained=pretrained, **model_kwargs) + 'vit_tiny_r_s16_p8_384', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @@ -212,6 +219,17 @@ def vit_small_r26_s32_224(pretrained=False, **kwargs): return model +@register_model +def vit_small_r26_s32_384(pretrained=False, **kwargs): + """ R26+ViT-S/S32 hybrid. + """ + backbone = _resnetv2((2, 2, 2, 2), **kwargs) + model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_small_r26_s32_384', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_base_r20_s16_224(pretrained=False, **kwargs): """ R20+ViT-B/S16 hybrid. @@ -245,17 +263,74 @@ def vit_base_r50_s16_224(pretrained=False, **kwargs): return model +@register_model +def vit_base_r50_s16_384(pretrained=False, **kwargs): + """ R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + backbone = _resnetv2((3, 4, 9), **kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_base_r50_s16_384', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_large_r50_s32_224(pretrained=False, **kwargs): """ R50+ViT-L/S32 hybrid. """ backbone = _resnetv2((3, 4, 6, 3), **kwargs) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) + model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs) model = _create_vision_transformer_hybrid( 'vit_large_r50_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs) return model +@register_model +def vit_large_r50_s32_384(pretrained=False, **kwargs): + """ R50+ViT-L/S32 hybrid. + """ + backbone = _resnetv2((3, 4, 6, 3), **kwargs) + model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_large_r50_s32_384', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_r26_s32_224_in21k(pretrained=False, **kwargs): + """ R26+ViT-S/S32 hybrid. + """ + backbone = _resnetv2((2, 2, 2, 2), **kwargs) + model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_small_r26_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_r26_s32_384_in21k(pretrained=False, **kwargs): + """ R26+ViT-S/S32 hybrid. + """ + backbone = _resnetv2((2, 2, 2, 2), **kwargs) + model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_small_r26_s32_384_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs): + """ R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + """ + backbone = _resnetv2(layers=(3, 4, 9), **kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_base_r50_s16_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_small_resnet26d_224(pretrained=False, **kwargs): """ Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights.