Vision Transformer refactoring and Rel Pos impl

pull/1245/head
Ross Wightman 2 years ago
parent b7cb8d0337
commit 41dc49a337

@ -23,6 +23,15 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
## What's New
### May 2, 2022
* Vision Transformer experiments adding Relative Position (Swin-V2 log-coord) (`vision_transformer_relpos.py`) and Residual Post-Norm branches (from Swin-V2) (`vision_transformer*.py`)
* `vit_relpos_base_patch32_plus_rpn_256` - 79.5 @ 256, 80.6 @ 320 -- rel pos + extended width + res-post-norm, no class token, avg pool
* `vit_relpos_base_patch16_224` - 82.5 @ 224, 83.6 @ 320 -- rel pos, layer scale, no class token, avg pool
* `vit_base_patch16_rpn_224` - 82.3 @ 224 -- rel pos + res-post-norm, no class token, avg pool
* Vision Transformer refactor to remove representation layer that was only used in initial vit and rarely used since with newer pretrain (ie `How to Train Your ViT`)
* `vit_*` models support removal of class token, use of global average pool, use of fc_norm (ala beit, mae).
### April 22, 2022
* `timm` models are now officially supported in [fast.ai](https://www.fast.ai/)! Just in time for the new Practical Deep Learning course. `timmdocs` documentation link updated to [timm.fast.ai](http://timm.fast.ai/).
* Two more model weights added in the TPU trained [series](https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-tpu-weights). Some In22k pretrain still in progress.

@ -49,6 +49,7 @@ from .vgg import *
from .visformer import *
from .vision_transformer import *
from .vision_transformer_hybrid import *
from .vision_transformer_relpos import *
from .volo import *
from .vovnet import *
from .xception import *

@ -23,6 +23,7 @@ import math
import logging
from functools import partial
from collections import OrderedDict
from typing import Optional
import torch
import torch.nn as nn
@ -107,7 +108,6 @@ default_cfgs = {
'vit_giant_patch14_224': _cfg(url=''),
'vit_gigantic_patch14_224': _cfg(url=''),
'vit_base2_patch32_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95),
# patch models, imagenet21k (weights from official Google JAX impl)
'vit_tiny_patch16_224_in21k': _cfg(
@ -171,7 +171,12 @@ default_cfgs = {
mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear',
),
# experimental
'vit_base_patch16_rpn_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth'),
# experimental (may be removed)
'vit_base_patch32_plus_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95),
'vit_base_patch16_plus_240': _cfg(url='', input_size=(3, 240, 240), crop_pct=0.95),
'vit_small_patch16_36x1_224': _cfg(url=''),
'vit_small_patch16_18x2_224': _cfg(url=''),
'vit_base_patch16_18x2_224': _cfg(url=''),
@ -229,8 +234,7 @@ class Block(nn.Module):
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -240,6 +244,36 @@ class Block(nn.Module):
return x
class ResPostBlock(nn.Module):
def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.init_values = init_values
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.norm1 = norm_layer(dim)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.norm2 = norm_layer(dim)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.init_weights()
def init_weights(self):
# NOTE this init overrides that base model init with specific changes for the block type
if self.init_values is not None:
nn.init.constant_(self.norm1.weight, self.init_values)
nn.init.constant_(self.norm2.weight, self.init_values)
def forward(self, x):
x = x + self.drop_path1(self.norm1(self.attn(x)))
x = x + self.drop_path2(self.norm2(self.mlp(x)))
return x
class ParallelBlock(nn.Module):
def __init__(
@ -290,9 +324,9 @@ class VisionTransformer(nn.Module):
def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', init_values=None,
embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block):
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', class_token=True,
fc_norm=None, embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block):
"""
Args:
img_size (int, tuple): input image size
@ -305,33 +339,36 @@ class VisionTransformer(nn.Module):
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
init_values: (float): layer-scale init values
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
weight_init: (str): weight init scheme
init_values: (float): layer-scale init values
weight_init (str): weight init scheme
class_token (bool): use class token
fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)
embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer
act_layer: (nn.Module): MLP activation layer
"""
super().__init__()
assert global_pool in ('', 'avg', 'token')
assert class_token or global_pool != 'token'
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 1
self.num_tokens = 1 if class_token else 0
self.grad_checkpointing = False
self.patch_embed = embed_layer(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if self.num_tokens > 0 else None
self.pos_embed = nn.Parameter(torch.randn(1, num_patches + self.num_tokens, embed_dim) * .02)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
@ -340,38 +377,21 @@ class VisionTransformer(nn.Module):
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
for i in range(depth)])
use_fc_norm = self.global_pool == 'avg'
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
# Representation layer. Used for original ViT models w/ in21k pretraining.
self.representation_size = representation_size
self.pre_logits = nn.Identity()
if representation_size:
self._reset_representation(representation_size)
# Classifier Head
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
final_chs = self.representation_size if self.representation_size else self.embed_dim
self.head = nn.Linear(final_chs, num_classes) if num_classes > 0 else nn.Identity()
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if weight_init != 'skip':
self.init_weights(weight_init)
def _reset_representation(self, representation_size):
self.representation_size = representation_size
if self.representation_size:
self.pre_logits = nn.Sequential(OrderedDict([
('fc', nn.Linear(self.embed_dim, self.representation_size)),
('act', nn.Tanh())
]))
else:
self.pre_logits = nn.Identity()
def init_weights(self, mode=''):
assert mode in ('jax', 'jax_nlhb', 'moco', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
trunc_normal_(self.pos_embed, std=.02)
nn.init.normal_(self.cls_token, std=1e-6)
if self.cls_token is not None:
nn.init.normal_(self.cls_token, std=1e-6)
named_apply(get_init_weights_vit(mode, head_bias), self)
def _init_weights(self, m):
@ -401,19 +421,17 @@ class VisionTransformer(nn.Module):
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes: int, global_pool=None, representation_size=None):
def reset_classifier(self, num_classes: int, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg', 'token')
self.global_pool = global_pool
if representation_size is not None:
self._reset_representation(representation_size)
final_chs = self.representation_size if self.representation_size else self.embed_dim
self.head = nn.Linear(final_chs, num_classes) if num_classes > 0 else nn.Identity()
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.patch_embed(x)
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = self.pos_drop(x + self.pos_embed)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
@ -424,9 +442,8 @@ class VisionTransformer(nn.Module):
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool:
x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
x = self.fc_norm(x)
x = self.pre_logits(x)
return x if pre_logits else self.head(x)
def forward(self, x):
@ -441,6 +458,8 @@ def init_weights_vit_timm(module: nn.Module, name: str = ''):
trunc_normal_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif hasattr(module, 'init_weights'):
module.init_weights()
def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.):
@ -449,9 +468,6 @@ def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0
if name.startswith('head'):
nn.init.zeros_(module.weight)
nn.init.constant_(module.bias, head_bias)
elif name.startswith('pre_logits'):
lecun_normal_(module.weight)
nn.init.zeros_(module.bias)
else:
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
@ -460,6 +476,8 @@ def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif hasattr(module, 'init_weights'):
module.init_weights()
def init_weights_vit_moco(module: nn.Module, name: str = ''):
@ -473,6 +491,8 @@ def init_weights_vit_moco(module: nn.Module, name: str = ''):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif hasattr(module, 'init_weights'):
module.init_weights()
def get_init_weights_vit(mode='jax', head_bias: float = 0.):
@ -543,9 +563,10 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
# NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights
# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
for i, block in enumerate(model.blocks.children()):
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
@ -601,6 +622,9 @@ def checkpoint_filter_fn(state_dict, model):
# To resize pos embedding when using model at different size from pretrained weights
v = resize_pos_embed(
v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
elif 'pre_logits' in k:
# NOTE representation layer removed as not used in latest 21k/1k pretrained weights
continue
out_dict[k] = v
return out_dict
@ -609,21 +633,10 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
# NOTE this extra code to support handling of repr size for in21k pretrained models
pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs)
default_num_classes = pretrained_cfg['num_classes']
num_classes = kwargs.get('num_classes', default_num_classes)
repr_size = kwargs.pop('representation_size', None)
if repr_size is not None and num_classes != default_num_classes:
# Remove representation layer if fine-tuning. This may not always be the desired action,
# but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface?
_logger.warning("Removing representation layer for fine-tuning.")
repr_size = None
model = build_model_with_cfg(
VisionTransformer, variant, pretrained,
pretrained_cfg=pretrained_cfg,
representation_size=repr_size,
pretrained_filter_fn=checkpoint_filter_fn,
pretrained_custom_load='npz' in pretrained_cfg['url'],
**kwargs)
@ -696,16 +709,6 @@ def vit_base_patch32_224(pretrained=False, **kwargs):
return model
@register_model
def vit_base2_patch32_256(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/32)
# FIXME experiment
"""
model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, **kwargs)
model = _create_vision_transformer('vit_base2_patch32_256', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch32_384(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
@ -860,8 +863,7 @@ def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
"""
model_kwargs = dict(
patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@ -872,8 +874,7 @@ def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@ -884,8 +885,7 @@ def vit_base_patch8_224_in21k(pretrained=False, **kwargs):
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
"""
model_kwargs = dict(
patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs)
model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch8_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@ -896,8 +896,7 @@ def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
"""
model_kwargs = dict(
patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@ -908,8 +907,7 @@ def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
"""
model_kwargs = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@ -920,8 +918,7 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
"""
model_kwargs = dict(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs)
model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@ -930,7 +927,6 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
def vit_base_patch16_224_sam(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548
"""
# NOTE original SAM weights release worked with representation_size=768
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224_sam', pretrained=pretrained, **model_kwargs)
return model
@ -940,7 +936,6 @@ def vit_base_patch16_224_sam(pretrained=False, **kwargs):
def vit_base_patch32_224_sam(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/32) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548
"""
# NOTE original SAM weights release worked with representation_size=768
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch32_224_sam', pretrained=pretrained, **model_kwargs)
return model
@ -1002,6 +997,37 @@ def vit_base_patch16_224_miil(pretrained=False, **kwargs):
return model
# Experimental models below
@register_model
def vit_base_patch32_plus_256(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/32+)
"""
model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs)
model = _create_vision_transformer('vit_base_patch32_plus_256', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch16_plus_240(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16+)
"""
model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs)
model = _create_vision_transformer('vit_base_patch16_plus_240', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch16_rpn_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ residual post-norm
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5, class_token=False,
block_fn=ResPostBlock, global_pool=kwargs.pop('global_pool', 'avg'), **kwargs)
model = _create_vision_transformer('vit_base_patch16_rpn_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_patch16_36x1_224(pretrained=False, **kwargs):
""" ViT-Base w/ LayerScale + 36 x 1 (36 block serial) config. Experimental, may remove.

@ -295,7 +295,7 @@ def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs):
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
"""
backbone = _resnetv2(layers=(3, 4, 9), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_base_r50_s16_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model

@ -0,0 +1,425 @@
""" Relative Position Vision Transformer (ViT) in PyTorch
Hacked together by / Copyright 2022, Ross Wightman
"""
import math
import logging
from functools import partial
from collections import OrderedDict
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, to_2tuple
from .registry import register_model
_logger = logging.getLogger(__name__)
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = {
'vit_relpos_base_patch32_plus_rpn_256': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_replos_base_patch32_plus_rpn_256-sw-dd486f51.pth',
input_size=(3, 256, 256)),
'vit_relpos_base_patch16_plus_240': _cfg(url='', input_size=(3, 240, 240)),
'vit_relpos_base_patch16_rpn_224': _cfg(url=''),
'vit_relpos_base_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_224-sw-49049aed.pth'),
}
def gen_relative_position_index(win_size: Tuple[int, int], class_token: int = 0) -> torch.Tensor:
# cut and paste w/ modifications from swin / beit codebase
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
window_area = win_size[0] * win_size[1]
coords = torch.stack(torch.meshgrid([torch.arange(win_size[0]), torch.arange(win_size[1])])).flatten(1) # 2, Wh, Ww
relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += win_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += win_size[1] - 1
relative_coords[:, :, 0] *= 2 * win_size[1] - 1
if class_token:
num_relative_distance = (2 * win_size[0] - 1) * (2 * win_size[1] - 1) + 3
relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = num_relative_distance - 3
relative_position_index[0:, 0] = num_relative_distance - 2
relative_position_index[0, 0] = num_relative_distance - 1
else:
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
return relative_position_index
def gen_relative_position_log(win_size: Tuple[int, int]) -> torch.Tensor:
"""Method initializes the pair-wise relative positions to compute the positional biases."""
coordinates = torch.stack(torch.meshgrid([torch.arange(win_size[0]), torch.arange(win_size[1])])).flatten(1)
relative_coords = coordinates[:, :, None] - coordinates[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).float()
relative_coordinates_log = torch.sign(relative_coords) * torch.log(1.0 + relative_coords.abs())
return relative_coordinates_log
class RelPosMlp(nn.Module):
# based on timm swin-v2 impl
def __init__(self, window_size, num_heads=8, hidden_dim=32, class_token=False):
super().__init__()
self.window_size = window_size
self.window_area = self.window_size[0] * self.window_size[1]
self.class_token = 1 if class_token else 0
self.num_heads = num_heads
self.mlp = Mlp(
2, # x, y
hidden_features=min(128, hidden_dim * num_heads),
out_features=num_heads,
act_layer=nn.ReLU,
drop=(0.125, 0.)
)
self.register_buffer(
'rel_coords_log',
gen_relative_position_log(window_size),
persistent=False
)
def get_bias(self) -> torch.Tensor:
relative_position_bias = self.mlp(self.rel_coords_log).permute(2, 0, 1).unsqueeze(0)
if self.class_token:
relative_position_bias = F.pad(relative_position_bias, [self.class_token, 0, self.class_token, 0])
return relative_position_bias
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
return attn + self.get_bias()
class RelPosBias(nn.Module):
def __init__(self, window_size, num_heads, class_token=False):
super().__init__()
self.window_size = window_size
self.window_area = window_size[0] * window_size[1]
self.class_token = 1 if class_token else 0
self.bias_shape = (self.window_area + self.class_token,) * 2 + (num_heads,)
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * self.class_token
self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads))
self.register_buffer(
"relative_position_index",
gen_relative_position_index(self.window_size, class_token=self.class_token),
persistent=False,
)
self.init_weights()
def init_weights(self):
trunc_normal_(self.relative_position_bias_table, std=.02)
def get_bias(self) -> torch.Tensor:
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.bias_shape) # win_h * win_w, win_h * win_w, num_heads
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
return relative_position_bias
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
return attn + self.get_bias()
class RelPosAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, rel_pos_cls=None, attn_drop=0., proj_drop=0.):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.rel_pos = rel_pos_cls(num_heads=num_heads) if rel_pos_cls else None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
if self.rel_pos is not None:
attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos)
elif shared_rel_pos is not None:
attn = attn + shared_rel_pos
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class RelPosBlock(nn.Module):
def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, rel_pos_cls=None, init_values=None,
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = RelPosAttention(
dim, num_heads, qkv_bias=qkv_bias, rel_pos_cls=rel_pos_cls, attn_drop=attn_drop, proj_drop=drop)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), shared_rel_pos=shared_rel_pos)))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x
class ResPostRelPosBlock(nn.Module):
def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, rel_pos_cls=None, init_values=None,
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.init_values = init_values
self.attn = RelPosAttention(
dim, num_heads, qkv_bias=qkv_bias, rel_pos_cls=rel_pos_cls, attn_drop=attn_drop, proj_drop=drop)
self.norm1 = norm_layer(dim)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.norm2 = norm_layer(dim)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.init_weights()
def init_weights(self):
# NOTE this init overrides that base model init with specific changes for the block type
if self.init_values is not None:
nn.init.constant_(self.norm1.weight, self.init_values)
nn.init.constant_(self.norm2.weight, self.init_values)
def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
x = x + self.drop_path1(self.norm1(self.attn(x, shared_rel_pos=shared_rel_pos)))
x = x + self.drop_path2(self.norm2(self.mlp(x)))
return x
class VisionTransformerRelPos(nn.Module):
""" Vision Transformer w/ Relative Position Bias
"""
def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='avg',
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='skip', class_token=False,
rel_pos_type='mlp', shared_rel_pos=False, fc_norm=False,
embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=RelPosBlock):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
global_pool (str): type of global pooling for final sequence (default: 'token')
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
init_values: (float): layer-scale init values
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
weight_init (str): weight init scheme
class_token (bool): use class token (default: False)
rel_pos_ty pe (str): type of relative position
shared_rel_pos (bool): share relative pos across all blocks
fc_norm (bool): use pre classifier norm
embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer
act_layer: (nn.Module): MLP activation layer
"""
super().__init__()
assert global_pool in ('', 'avg', 'token')
assert class_token or global_pool != 'token'
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 1 if class_token else 0
self.grad_checkpointing = False
self.patch_embed = embed_layer(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
feat_size = self.patch_embed.grid_size
rel_pos_cls = RelPosMlp if rel_pos_type == 'mlp' else RelPosBias
rel_pos_cls = partial(rel_pos_cls, window_size=feat_size, class_token=class_token)
self.shared_rel_pos = None
if shared_rel_pos:
self.shared_rel_pos = rel_pos_cls(num_heads=num_heads)
# NOTE shared rel pos currently mutually exclusive w/ per-block, but could support both...
rel_pos_cls = None
self.cls_token = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim)) if self.num_tokens else None
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
block_fn(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, rel_pos_cls=rel_pos_cls,
init_values=init_values, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i],
norm_layer=norm_layer, act_layer=act_layer)
for i in range(depth)])
self.norm = norm_layer(embed_dim) if not fc_norm else nn.Identity()
# Classifier Head
self.fc_norm = norm_layer(embed_dim) if fc_norm else nn.Identity()
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if weight_init != 'skip':
self.init_weights(weight_init)
def init_weights(self, mode=''):
assert mode in ('jax', 'moco', '')
if self.cls_token is not None:
nn.init.normal_(self.cls_token, std=1e-6)
# FIXME weight init scheme using PyTorch defaults curently
#named_apply(get_init_weights_vit(mode, head_bias), self)
@torch.jit.ignore
def no_weight_decay(self):
return {'cls_token'}
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^cls_token|patch_embed', # stem and embed
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes: int, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg', 'token')
self.global_pool = global_pool
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.patch_embed(x)
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
shared_rel_pos = self.shared_rel_pos.get_bias() if self.shared_rel_pos is not None else None
for blk in self.blocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(blk, x, shared_rel_pos=shared_rel_pos)
else:
x = blk(x, shared_rel_pos=shared_rel_pos)
x = self.norm(x)
return x
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool:
x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
x = self.fc_norm(x)
return x if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def _create_vision_transformer_relpos(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(VisionTransformerRelPos, variant, pretrained, **kwargs)
return model
@register_model
def vit_relpos_base_patch32_plus_rpn_256(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/32+) w/ relative log-coord position and residual post-norm, no class token
"""
model_kwargs = dict(
patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5,
block_fn=ResPostRelPosBlock, **kwargs)
model = _create_vision_transformer_relpos(
'vit_relpos_base_patch32_plus_rpn_256', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_relpos_base_patch16_plus_240(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16+) w/ relative log-coord position, no class token
"""
model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs)
model = _create_vision_transformer_relpos('vit_relpos_base_patch16_plus_240', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_relpos_base_patch16_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ relative log-coord position, no class token
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5,
fc_norm=True, **kwargs)
model = _create_vision_transformer_relpos('vit_relpos_base_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_relpos_base_patch16_rpn_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ relative log-coord position and residual post-norm, no class token
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5,
block_fn=ResPostRelPosBlock, **kwargs)
model = _create_vision_transformer_relpos('vit_relpos_base_patch16_rpn_224', pretrained=pretrained, **model_kwargs)
return model
Loading…
Cancel
Save