From 41dc49a33752b72dbb3cff5cb181b9953e07971f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 2 May 2022 15:37:39 -0700 Subject: [PATCH] Vision Transformer refactoring and Rel Pos impl --- README.md | 9 + timm/models/__init__.py | 1 + timm/models/vision_transformer.py | 190 +++++----- timm/models/vision_transformer_hybrid.py | 2 +- timm/models/vision_transformer_relpos.py | 425 +++++++++++++++++++++++ 5 files changed, 544 insertions(+), 83 deletions(-) create mode 100644 timm/models/vision_transformer_relpos.py diff --git a/README.md b/README.md index 355cedaf..df5fb968 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,15 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor ## What's New + +### May 2, 2022 +* Vision Transformer experiments adding Relative Position (Swin-V2 log-coord) (`vision_transformer_relpos.py`) and Residual Post-Norm branches (from Swin-V2) (`vision_transformer*.py`) + * `vit_relpos_base_patch32_plus_rpn_256` - 79.5 @ 256, 80.6 @ 320 -- rel pos + extended width + res-post-norm, no class token, avg pool + * `vit_relpos_base_patch16_224` - 82.5 @ 224, 83.6 @ 320 -- rel pos, layer scale, no class token, avg pool + * `vit_base_patch16_rpn_224` - 82.3 @ 224 -- rel pos + res-post-norm, no class token, avg pool +* Vision Transformer refactor to remove representation layer that was only used in initial vit and rarely used since with newer pretrain (ie `How to Train Your ViT`) +* `vit_*` models support removal of class token, use of global average pool, use of fc_norm (ala beit, mae). + ### April 22, 2022 * `timm` models are now officially supported in [fast.ai](https://www.fast.ai/)! Just in time for the new Practical Deep Learning course. `timmdocs` documentation link updated to [timm.fast.ai](http://timm.fast.ai/). * Two more model weights added in the TPU trained [series](https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-tpu-weights). Some In22k pretrain still in progress. diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 45ead5dc..c1d63dcc 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -49,6 +49,7 @@ from .vgg import * from .visformer import * from .vision_transformer import * from .vision_transformer_hybrid import * +from .vision_transformer_relpos import * from .volo import * from .vovnet import * from .xception import * diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 17faba53..33cc5db2 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -23,6 +23,7 @@ import math import logging from functools import partial from collections import OrderedDict +from typing import Optional import torch import torch.nn as nn @@ -107,7 +108,6 @@ default_cfgs = { 'vit_giant_patch14_224': _cfg(url=''), 'vit_gigantic_patch14_224': _cfg(url=''), - 'vit_base2_patch32_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95), # patch models, imagenet21k (weights from official Google JAX impl) 'vit_tiny_patch16_224_in21k': _cfg( @@ -171,7 +171,12 @@ default_cfgs = { mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', ), - # experimental + 'vit_base_patch16_rpn_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth'), + + # experimental (may be removed) + 'vit_base_patch32_plus_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95), + 'vit_base_patch16_plus_240': _cfg(url='', input_size=(3, 240, 240), crop_pct=0.95), 'vit_small_patch16_36x1_224': _cfg(url=''), 'vit_small_patch16_18x2_224': _cfg(url=''), 'vit_base_patch16_18x2_224': _cfg(url=''), @@ -229,8 +234,7 @@ class Block(nn.Module): self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -240,6 +244,36 @@ class Block(nn.Module): return x +class ResPostBlock(nn.Module): + + def __init__( + self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.init_values = init_values + + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.norm1 = norm_layer(dim) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + self.norm2 = norm_layer(dim) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.init_weights() + + def init_weights(self): + # NOTE this init overrides that base model init with specific changes for the block type + if self.init_values is not None: + nn.init.constant_(self.norm1.weight, self.init_values) + nn.init.constant_(self.norm2.weight, self.init_values) + + def forward(self, x): + x = x + self.drop_path1(self.norm1(self.attn(x))) + x = x + self.drop_path2(self.norm2(self.mlp(x))) + return x + + class ParallelBlock(nn.Module): def __init__( @@ -290,9 +324,9 @@ class VisionTransformer(nn.Module): def __init__( self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', - embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', init_values=None, - embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block): + embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', class_token=True, + fc_norm=None, embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block): """ Args: img_size (int, tuple): input image size @@ -305,33 +339,36 @@ class VisionTransformer(nn.Module): num_heads (int): number of attention heads mlp_ratio (int): ratio of mlp hidden dim to embedding dim qkv_bias (bool): enable bias for qkv if True - representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set + init_values: (float): layer-scale init values drop_rate (float): dropout rate attn_drop_rate (float): attention dropout rate drop_path_rate (float): stochastic depth rate - weight_init: (str): weight init scheme - init_values: (float): layer-scale init values + weight_init (str): weight init scheme + class_token (bool): use class token + fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None) embed_layer (nn.Module): patch embedding layer norm_layer: (nn.Module): normalization layer act_layer: (nn.Module): MLP activation layer """ super().__init__() assert global_pool in ('', 'avg', 'token') + assert class_token or global_pool != 'token' + use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) act_layer = act_layer or nn.GELU self.num_classes = num_classes self.global_pool = global_pool self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models - self.num_tokens = 1 + self.num_tokens = 1 if class_token else 0 self.grad_checkpointing = False self.patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if self.num_tokens > 0 else None + self.pos_embed = nn.Parameter(torch.randn(1, num_patches + self.num_tokens, embed_dim) * .02) self.pos_drop = nn.Dropout(p=drop_rate) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule @@ -340,38 +377,21 @@ class VisionTransformer(nn.Module): dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) for i in range(depth)]) - use_fc_norm = self.global_pool == 'avg' self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() - # Representation layer. Used for original ViT models w/ in21k pretraining. - self.representation_size = representation_size - self.pre_logits = nn.Identity() - if representation_size: - self._reset_representation(representation_size) - # Classifier Head self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() - final_chs = self.representation_size if self.representation_size else self.embed_dim - self.head = nn.Linear(final_chs, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() if weight_init != 'skip': self.init_weights(weight_init) - def _reset_representation(self, representation_size): - self.representation_size = representation_size - if self.representation_size: - self.pre_logits = nn.Sequential(OrderedDict([ - ('fc', nn.Linear(self.embed_dim, self.representation_size)), - ('act', nn.Tanh()) - ])) - else: - self.pre_logits = nn.Identity() - def init_weights(self, mode=''): assert mode in ('jax', 'jax_nlhb', 'moco', '') head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. trunc_normal_(self.pos_embed, std=.02) - nn.init.normal_(self.cls_token, std=1e-6) + if self.cls_token is not None: + nn.init.normal_(self.cls_token, std=1e-6) named_apply(get_init_weights_vit(mode, head_bias), self) def _init_weights(self, m): @@ -401,19 +421,17 @@ class VisionTransformer(nn.Module): def get_classifier(self): return self.head - def reset_classifier(self, num_classes: int, global_pool=None, representation_size=None): + def reset_classifier(self, num_classes: int, global_pool=None): self.num_classes = num_classes if global_pool is not None: assert global_pool in ('', 'avg', 'token') self.global_pool = global_pool - if representation_size is not None: - self._reset_representation(representation_size) - final_chs = self.representation_size if self.representation_size else self.embed_dim - self.head = nn.Linear(final_chs, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x = self.patch_embed(x) - x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + if self.cls_token is not None: + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) x = self.pos_drop(x + self.pos_embed) if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) @@ -424,9 +442,8 @@ class VisionTransformer(nn.Module): def forward_head(self, x, pre_logits: bool = False): if self.global_pool: - x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] x = self.fc_norm(x) - x = self.pre_logits(x) return x if pre_logits else self.head(x) def forward(self, x): @@ -441,6 +458,8 @@ def init_weights_vit_timm(module: nn.Module, name: str = ''): trunc_normal_(module.weight, std=.02) if module.bias is not None: nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.): @@ -449,9 +468,6 @@ def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0 if name.startswith('head'): nn.init.zeros_(module.weight) nn.init.constant_(module.bias, head_bias) - elif name.startswith('pre_logits'): - lecun_normal_(module.weight) - nn.init.zeros_(module.bias) else: nn.init.xavier_uniform_(module.weight) if module.bias is not None: @@ -460,6 +476,8 @@ def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0 lecun_normal_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() def init_weights_vit_moco(module: nn.Module, name: str = ''): @@ -473,6 +491,8 @@ def init_weights_vit_moco(module: nn.Module, name: str = ''): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() def get_init_weights_vit(mode='jax', head_bias: float = 0.): @@ -543,9 +563,10 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) - if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: - model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) - model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) + # NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights + # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: + # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) + # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) for i, block in enumerate(model.blocks.children()): block_prefix = f'{prefix}Transformer/encoderblock_{i}/' mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' @@ -601,6 +622,9 @@ def checkpoint_filter_fn(state_dict, model): # To resize pos embedding when using model at different size from pretrained weights v = resize_pos_embed( v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + elif 'pre_logits' in k: + # NOTE representation layer removed as not used in latest 21k/1k pretrained weights + continue out_dict[k] = v return out_dict @@ -609,21 +633,10 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') - # NOTE this extra code to support handling of repr size for in21k pretrained models pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs) - default_num_classes = pretrained_cfg['num_classes'] - num_classes = kwargs.get('num_classes', default_num_classes) - repr_size = kwargs.pop('representation_size', None) - if repr_size is not None and num_classes != default_num_classes: - # Remove representation layer if fine-tuning. This may not always be the desired action, - # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface? - _logger.warning("Removing representation layer for fine-tuning.") - repr_size = None - model = build_model_with_cfg( VisionTransformer, variant, pretrained, pretrained_cfg=pretrained_cfg, - representation_size=repr_size, pretrained_filter_fn=checkpoint_filter_fn, pretrained_custom_load='npz' in pretrained_cfg['url'], **kwargs) @@ -696,16 +709,6 @@ def vit_base_patch32_224(pretrained=False, **kwargs): return model -@register_model -def vit_base2_patch32_256(pretrained=False, **kwargs): - """ ViT-Base (ViT-B/32) - # FIXME experiment - """ - model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, **kwargs) - model = _create_vision_transformer('vit_base2_patch32_256', pretrained=pretrained, **model_kwargs) - return model - - @register_model def vit_base_patch32_384(pretrained=False, **kwargs): """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). @@ -860,8 +863,7 @@ def vit_base_patch32_224_in21k(pretrained=False, **kwargs): ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer """ - model_kwargs = dict( - patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) + model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -872,8 +874,7 @@ def vit_base_patch16_224_in21k(pretrained=False, **kwargs): ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer """ - model_kwargs = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -884,8 +885,7 @@ def vit_base_patch8_224_in21k(pretrained=False, **kwargs): ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer """ - model_kwargs = dict( - patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) + model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer('vit_base_patch8_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -896,8 +896,7 @@ def vit_large_patch32_224_in21k(pretrained=False, **kwargs): ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights """ - model_kwargs = dict( - patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs) + model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -908,8 +907,7 @@ def vit_large_patch16_224_in21k(pretrained=False, **kwargs): ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer """ - model_kwargs = dict( - patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -920,8 +918,7 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights """ - model_kwargs = dict( - patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs) + model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs) model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -930,7 +927,6 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): def vit_base_patch16_224_sam(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548 """ - # NOTE original SAM weights release worked with representation_size=768 model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer('vit_base_patch16_224_sam', pretrained=pretrained, **model_kwargs) return model @@ -940,7 +936,6 @@ def vit_base_patch16_224_sam(pretrained=False, **kwargs): def vit_base_patch32_224_sam(pretrained=False, **kwargs): """ ViT-Base (ViT-B/32) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548 """ - # NOTE original SAM weights release worked with representation_size=768 model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer('vit_base_patch32_224_sam', pretrained=pretrained, **model_kwargs) return model @@ -1002,6 +997,37 @@ def vit_base_patch16_224_miil(pretrained=False, **kwargs): return model +# Experimental models below + +@register_model +def vit_base_patch32_plus_256(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/32+) + """ + model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs) + model = _create_vision_transformer('vit_base_patch32_plus_256', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_plus_240(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16+) + """ + model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs) + model = _create_vision_transformer('vit_base_patch16_plus_240', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_rpn_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/ residual post-norm + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5, class_token=False, + block_fn=ResPostBlock, global_pool=kwargs.pop('global_pool', 'avg'), **kwargs) + model = _create_vision_transformer('vit_base_patch16_rpn_224', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_small_patch16_36x1_224(pretrained=False, **kwargs): """ ViT-Base w/ LayerScale + 36 x 1 (36 block serial) config. Experimental, may remove. diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 0eee2044..24ff2096 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -295,7 +295,7 @@ def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs): ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. """ backbone = _resnetv2(layers=(3, 4, 9), **kwargs) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer_hybrid( 'vit_base_r50_s16_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) return model diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py new file mode 100644 index 00000000..056dba97 --- /dev/null +++ b/timm/models/vision_transformer_relpos.py @@ -0,0 +1,425 @@ +""" Relative Position Vision Transformer (ViT) in PyTorch + +Hacked together by / Copyright 2022, Ross Wightman +""" +import math +import logging +from functools import partial +from collections import OrderedDict +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply +from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, to_2tuple +from .registry import register_model + +_logger = logging.getLogger(__name__) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + 'vit_relpos_base_patch32_plus_rpn_256': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_replos_base_patch32_plus_rpn_256-sw-dd486f51.pth', + input_size=(3, 256, 256)), + 'vit_relpos_base_patch16_plus_240': _cfg(url='', input_size=(3, 240, 240)), + 'vit_relpos_base_patch16_rpn_224': _cfg(url=''), + 'vit_relpos_base_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_224-sw-49049aed.pth'), +} + + +def gen_relative_position_index(win_size: Tuple[int, int], class_token: int = 0) -> torch.Tensor: + # cut and paste w/ modifications from swin / beit codebase + # cls to token & token 2 cls & cls to cls + # get pair-wise relative position index for each token inside the window + window_area = win_size[0] * win_size[1] + coords = torch.stack(torch.meshgrid([torch.arange(win_size[0]), torch.arange(win_size[1])])).flatten(1) # 2, Wh, Ww + relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += win_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += win_size[1] - 1 + relative_coords[:, :, 0] *= 2 * win_size[1] - 1 + if class_token: + num_relative_distance = (2 * win_size[0] - 1) * (2 * win_size[1] - 1) + 3 + relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = num_relative_distance - 3 + relative_position_index[0:, 0] = num_relative_distance - 2 + relative_position_index[0, 0] = num_relative_distance - 1 + else: + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + return relative_position_index + + +def gen_relative_position_log(win_size: Tuple[int, int]) -> torch.Tensor: + """Method initializes the pair-wise relative positions to compute the positional biases.""" + coordinates = torch.stack(torch.meshgrid([torch.arange(win_size[0]), torch.arange(win_size[1])])).flatten(1) + relative_coords = coordinates[:, :, None] - coordinates[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).float() + relative_coordinates_log = torch.sign(relative_coords) * torch.log(1.0 + relative_coords.abs()) + return relative_coordinates_log + + +class RelPosMlp(nn.Module): + # based on timm swin-v2 impl + def __init__(self, window_size, num_heads=8, hidden_dim=32, class_token=False): + super().__init__() + self.window_size = window_size + self.window_area = self.window_size[0] * self.window_size[1] + self.class_token = 1 if class_token else 0 + self.num_heads = num_heads + + self.mlp = Mlp( + 2, # x, y + hidden_features=min(128, hidden_dim * num_heads), + out_features=num_heads, + act_layer=nn.ReLU, + drop=(0.125, 0.) + ) + + self.register_buffer( + 'rel_coords_log', + gen_relative_position_log(window_size), + persistent=False + ) + + def get_bias(self) -> torch.Tensor: + relative_position_bias = self.mlp(self.rel_coords_log).permute(2, 0, 1).unsqueeze(0) + if self.class_token: + relative_position_bias = F.pad(relative_position_bias, [self.class_token, 0, self.class_token, 0]) + return relative_position_bias + + def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None): + return attn + self.get_bias() + + +class RelPosBias(nn.Module): + + def __init__(self, window_size, num_heads, class_token=False): + super().__init__() + self.window_size = window_size + self.window_area = window_size[0] * window_size[1] + self.class_token = 1 if class_token else 0 + self.bias_shape = (self.window_area + self.class_token,) * 2 + (num_heads,) + + num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * self.class_token + self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads)) + self.register_buffer( + "relative_position_index", + gen_relative_position_index(self.window_size, class_token=self.class_token), + persistent=False, + ) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.relative_position_bias_table, std=.02) + + def get_bias(self) -> torch.Tensor: + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.bias_shape) # win_h * win_w, win_h * win_w, num_heads + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + return relative_position_bias + + def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None): + return attn + self.get_bias() + + +class RelPosAttention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, rel_pos_cls=None, attn_drop=0., proj_drop=0.): + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.rel_pos = rel_pos_cls(num_heads=num_heads) if rel_pos_cls else None + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + if self.rel_pos is not None: + attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos) + elif shared_rel_pos is not None: + attn = attn + shared_rel_pos + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class RelPosBlock(nn.Module): + + def __init__( + self, dim, num_heads, mlp_ratio=4., qkv_bias=False, rel_pos_cls=None, init_values=None, + drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = RelPosAttention( + dim, num_heads, qkv_bias=qkv_bias, rel_pos_cls=rel_pos_cls, attn_drop=attn_drop, proj_drop=drop) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), shared_rel_pos=shared_rel_pos))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class ResPostRelPosBlock(nn.Module): + + def __init__( + self, dim, num_heads, mlp_ratio=4., qkv_bias=False, rel_pos_cls=None, init_values=None, + drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.init_values = init_values + + self.attn = RelPosAttention( + dim, num_heads, qkv_bias=qkv_bias, rel_pos_cls=rel_pos_cls, attn_drop=attn_drop, proj_drop=drop) + self.norm1 = norm_layer(dim) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + self.norm2 = norm_layer(dim) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.init_weights() + + def init_weights(self): + # NOTE this init overrides that base model init with specific changes for the block type + if self.init_values is not None: + nn.init.constant_(self.norm1.weight, self.init_values) + nn.init.constant_(self.norm2.weight, self.init_values) + + def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): + x = x + self.drop_path1(self.norm1(self.attn(x, shared_rel_pos=shared_rel_pos))) + x = x + self.drop_path2(self.norm2(self.mlp(x))) + return x + + +class VisionTransformerRelPos(nn.Module): + """ Vision Transformer w/ Relative Position Bias + """ + + def __init__( + self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='avg', + embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='skip', class_token=False, + rel_pos_type='mlp', shared_rel_pos=False, fc_norm=False, + embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=RelPosBlock): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + global_pool (str): type of global pooling for final sequence (default: 'token') + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + init_values: (float): layer-scale init values + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + weight_init (str): weight init scheme + class_token (bool): use class token (default: False) + rel_pos_ty pe (str): type of relative position + shared_rel_pos (bool): share relative pos across all blocks + fc_norm (bool): use pre classifier norm + embed_layer (nn.Module): patch embedding layer + norm_layer: (nn.Module): normalization layer + act_layer: (nn.Module): MLP activation layer + """ + super().__init__() + assert global_pool in ('', 'avg', 'token') + assert class_token or global_pool != 'token' + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 if class_token else 0 + self.grad_checkpointing = False + + self.patch_embed = embed_layer( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + feat_size = self.patch_embed.grid_size + + rel_pos_cls = RelPosMlp if rel_pos_type == 'mlp' else RelPosBias + rel_pos_cls = partial(rel_pos_cls, window_size=feat_size, class_token=class_token) + self.shared_rel_pos = None + if shared_rel_pos: + self.shared_rel_pos = rel_pos_cls(num_heads=num_heads) + # NOTE shared rel pos currently mutually exclusive w/ per-block, but could support both... + rel_pos_cls = None + + self.cls_token = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim)) if self.num_tokens else None + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + block_fn( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, rel_pos_cls=rel_pos_cls, + init_values=init_values, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], + norm_layer=norm_layer, act_layer=act_layer) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) if not fc_norm else nn.Identity() + + # Classifier Head + self.fc_norm = norm_layer(embed_dim) if fc_norm else nn.Identity() + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + if weight_init != 'skip': + self.init_weights(weight_init) + + def init_weights(self, mode=''): + assert mode in ('jax', 'moco', '') + if self.cls_token is not None: + nn.init.normal_(self.cls_token, std=1e-6) + # FIXME weight init scheme using PyTorch defaults curently + #named_apply(get_init_weights_vit(mode, head_bias), self) + + @torch.jit.ignore + def no_weight_decay(self): + return {'cls_token'} + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^cls_token|patch_embed', # stem and embed + blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes: int, global_pool=None): + self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ('', 'avg', 'token') + self.global_pool = global_pool + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + if self.cls_token is not None: + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + + shared_rel_pos = self.shared_rel_pos.get_bias() if self.shared_rel_pos is not None else None + for blk in self.blocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x, shared_rel_pos=shared_rel_pos) + else: + x = blk(x, shared_rel_pos=shared_rel_pos) + x = self.norm(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool: + x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + x = self.fc_norm(x) + return x if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _create_vision_transformer_relpos(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model = build_model_with_cfg(VisionTransformerRelPos, variant, pretrained, **kwargs) + return model + + +@register_model +def vit_relpos_base_patch32_plus_rpn_256(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/32+) w/ relative log-coord position and residual post-norm, no class token + """ + model_kwargs = dict( + patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, + block_fn=ResPostRelPosBlock, **kwargs) + model = _create_vision_transformer_relpos( + 'vit_relpos_base_patch32_plus_rpn_256', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_relpos_base_patch16_plus_240(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16+) w/ relative log-coord position, no class token + """ + model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs) + model = _create_vision_transformer_relpos('vit_relpos_base_patch16_plus_240', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_relpos_base_patch16_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/ relative log-coord position, no class token + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5, + fc_norm=True, **kwargs) + model = _create_vision_transformer_relpos('vit_relpos_base_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_relpos_base_patch16_rpn_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/ relative log-coord position and residual post-norm, no class token + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5, + block_fn=ResPostRelPosBlock, **kwargs) + model = _create_vision_transformer_relpos('vit_relpos_base_patch16_rpn_224', pretrained=pretrained, **model_kwargs) + return model