From 621e1b2182508afe97bfbabb0bf6cc83ba02d69f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 14 Feb 2023 23:32:04 -0800 Subject: [PATCH] Add ideas from 'Scaling ViT to 22-B Params', testing PyTorch 2.0 fused F.scaled_dot_product_attention impl in vit, vit_relpos, maxxvit / coatnet. --- timm/layers/__init__.py | 2 +- timm/layers/fast_norm.py | 35 ++++ timm/layers/norm.py | 39 ++++- timm/layers/pos_embed_rel.py | 23 +-- timm/models/maxxvit.py | 67 ++++++-- timm/models/vision_transformer.py | 194 ++++++++++++++++++++--- timm/models/vision_transformer_relpos.py | 151 ++++++++++++++---- 7 files changed, 420 insertions(+), 91 deletions(-) diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index 8e555b8b..47b02892 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -28,7 +28,7 @@ from .linear import Linear from .mixed_conv2d import MixedConv2d from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp, GlobalResponseNormMlp from .non_local_attn import NonLocalAttn, BatNonLocalAttn -from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d +from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\ SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d from .padding import get_padding, get_same_padding, pad_same diff --git a/timm/layers/fast_norm.py b/timm/layers/fast_norm.py index fb35e47d..17828989 100644 --- a/timm/layers/fast_norm.py +++ b/timm/layers/fast_norm.py @@ -17,6 +17,12 @@ try: except ImportError: has_apex = False +try: + from apex.normalization.fused_layer_norm import fused_rms_norm_affine, fused_rms_norm + has_apex_rmsnorm = True +except ImportError: + has_apex_rmsnorm = False + # fast (ie lower precision LN) can be disabled with this flag if issues crop up _USE_FAST_NORM = False # defaulting to False for now @@ -76,3 +82,32 @@ def fast_layer_norm( with torch.cuda.amp.autocast(enabled=False): return F.layer_norm(x, normalized_shape, weight, bias, eps) + + +def rms_norm( + x: torch.Tensor, + normalized_shape: List[int], + weight: Optional[torch.Tensor] = None, + eps: float = 1e-5, +): + dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1)) + v = torch.var(x, dim=dims, keepdim=True) + x = x * torch.rsqrt(v + eps) + if weight is not None: + x = x * weight + return x + + +def fast_rms_norm( + x: torch.Tensor, + normalized_shape: List[int], + weight: Optional[torch.Tensor] = None, + eps: float = 1e-5, +) -> torch.Tensor: + if torch.jit.is_scripting() or not has_apex_rmsnorm: + return rms_norm(x, normalized_shape, weight, eps) + + if weight is None: + return fused_rms_norm(x, normalized_shape, eps) + else: + return fused_rms_norm_affine(x, weight, normalized_shape, eps) diff --git a/timm/layers/norm.py b/timm/layers/norm.py index 77d719ed..dd939719 100644 --- a/timm/layers/norm.py +++ b/timm/layers/norm.py @@ -4,12 +4,14 @@ Norm layer definitions that support fast norm and consistent channel arg order ( Hacked together by / Copyright 2022 Ross Wightman """ +import numbers +from typing import Tuple import torch import torch.nn as nn import torch.nn.functional as F -from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm +from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm class GroupNorm(nn.GroupNorm): @@ -115,3 +117,38 @@ class LayerNormExp2d(nn.LayerNorm): else: x = _layer_norm_cf(x, self.weight, self.bias, self.eps) return x + + +class RmsNorm(nn.Module): + """ RmsNorm w/ fast (apex) norm if available + """ + normalized_shape: Tuple[int, ...] + eps: float + elementwise_affine: bool + + def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + normalized_shape = channels + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = affine + if self.elementwise_affine: + self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) + else: + self.register_parameter('weight', None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.elementwise_affine: + nn.init.ones_(self.weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # NOTE fast norm fallback needs our rms norm impl, so both paths through here. + # Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed. + x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps) + return x diff --git a/timm/layers/pos_embed_rel.py b/timm/layers/pos_embed_rel.py index 2ef25670..7b843dc5 100644 --- a/timm/layers/pos_embed_rel.py +++ b/timm/layers/pos_embed_rel.py @@ -83,8 +83,8 @@ def gen_relative_log_coords( pretrained_win_size: Tuple[int, int] = (0, 0), mode='swin', ): - assert mode in ('swin', 'cr', 'rw') - # as per official swin-v2 impl, supporting timm specific 'cr' and 'rw' log coords as well + assert mode in ('swin', 'cr') + # as per official swin-v2 impl, supporting timm specific 'cr' log coords as well relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32) relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32) relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w])) @@ -100,18 +100,9 @@ def gen_relative_log_coords( relative_coords_table = torch.sign(relative_coords_table) * torch.log2( 1.0 + relative_coords_table.abs()) / math.log2(8) else: - if mode == 'rw': - # cr w/ window size normalization -> [-1,1] log coords - relative_coords_table[:, :, 0] /= (win_size[0] - 1) - relative_coords_table[:, :, 1] /= (win_size[1] - 1) - relative_coords_table *= 8 # scale to -8, 8 - relative_coords_table = torch.sign(relative_coords_table) * torch.log2( - 1.0 + relative_coords_table.abs()) - relative_coords_table /= math.log2(9) # -> [-1, 1] - else: - # mode == 'cr' - relative_coords_table = torch.sign(relative_coords_table) * torch.log( - 1.0 + relative_coords_table.abs()) + # mode == 'cr' + relative_coords_table = torch.sign(relative_coords_table) * torch.log( + 1.0 + relative_coords_table.abs()) return relative_coords_table @@ -141,10 +132,6 @@ class RelPosMlp(nn.Module): self.bias_act = nn.Sigmoid() self.bias_gain = 16 mlp_bias = (True, False) - elif mode == 'rw': - self.bias_act = nn.Tanh() - self.bias_gain = 4 - mlp_bias = True else: self.bias_act = nn.Identity() self.bias_gain = None diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index 9030f206..5a164e88 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -160,6 +160,7 @@ class Attention2d(nn.Module): self.dim_head = dim_head self.head_first = head_first self.scale = dim_head ** -0.5 + self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias) self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None @@ -175,15 +176,31 @@ class Attention2d(nn.Module): else: q, k, v = self.qkv(x).reshape(B, 3, self.num_heads, self.dim_head, -1).unbind(1) - attn = (q.transpose(-2, -1) @ k) * self.scale - if self.rel_pos is not None: - attn = self.rel_pos(attn) - elif shared_rel_pos is not None: - attn = attn + shared_rel_pos - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) + if self.fast_attn: + if self.rel_pos is not None: + attn_bias = self.rel_pos.get_bias() + elif shared_rel_pos is not None: + attn_bias = shared_rel_pos + else: + attn_bias = None + x = torch.nn.functional.scaled_dot_product_attention( + q.transpose(-1, -2), + k.transpose(-1, -2), + v.transpose(-1, -2), + attn_mask=attn_bias, + dropout_p=self.attn_drop.p, + ).transpose(-1, -2).reshape(B, -1, H, W) + else: + q = q * self.scale + attn = q.transpose(-2, -1) @ k + if self.rel_pos is not None: + attn = self.rel_pos(attn) + elif shared_rel_pos is not None: + attn = attn + shared_rel_pos + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) - x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) x = self.proj(x) x = self.proj_drop(x) return x @@ -211,6 +228,7 @@ class AttentionCl(nn.Module): self.dim_head = dim_head self.head_first = head_first self.scale = dim_head ** -0.5 + self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME self.qkv = nn.Linear(dim, dim_attn * 3, bias=bias) self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None @@ -227,15 +245,30 @@ class AttentionCl(nn.Module): else: q, k, v = self.qkv(x).reshape(B, -1, 3, self.num_heads, self.dim_head).transpose(1, 3).unbind(2) - attn = (q @ k.transpose(-2, -1)) * self.scale - if self.rel_pos is not None: - attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos) - elif shared_rel_pos is not None: - attn = attn + shared_rel_pos - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(restore_shape + (-1,)) + if self.fast_attn: + if self.rel_pos is not None: + attn_bias = self.rel_pos.get_bias() + elif shared_rel_pos is not None: + attn_bias = shared_rel_pos + else: + attn_bias = None + x = torch.nn.functional.scaled_dot_product_attention( + q, k, v, + attn_mask=attn_bias, + dropout_p=self.attn_drop.p, + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + if self.rel_pos is not None: + attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos) + elif shared_rel_pos is not None: + attn = attn + shared_rel_pos + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(restore_shape + (-1,)) x = self.proj(x) x = self.proj_drop(x) return x diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index d32f9dea..95d87126 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -37,7 +37,7 @@ import torch.utils.checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \ - resample_abs_pos_embed + resample_abs_pos_embed, RmsNorm from ._builder import build_model_with_cfg from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv from ._pretrained import generate_default_cfgs @@ -51,28 +51,49 @@ _logger = logging.getLogger(__name__) class Attention(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_norm=False, + attn_drop=0., + proj_drop=0., + norm_layer=nn.LayerNorm, + ): super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = head_dim ** -0.5 + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) - - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if self.fast_attn: + x = F.scaled_dot_product_attention( + q, k, v, + dropout_p=self.attn_drop.p, + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v - x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x @@ -96,6 +117,7 @@ class Block(nn.Module): num_heads, mlp_ratio=4., qkv_bias=False, + qk_norm=False, drop=0., attn_drop=0., init_values=None, @@ -105,13 +127,25 @@ class Block(nn.Module): ): super().__init__() self.norm1 = norm_layer(dim) - self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=drop, + norm_layer=norm_layer, + ) self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() - # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) - self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=drop, + ) self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -129,6 +163,7 @@ class ResPostBlock(nn.Module): num_heads, mlp_ratio=4., qkv_bias=False, + qk_norm=False, drop=0., attn_drop=0., init_values=None, @@ -139,11 +174,24 @@ class ResPostBlock(nn.Module): super().__init__() self.init_values = init_values - self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=drop, + norm_layer=norm_layer, + ) self.norm1 = norm_layer(dim) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=drop, + ) self.norm2 = norm_layer(dim) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -161,8 +209,61 @@ class ResPostBlock(nn.Module): return x -class ParallelBlock(nn.Module): +class ParallelScalingBlock(nn.Module): + """ Parallel ViT block (MLP & Attention in parallel) + Based on: + 'Scaling Vision Transformers to 22 Billion Parameters` - https://arxiv.org/abs/2302.05442 + """ + def __init__( + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + qk_norm=False, + drop=0., + attn_drop=0., + init_values=None, + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=drop, + norm_layer=norm_layer, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=drop, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + def forward(self, x): + y1 = self.drop_path1(self.ls1(self.attn(self.norm1(x)))) + y2 = self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + x = x + y1 + y2 + return x + + +class ParallelThingsBlock(nn.Module): + """ Parallel ViT block (N parallel attention followed by N parallel MLP) + Based on: + `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 + """ def __init__( self, dim, @@ -170,6 +271,7 @@ class ParallelBlock(nn.Module): num_parallel=2, mlp_ratio=4., qkv_bias=False, + qk_norm=False, init_values=None, drop=0., attn_drop=0., @@ -184,13 +286,26 @@ class ParallelBlock(nn.Module): for _ in range(num_parallel): self.attns.append(nn.Sequential(OrderedDict([ ('norm', norm_layer(dim)), - ('attn', Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)), + ('attn', Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=drop, + norm_layer=norm_layer, + )), ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) ]))) self.ffns.append(nn.Sequential(OrderedDict([ ('norm', norm_layer(dim)), - ('mlp', Mlp(dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)), + ('mlp', Mlp( + dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=drop, + )), ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) ]))) @@ -232,6 +347,7 @@ class VisionTransformer(nn.Module): num_heads=12, mlp_ratio=4., qkv_bias=True, + qk_norm=False, init_values=None, class_token=True, no_embed_class=False, @@ -305,6 +421,7 @@ class VisionTransformer(nn.Module): num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, + qk_norm=qk_norm, init_values=init_values, drop=drop_rate, attn_drop=attn_drop_rate, @@ -641,9 +758,8 @@ def checkpoint_filter_fn( """ convert patch embedding weight from manual patchify + linear proj to conv""" import re out_dict = {} - if 'model' in state_dict: - # For deit models - state_dict = state_dict['model'] + state_dict = state_dict.get('model', state_dict) + state_dict = state_dict.get('state_dict', state_dict) if 'visual.class_embedding' in state_dict: return _convert_openai_clip(state_dict, model) @@ -1129,6 +1245,9 @@ default_cfgs = generate_default_cfgs({ url='https://storage.googleapis.com/big_vision/flexivit/vit_b30_i21k_300ep.npz', custom_load=True, hf_hub_id='timm/', input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843), + + 'vit_large_patch14_xp_224.untrained': _cfg(url=''), + 'vit_huge_patch14_xp_224.untrained': _cfg(url=''), }) @@ -1566,7 +1685,7 @@ def vit_small_patch16_18x2_224(pretrained=False, **kwargs): Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow. """ model_kwargs = dict( - patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelBlock) + patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelThingsBlock) model = _create_vision_transformer( 'vit_small_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1577,7 +1696,8 @@ def vit_base_patch16_18x2_224(pretrained=False, **kwargs): """ ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove. Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 """ - model_kwargs = dict(patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock) + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelThingsBlock) model = _create_vision_transformer( 'vit_base_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1625,3 +1745,29 @@ def flexivit_large(pretrained=False, **kwargs): model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True) model = _create_vision_transformer('flexivit_large', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model + + +@register_model +def vit_large_patch14_xp_224(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled. + """ + model_kwargs = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, no_embed_class=True, + norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True, + ) + model = _create_vision_transformer( + 'vit_large_patch14_xp_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) + return model + + +@register_model +def vit_huge_patch14_xp_224(pretrained=False, **kwargs): + """ ViT-Huge model (ViT-H/14) w/ parallel blocks and qk norm enabled. + """ + model_kwargs = dict( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, no_embed_class=True, + norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True, + ) + model = _create_vision_transformer( + 'vit_huge_patch14_xp_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) + return model diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index a7cf3e53..f9fede53 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -25,14 +25,27 @@ _logger = logging.getLogger(__name__) class RelPosAttention(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, rel_pos_cls=None, attn_drop=0., proj_drop=0.): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_norm=False, + rel_pos_cls=None, + attn_drop=0., + proj_drop=0., + norm_layer=nn.LayerNorm, + ): super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = head_dim ** -0.5 + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.rel_pos = rel_pos_cls(num_heads=num_heads) if rel_pos_cls else None self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) @@ -40,18 +53,35 @@ class RelPosAttention(nn.Module): def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) - - attn = (q @ k.transpose(-2, -1)) * self.scale - if self.rel_pos is not None: - attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos) - elif shared_rel_pos is not None: - attn = attn + shared_rel_pos - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B, N, C) + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q = self.q_norm(q) + k = self.k_norm(k) + + if self.fast_attn: + if self.rel_pos is not None: + attn_bias = self.rel_pos.get_bias() + elif shared_rel_pos is not None: + attn_bias = shared_rel_pos + else: + attn_bias = None + x = torch.nn.functional.scaled_dot_product_attention( + q, k, v, + attn_mask=attn_bias, + dropout_p=self.attn_drop.p, + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + if self.rel_pos is not None: + attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos) + elif shared_rel_pos is not None: + attn = attn + shared_rel_pos + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x @@ -70,18 +100,42 @@ class LayerScale(nn.Module): class RelPosBlock(nn.Module): def __init__( - self, dim, num_heads, mlp_ratio=4., qkv_bias=False, rel_pos_cls=None, init_values=None, - drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + qk_norm=False, + rel_pos_cls=None, + init_values=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): super().__init__() self.norm1 = norm_layer(dim) self.attn = RelPosAttention( - dim, num_heads, qkv_bias=qkv_bias, rel_pos_cls=rel_pos_cls, attn_drop=attn_drop, proj_drop=drop) + dim, + num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + rel_pos_cls=rel_pos_cls, + attn_drop=attn_drop, + proj_drop=drop, + ) self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) - self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=drop, + ) self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -94,17 +148,41 @@ class RelPosBlock(nn.Module): class ResPostRelPosBlock(nn.Module): def __init__( - self, dim, num_heads, mlp_ratio=4., qkv_bias=False, rel_pos_cls=None, init_values=None, - drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + qk_norm=False, + rel_pos_cls=None, + init_values=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): super().__init__() self.init_values = init_values self.attn = RelPosAttention( - dim, num_heads, qkv_bias=qkv_bias, rel_pos_cls=rel_pos_cls, attn_drop=attn_drop, proj_drop=drop) + dim, + num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + rel_pos_cls=rel_pos_cls, + attn_drop=attn_drop, + proj_drop=drop, + ) self.norm1 = norm_layer(dim) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=drop, + ) self.norm2 = norm_layer(dim) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -144,6 +222,7 @@ class VisionTransformerRelPos(nn.Module): num_heads=12, mlp_ratio=4., qkv_bias=True, + qk_norm=False, init_values=1e-6, class_token=False, fc_norm=False, @@ -171,6 +250,7 @@ class VisionTransformerRelPos(nn.Module): num_heads (int): number of attention heads mlp_ratio (int): ratio of mlp hidden dim to embedding dim qkv_bias (bool): enable bias for qkv if True + qk_norm (bool): Enable normalization of query and key in attention init_values: (float): layer-scale init values class_token (bool): use class token (default: False) fc_norm (bool): use pre classifier norm instead of pre-pool @@ -197,18 +277,19 @@ class VisionTransformerRelPos(nn.Module): self.grad_checkpointing = False self.patch_embed = embed_layer( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) feat_size = self.patch_embed.grid_size rel_pos_args = dict(window_size=feat_size, prefix_tokens=self.num_prefix_tokens) if rel_pos_type.startswith('mlp'): if rel_pos_dim: rel_pos_args['hidden_dim'] = rel_pos_dim - # FIXME experimenting with different relpos log coord configs if 'swin' in rel_pos_type: rel_pos_args['mode'] = 'swin' - elif 'rw' in rel_pos_type: - rel_pos_args['mode'] = 'rw' rel_pos_cls = partial(RelPosMlp, **rel_pos_args) else: rel_pos_cls = partial(RelPosBias, **rel_pos_args) @@ -223,9 +304,19 @@ class VisionTransformerRelPos(nn.Module): dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.blocks = nn.ModuleList([ block_fn( - dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, rel_pos_cls=rel_pos_cls, - init_values=init_values, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], - norm_layer=norm_layer, act_layer=act_layer) + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + rel_pos_cls=rel_pos_cls, + init_values=init_values, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ) for i in range(depth)]) self.norm = norm_layer(embed_dim) if not fc_norm else nn.Identity()