From f77c04ff36e3a61d94e3027cb4ae22256387e70e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 15 Feb 2023 14:53:55 -0800 Subject: [PATCH] Torchscript fixes/hacks for rms_norm, refactor ParallelScalingBlock with manual combination of input projections, closer paper match --- timm/layers/fast_norm.py | 27 ++++++-- timm/layers/norm.py | 1 + timm/models/vision_transformer.py | 107 +++++++++++++++++++++++------- 3 files changed, 104 insertions(+), 31 deletions(-) diff --git a/timm/layers/fast_norm.py b/timm/layers/fast_norm.py index 17828989..2880be99 100644 --- a/timm/layers/fast_norm.py +++ b/timm/layers/fast_norm.py @@ -90,8 +90,16 @@ def rms_norm( weight: Optional[torch.Tensor] = None, eps: float = 1e-5, ): - dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1)) - v = torch.var(x, dim=dims, keepdim=True) + norm_ndim = len(normalized_shape) + if torch.jit.is_scripting(): + # ndim = len(x.shape) + # dims = list(range(ndim - norm_ndim, ndim)) # this doesn't work on pytorch <= 1.13.x + # NOTE -ve dims cause torchscript to crash in some cases, out of options to work around + assert norm_ndim == 1 + v = torch.var(x, dim=-1).unsqueeze(-1) # ts crashes with -ve dim + keepdim=True + else: + dims = tuple(range(-1, -norm_ndim - 1, -1)) + v = torch.var(x, dim=dims, keepdim=True) x = x * torch.rsqrt(v + eps) if weight is not None: x = x * weight @@ -104,10 +112,15 @@ def fast_rms_norm( weight: Optional[torch.Tensor] = None, eps: float = 1e-5, ) -> torch.Tensor: - if torch.jit.is_scripting() or not has_apex_rmsnorm: + if torch.jit.is_scripting(): + # this must be by itself, cannot merge with has_apex_rmsnorm return rms_norm(x, normalized_shape, weight, eps) - if weight is None: - return fused_rms_norm(x, normalized_shape, eps) - else: - return fused_rms_norm_affine(x, weight, normalized_shape, eps) + if has_apex_rmsnorm: + if weight is None: + return fused_rms_norm(x, normalized_shape, eps) + else: + return fused_rms_norm_affine(x, weight, normalized_shape, eps) + + # fallback + return rms_norm(x, normalized_shape, weight, eps) diff --git a/timm/layers/norm.py b/timm/layers/norm.py index dd939719..504060c7 100644 --- a/timm/layers/norm.py +++ b/timm/layers/norm.py @@ -122,6 +122,7 @@ class LayerNormExp2d(nn.LayerNorm): class RmsNorm(nn.Module): """ RmsNorm w/ fast (apex) norm if available """ + __constants__ = ['normalized_shape', 'eps', 'elementwise_affine'] normalized_shape: Tuple[int, ...] eps: float elementwise_affine: bool diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index dcd0a6fa..477b4f7b 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -217,6 +217,8 @@ class ParallelScalingBlock(nn.Module): Based on: 'Scaling Vision Transformers to 22 Billion Parameters` - https://arxiv.org/abs/2302.05442 """ + fast_attn: Final[bool] + def __init__( self, dim, @@ -232,33 +234,76 @@ class ParallelScalingBlock(nn.Module): norm_layer=nn.LayerNorm ): super().__init__() - self.norm1 = norm_layer(dim) - self.attn = Attention( - dim, - num_heads=num_heads, - qkv_bias=qkv_bias, - qk_norm=qk_norm, - attn_drop=attn_drop, - proj_drop=drop, - norm_layer=norm_layer, - ) - self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() - self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME + mlp_hidden_dim = int(mlp_ratio * dim) + in_proj_out_dim = mlp_hidden_dim + 3 * dim + out_proj_in_dim = mlp_hidden_dim + dim + + self.in_norm = norm_layer(dim) + self.in_proj = nn.Linear(dim, in_proj_out_dim, bias=qkv_bias) + self.in_split = [mlp_hidden_dim] + [dim] * 3 + if qkv_bias: + self.register_buffer('qkv_bias', None) + self.register_parameter('mlp_bias', None) + else: + self.register_buffer('qkv_bias', torch.zeros(3 * dim), persistent=False) + self.mlp_bias = nn.Parameter(torch.zeros(mlp_hidden_dim)) - self.norm2 = norm_layer(dim) - self.mlp = Mlp( - in_features=dim, - hidden_features=int(dim * mlp_ratio), - act_layer=act_layer, - drop=drop, - ) - self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() - self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.attn_out_proj = nn.Linear(dim, dim) + + self.mlp_drop = nn.Dropout(drop) + self.mlp_act = act_layer() + self.mlp_out_proj = nn.Linear(mlp_hidden_dim, dim) + + self.ls = LayerScale(dim, init_values=init_values) if init_values is not None else nn.Identity() + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): - y1 = self.drop_path1(self.ls1(self.attn(self.norm1(x)))) - y2 = self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) - x = x + y1 + y2 + B, N, C = x.shape + + # Combined MLP fc1 & qkv projections + y = self.in_norm(x) + if self.mlp_bias is not None: + # Concat constant zero-bias for qkv w/ trainable mlp_bias. + # Appears faster than adding to x_mlp separately + y = F.linear(y, self.in_proj.weight, torch.cat((self.qkv_bias, self.mlp_bias))) + else: + y = self.in_proj(y) + x_mlp, q, k, v = torch.split(y, self.in_split, dim=-1) + + # Dot product attention w/ qk norm + q = self.q_norm(q.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2) + k = self.k_norm(k.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2) + v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2) + if self.fast_attn: + x_attn = F.scaled_dot_product_attention( + q, k, v, + dropout_p=self.attn_drop.p, + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x_attn = attn @ v + x_attn = x_attn.transpose(1, 2).reshape(B, N, C) + x_attn = self.attn_out_proj(x_attn) + + # MLP activation, dropout, fc2 + x_mlp = self.mlp_act(x_mlp) + x_mlp = self.mlp_drop(x_mlp) + x_mlp = self.mlp_out_proj(x_mlp) + + # Add residual w/ drop path & layer scale applied + y = self.drop_path(self.ls(x_attn + x_mlp)) + x = x + y return x @@ -1249,6 +1294,7 @@ default_cfgs = generate_default_cfgs({ hf_hub_id='timm/', input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843), + 'vit_base_patch16_xp_224.untrained': _cfg(url=''), 'vit_large_patch14_xp_224.untrained': _cfg(url=''), 'vit_huge_patch14_xp_224.untrained': _cfg(url=''), }) @@ -1750,6 +1796,19 @@ def flexivit_large(pretrained=False, **kwargs): return model +@register_model +def vit_base_patch16_xp_224(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled. + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, no_embed_class=True, + norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True, + ) + model = _create_vision_transformer( + 'vit_base_patch16_xp_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) + return model + + @register_model def vit_large_patch14_xp_224(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled.