Compare commits

...

5 Commits

@ -28,7 +28,7 @@ from .linear import Linear
from .mixed_conv2d import MixedConv2d from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp, GlobalResponseNormMlp from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp, GlobalResponseNormMlp
from .non_local_attn import NonLocalAttn, BatNonLocalAttn from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm
from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\ from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
from .padding import get_padding, get_same_padding, pad_same from .padding import get_padding, get_same_padding, pad_same

@ -17,6 +17,12 @@ try:
except ImportError: except ImportError:
has_apex = False has_apex = False
try:
from apex.normalization.fused_layer_norm import fused_rms_norm_affine, fused_rms_norm
has_apex_rmsnorm = True
except ImportError:
has_apex_rmsnorm = False
# fast (ie lower precision LN) can be disabled with this flag if issues crop up # fast (ie lower precision LN) can be disabled with this flag if issues crop up
_USE_FAST_NORM = False # defaulting to False for now _USE_FAST_NORM = False # defaulting to False for now
@ -76,3 +82,45 @@ def fast_layer_norm(
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
return F.layer_norm(x, normalized_shape, weight, bias, eps) return F.layer_norm(x, normalized_shape, weight, bias, eps)
def rms_norm(
x: torch.Tensor,
normalized_shape: List[int],
weight: Optional[torch.Tensor] = None,
eps: float = 1e-5,
):
norm_ndim = len(normalized_shape)
if torch.jit.is_scripting():
# ndim = len(x.shape)
# dims = list(range(ndim - norm_ndim, ndim)) # this doesn't work on pytorch <= 1.13.x
# NOTE -ve dims cause torchscript to crash in some cases, out of options to work around
assert norm_ndim == 1
v = torch.var(x, dim=-1).unsqueeze(-1) # ts crashes with -ve dim + keepdim=True
else:
dims = tuple(range(-1, -norm_ndim - 1, -1))
v = torch.var(x, dim=dims, keepdim=True)
x = x * torch.rsqrt(v + eps)
if weight is not None:
x = x * weight
return x
def fast_rms_norm(
x: torch.Tensor,
normalized_shape: List[int],
weight: Optional[torch.Tensor] = None,
eps: float = 1e-5,
) -> torch.Tensor:
if torch.jit.is_scripting():
# this must be by itself, cannot merge with has_apex_rmsnorm
return rms_norm(x, normalized_shape, weight, eps)
if has_apex_rmsnorm:
if weight is None:
return fused_rms_norm(x, normalized_shape, eps)
else:
return fused_rms_norm_affine(x, weight, normalized_shape, eps)
# fallback
return rms_norm(x, normalized_shape, weight, eps)

@ -4,12 +4,14 @@ Norm layer definitions that support fast norm and consistent channel arg order (
Hacked together by / Copyright 2022 Ross Wightman Hacked together by / Copyright 2022 Ross Wightman
""" """
import numbers
from typing import Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm
class GroupNorm(nn.GroupNorm): class GroupNorm(nn.GroupNorm):
@ -115,3 +117,39 @@ class LayerNormExp2d(nn.LayerNorm):
else: else:
x = _layer_norm_cf(x, self.weight, self.bias, self.eps) x = _layer_norm_cf(x, self.weight, self.bias, self.eps)
return x return x
class RmsNorm(nn.Module):
""" RmsNorm w/ fast (apex) norm if available
"""
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
normalized_shape: Tuple[int, ...]
eps: float
elementwise_affine: bool
def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
normalized_shape = channels
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = affine
if self.elementwise_affine:
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
self.register_parameter('weight', None)
self.reset_parameters()
def reset_parameters(self) -> None:
if self.elementwise_affine:
nn.init.ones_(self.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# NOTE fast norm fallback needs our rms norm impl, so both paths through here.
# Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
return x

@ -83,8 +83,8 @@ def gen_relative_log_coords(
pretrained_win_size: Tuple[int, int] = (0, 0), pretrained_win_size: Tuple[int, int] = (0, 0),
mode='swin', mode='swin',
): ):
assert mode in ('swin', 'cr', 'rw') assert mode in ('swin', 'cr')
# as per official swin-v2 impl, supporting timm specific 'cr' and 'rw' log coords as well # as per official swin-v2 impl, supporting timm specific 'cr' log coords as well
relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32) relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32)
relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32) relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32)
relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w])) relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
@ -99,15 +99,6 @@ def gen_relative_log_coords(
relative_coords_table *= 8 # normalize to -8, 8 relative_coords_table *= 8 # normalize to -8, 8
relative_coords_table = torch.sign(relative_coords_table) * torch.log2( relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
1.0 + relative_coords_table.abs()) / math.log2(8) 1.0 + relative_coords_table.abs()) / math.log2(8)
else:
if mode == 'rw':
# cr w/ window size normalization -> [-1,1] log coords
relative_coords_table[:, :, 0] /= (win_size[0] - 1)
relative_coords_table[:, :, 1] /= (win_size[1] - 1)
relative_coords_table *= 8 # scale to -8, 8
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
1.0 + relative_coords_table.abs())
relative_coords_table /= math.log2(9) # -> [-1, 1]
else: else:
# mode == 'cr' # mode == 'cr'
relative_coords_table = torch.sign(relative_coords_table) * torch.log( relative_coords_table = torch.sign(relative_coords_table) * torch.log(
@ -141,10 +132,6 @@ class RelPosMlp(nn.Module):
self.bias_act = nn.Sigmoid() self.bias_act = nn.Sigmoid()
self.bias_gain = 16 self.bias_gain = 16
mlp_bias = (True, False) mlp_bias = (True, False)
elif mode == 'rw':
self.bias_act = nn.Tanh()
self.bias_gain = 4
mlp_bias = True
else: else:
self.bias_act = nn.Identity() self.bias_act = nn.Identity()
self.bias_gain = None self.bias_gain = None

@ -42,6 +42,7 @@ from typing import Callable, Optional, Union, Tuple, List
import torch import torch
from torch import nn from torch import nn
from torch.jit import Final
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import Mlp, ConvMlp, DropPath, LayerNorm, ClassifierHead, NormMlpClassifierHead from timm.layers import Mlp, ConvMlp, DropPath, LayerNorm, ClassifierHead, NormMlpClassifierHead
@ -140,6 +141,8 @@ class MaxxVitCfg:
class Attention2d(nn.Module): class Attention2d(nn.Module):
fast_attn: Final[bool]
""" multi-head attention for 2D NCHW tensors""" """ multi-head attention for 2D NCHW tensors"""
def __init__( def __init__(
self, self,
@ -160,6 +163,7 @@ class Attention2d(nn.Module):
self.dim_head = dim_head self.dim_head = dim_head
self.head_first = head_first self.head_first = head_first
self.scale = dim_head ** -0.5 self.scale = dim_head ** -0.5
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias) self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias)
self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None
@ -175,15 +179,31 @@ class Attention2d(nn.Module):
else: else:
q, k, v = self.qkv(x).reshape(B, 3, self.num_heads, self.dim_head, -1).unbind(1) q, k, v = self.qkv(x).reshape(B, 3, self.num_heads, self.dim_head, -1).unbind(1)
attn = (q.transpose(-2, -1) @ k) * self.scale if self.fast_attn:
if self.rel_pos is not None:
attn_bias = self.rel_pos.get_bias()
elif shared_rel_pos is not None:
attn_bias = shared_rel_pos
else:
attn_bias = None
x = torch.nn.functional.scaled_dot_product_attention(
q.transpose(-1, -2),
k.transpose(-1, -2),
v.transpose(-1, -2),
attn_mask=attn_bias,
dropout_p=self.attn_drop.p,
).transpose(-1, -2).reshape(B, -1, H, W)
else:
q = q * self.scale
attn = q.transpose(-2, -1) @ k
if self.rel_pos is not None: if self.rel_pos is not None:
attn = self.rel_pos(attn) attn = self.rel_pos(attn)
elif shared_rel_pos is not None: elif shared_rel_pos is not None:
attn = attn + shared_rel_pos attn = attn + shared_rel_pos
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn) attn = self.attn_drop(attn)
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
x = self.proj(x) x = self.proj(x)
x = self.proj_drop(x) x = self.proj_drop(x)
return x return x
@ -191,6 +211,8 @@ class Attention2d(nn.Module):
class AttentionCl(nn.Module): class AttentionCl(nn.Module):
""" Channels-last multi-head attention (B, ..., C) """ """ Channels-last multi-head attention (B, ..., C) """
fast_attn: Final[bool]
def __init__( def __init__(
self, self,
dim: int, dim: int,
@ -211,6 +233,7 @@ class AttentionCl(nn.Module):
self.dim_head = dim_head self.dim_head = dim_head
self.head_first = head_first self.head_first = head_first
self.scale = dim_head ** -0.5 self.scale = dim_head ** -0.5
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
self.qkv = nn.Linear(dim, dim_attn * 3, bias=bias) self.qkv = nn.Linear(dim, dim_attn * 3, bias=bias)
self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None
@ -227,15 +250,30 @@ class AttentionCl(nn.Module):
else: else:
q, k, v = self.qkv(x).reshape(B, -1, 3, self.num_heads, self.dim_head).transpose(1, 3).unbind(2) q, k, v = self.qkv(x).reshape(B, -1, 3, self.num_heads, self.dim_head).transpose(1, 3).unbind(2)
attn = (q @ k.transpose(-2, -1)) * self.scale if self.fast_attn:
if self.rel_pos is not None:
attn_bias = self.rel_pos.get_bias()
elif shared_rel_pos is not None:
attn_bias = shared_rel_pos
else:
attn_bias = None
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_bias,
dropout_p=self.attn_drop.p,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
if self.rel_pos is not None: if self.rel_pos is not None:
attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos) attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos)
elif shared_rel_pos is not None: elif shared_rel_pos is not None:
attn = attn + shared_rel_pos attn = attn + shared_rel_pos
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn) attn = self.attn_drop(attn)
x = attn @ v
x = (attn @ v).transpose(1, 2).reshape(restore_shape + (-1,)) x = x.transpose(1, 2).reshape(restore_shape + (-1,))
x = self.proj(x) x = self.proj(x)
x = self.proj_drop(x) x = self.proj_drop(x)
return x return x

@ -33,11 +33,12 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch.jit import Final
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \ from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \
resample_abs_pos_embed resample_abs_pos_embed, RmsNorm
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
from ._pretrained import generate_default_cfgs from ._pretrained import generate_default_cfgs
@ -51,28 +52,51 @@ _logger = logging.getLogger(__name__)
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): fast_attn: Final[bool]
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_norm=False,
attn_drop=0.,
proj_drop=0.,
norm_layer=nn.LayerNorm,
):
super().__init__() super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads' assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads self.num_heads = num_heads
head_dim = dim // num_heads self.head_dim = dim // num_heads
self.scale = head_dim ** -0.5 self.scale = self.head_dim ** -0.5
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop) self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim) self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop) self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x): def forward(self, x):
B, N, C = x.shape B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
attn = (q @ k.transpose(-2, -1)) * self.scale
if self.fast_attn:
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn) attn = self.attn_drop(attn)
x = attn @ v
x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x) x = self.proj(x)
x = self.proj_drop(x) x = self.proj_drop(x)
return x return x
@ -96,6 +120,7 @@ class Block(nn.Module):
num_heads, num_heads,
mlp_ratio=4., mlp_ratio=4.,
qkv_bias=False, qkv_bias=False,
qk_norm=False,
drop=0., drop=0.,
attn_drop=0., attn_drop=0.,
init_values=None, init_values=None,
@ -105,13 +130,25 @@ class Block(nn.Module):
): ):
super().__init__() super().__init__()
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
attn_drop=attn_drop,
proj_drop=drop,
norm_layer=norm_layer,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -129,6 +166,7 @@ class ResPostBlock(nn.Module):
num_heads, num_heads,
mlp_ratio=4., mlp_ratio=4.,
qkv_bias=False, qkv_bias=False,
qk_norm=False,
drop=0., drop=0.,
attn_drop=0., attn_drop=0.,
init_values=None, init_values=None,
@ -139,11 +177,24 @@ class ResPostBlock(nn.Module):
super().__init__() super().__init__()
self.init_values = init_values self.init_values = init_values
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
attn_drop=attn_drop,
proj_drop=drop,
norm_layer=norm_layer,
)
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
)
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -161,8 +212,105 @@ class ResPostBlock(nn.Module):
return x return x
class ParallelBlock(nn.Module): class ParallelScalingBlock(nn.Module):
""" Parallel ViT block (MLP & Attention in parallel)
Based on:
'Scaling Vision Transformers to 22 Billion Parameters` - https://arxiv.org/abs/2302.05442
"""
fast_attn: Final[bool]
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_norm=False,
drop=0.,
attn_drop=0.,
init_values=None,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm
):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
mlp_hidden_dim = int(mlp_ratio * dim)
in_proj_out_dim = mlp_hidden_dim + 3 * dim
self.in_norm = norm_layer(dim)
self.in_proj = nn.Linear(dim, in_proj_out_dim, bias=qkv_bias)
self.in_split = [mlp_hidden_dim] + [dim] * 3
if qkv_bias:
self.register_buffer('qkv_bias', None)
self.register_parameter('mlp_bias', None)
else:
self.register_buffer('qkv_bias', torch.zeros(3 * dim), persistent=False)
self.mlp_bias = nn.Parameter(torch.zeros(mlp_hidden_dim))
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.attn_out_proj = nn.Linear(dim, dim)
self.mlp_drop = nn.Dropout(drop)
self.mlp_act = act_layer()
self.mlp_out_proj = nn.Linear(mlp_hidden_dim, dim)
self.ls = LayerScale(dim, init_values=init_values) if init_values is not None else nn.Identity()
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
B, N, C = x.shape
# Combined MLP fc1 & qkv projections
y = self.in_norm(x)
if self.mlp_bias is not None:
# Concat constant zero-bias for qkv w/ trainable mlp_bias.
# Appears faster than adding to x_mlp separately
y = F.linear(y, self.in_proj.weight, torch.cat((self.qkv_bias, self.mlp_bias)))
else:
y = self.in_proj(y)
x_mlp, q, k, v = torch.split(y, self.in_split, dim=-1)
# Dot product attention w/ qk norm
q = self.q_norm(q.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2)
k = self.k_norm(k.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2)
v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
if self.fast_attn:
x_attn = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x_attn = attn @ v
x_attn = x_attn.transpose(1, 2).reshape(B, N, C)
x_attn = self.attn_out_proj(x_attn)
# MLP activation, dropout, fc2
x_mlp = self.mlp_act(x_mlp)
x_mlp = self.mlp_drop(x_mlp)
x_mlp = self.mlp_out_proj(x_mlp)
# Add residual w/ drop path & layer scale applied
y = self.drop_path(self.ls(x_attn + x_mlp))
x = x + y
return x
class ParallelThingsBlock(nn.Module):
""" Parallel ViT block (N parallel attention followed by N parallel MLP)
Based on:
`Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
"""
def __init__( def __init__(
self, self,
dim, dim,
@ -170,6 +318,7 @@ class ParallelBlock(nn.Module):
num_parallel=2, num_parallel=2,
mlp_ratio=4., mlp_ratio=4.,
qkv_bias=False, qkv_bias=False,
qk_norm=False,
init_values=None, init_values=None,
drop=0., drop=0.,
attn_drop=0., attn_drop=0.,
@ -184,13 +333,26 @@ class ParallelBlock(nn.Module):
for _ in range(num_parallel): for _ in range(num_parallel):
self.attns.append(nn.Sequential(OrderedDict([ self.attns.append(nn.Sequential(OrderedDict([
('norm', norm_layer(dim)), ('norm', norm_layer(dim)),
('attn', Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)), ('attn', Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
attn_drop=attn_drop,
proj_drop=drop,
norm_layer=norm_layer,
)),
('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
]))) ])))
self.ffns.append(nn.Sequential(OrderedDict([ self.ffns.append(nn.Sequential(OrderedDict([
('norm', norm_layer(dim)), ('norm', norm_layer(dim)),
('mlp', Mlp(dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)), ('mlp', Mlp(
dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
)),
('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
]))) ])))
@ -232,6 +394,7 @@ class VisionTransformer(nn.Module):
num_heads=12, num_heads=12,
mlp_ratio=4., mlp_ratio=4.,
qkv_bias=True, qkv_bias=True,
qk_norm=False,
init_values=None, init_values=None,
class_token=True, class_token=True,
no_embed_class=False, no_embed_class=False,
@ -305,6 +468,7 @@ class VisionTransformer(nn.Module):
num_heads=num_heads, num_heads=num_heads,
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
qk_norm=qk_norm,
init_values=init_values, init_values=init_values,
drop=drop_rate, drop=drop_rate,
attn_drop=attn_drop_rate, attn_drop=attn_drop_rate,
@ -641,9 +805,8 @@ def checkpoint_filter_fn(
""" convert patch embedding weight from manual patchify + linear proj to conv""" """ convert patch embedding weight from manual patchify + linear proj to conv"""
import re import re
out_dict = {} out_dict = {}
if 'model' in state_dict: state_dict = state_dict.get('model', state_dict)
# For deit models state_dict = state_dict.get('state_dict', state_dict)
state_dict = state_dict['model']
if 'visual.class_embedding' in state_dict: if 'visual.class_embedding' in state_dict:
return _convert_openai_clip(state_dict, model) return _convert_openai_clip(state_dict, model)
@ -1129,6 +1292,10 @@ default_cfgs = generate_default_cfgs({
url='https://storage.googleapis.com/big_vision/flexivit/vit_b30_i21k_300ep.npz', custom_load=True, url='https://storage.googleapis.com/big_vision/flexivit/vit_b30_i21k_300ep.npz', custom_load=True,
hf_hub_id='timm/', hf_hub_id='timm/',
input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843), input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843),
'vit_base_patch16_xp_224.untrained': _cfg(url=''),
'vit_large_patch14_xp_224.untrained': _cfg(url=''),
'vit_huge_patch14_xp_224.untrained': _cfg(url=''),
}) })
@ -1566,7 +1733,7 @@ def vit_small_patch16_18x2_224(pretrained=False, **kwargs):
Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow. Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
""" """
model_kwargs = dict( model_kwargs = dict(
patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelBlock) patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelThingsBlock)
model = _create_vision_transformer( model = _create_vision_transformer(
'vit_small_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) 'vit_small_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1577,7 +1744,8 @@ def vit_base_patch16_18x2_224(pretrained=False, **kwargs):
""" ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove. """ ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove.
Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
""" """
model_kwargs = dict(patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock) model_kwargs = dict(
patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelThingsBlock)
model = _create_vision_transformer( model = _create_vision_transformer(
'vit_base_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) 'vit_base_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1625,3 +1793,42 @@ def flexivit_large(pretrained=False, **kwargs):
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True) model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True)
model = _create_vision_transformer('flexivit_large', pretrained=pretrained, **dict(model_kwargs, **kwargs)) model = _create_vision_transformer('flexivit_large', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@register_model
def vit_base_patch16_xp_224(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled.
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, no_embed_class=True,
norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True,
)
model = _create_vision_transformer(
'vit_base_patch16_xp_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@register_model
def vit_large_patch14_xp_224(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled.
"""
model_kwargs = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, no_embed_class=True,
norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True,
)
model = _create_vision_transformer(
'vit_large_patch14_xp_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@register_model
def vit_huge_patch14_xp_224(pretrained=False, **kwargs):
""" ViT-Huge model (ViT-H/14) w/ parallel blocks and qk norm enabled.
"""
model_kwargs = dict(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, no_embed_class=True,
norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True,
)
model = _create_vision_transformer(
'vit_huge_patch14_xp_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model

@ -11,6 +11,7 @@ from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.jit import Final
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
@ -25,14 +26,29 @@ _logger = logging.getLogger(__name__)
class RelPosAttention(nn.Module): class RelPosAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, rel_pos_cls=None, attn_drop=0., proj_drop=0.): fast_attn: Final[bool]
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_norm=False,
rel_pos_cls=None,
attn_drop=0.,
proj_drop=0.,
norm_layer=nn.LayerNorm,
):
super().__init__() super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads' assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads self.num_heads = num_heads
head_dim = dim // num_heads self.head_dim = dim // num_heads
self.scale = head_dim ** -0.5 self.scale = self.head_dim ** -0.5
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.rel_pos = rel_pos_cls(num_heads=num_heads) if rel_pos_cls else None self.rel_pos = rel_pos_cls(num_heads=num_heads) if rel_pos_cls else None
self.attn_drop = nn.Dropout(attn_drop) self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim) self.proj = nn.Linear(dim, dim)
@ -40,18 +56,35 @@ class RelPosAttention(nn.Module):
def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
B, N, C = x.shape B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) q, k, v = qkv.unbind(0)
q = self.q_norm(q)
k = self.k_norm(k)
attn = (q @ k.transpose(-2, -1)) * self.scale if self.fast_attn:
if self.rel_pos is not None:
attn_bias = self.rel_pos.get_bias()
elif shared_rel_pos is not None:
attn_bias = shared_rel_pos
else:
attn_bias = None
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_bias,
dropout_p=self.attn_drop.p,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
if self.rel_pos is not None: if self.rel_pos is not None:
attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos) attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos)
elif shared_rel_pos is not None: elif shared_rel_pos is not None:
attn = attn + shared_rel_pos attn = attn + shared_rel_pos
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn) attn = self.attn_drop(attn)
x = attn @ v
x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x) x = self.proj(x)
x = self.proj_drop(x) x = self.proj_drop(x)
return x return x
@ -70,18 +103,42 @@ class LayerScale(nn.Module):
class RelPosBlock(nn.Module): class RelPosBlock(nn.Module):
def __init__( def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, rel_pos_cls=None, init_values=None, self,
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_norm=False,
rel_pos_cls=None,
init_values=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__() super().__init__()
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.attn = RelPosAttention( self.attn = RelPosAttention(
dim, num_heads, qkv_bias=qkv_bias, rel_pos_cls=rel_pos_cls, attn_drop=attn_drop, proj_drop=drop) dim,
num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
rel_pos_cls=rel_pos_cls,
attn_drop=attn_drop,
proj_drop=drop,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -94,17 +151,41 @@ class RelPosBlock(nn.Module):
class ResPostRelPosBlock(nn.Module): class ResPostRelPosBlock(nn.Module):
def __init__( def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, rel_pos_cls=None, init_values=None, self,
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_norm=False,
rel_pos_cls=None,
init_values=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__() super().__init__()
self.init_values = init_values self.init_values = init_values
self.attn = RelPosAttention( self.attn = RelPosAttention(
dim, num_heads, qkv_bias=qkv_bias, rel_pos_cls=rel_pos_cls, attn_drop=attn_drop, proj_drop=drop) dim,
num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
rel_pos_cls=rel_pos_cls,
attn_drop=attn_drop,
proj_drop=drop,
)
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
)
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -144,6 +225,7 @@ class VisionTransformerRelPos(nn.Module):
num_heads=12, num_heads=12,
mlp_ratio=4., mlp_ratio=4.,
qkv_bias=True, qkv_bias=True,
qk_norm=False,
init_values=1e-6, init_values=1e-6,
class_token=False, class_token=False,
fc_norm=False, fc_norm=False,
@ -171,6 +253,7 @@ class VisionTransformerRelPos(nn.Module):
num_heads (int): number of attention heads num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True qkv_bias (bool): enable bias for qkv if True
qk_norm (bool): Enable normalization of query and key in attention
init_values: (float): layer-scale init values init_values: (float): layer-scale init values
class_token (bool): use class token (default: False) class_token (bool): use class token (default: False)
fc_norm (bool): use pre classifier norm instead of pre-pool fc_norm (bool): use pre classifier norm instead of pre-pool
@ -197,18 +280,19 @@ class VisionTransformerRelPos(nn.Module):
self.grad_checkpointing = False self.grad_checkpointing = False
self.patch_embed = embed_layer( self.patch_embed = embed_layer(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
feat_size = self.patch_embed.grid_size feat_size = self.patch_embed.grid_size
rel_pos_args = dict(window_size=feat_size, prefix_tokens=self.num_prefix_tokens) rel_pos_args = dict(window_size=feat_size, prefix_tokens=self.num_prefix_tokens)
if rel_pos_type.startswith('mlp'): if rel_pos_type.startswith('mlp'):
if rel_pos_dim: if rel_pos_dim:
rel_pos_args['hidden_dim'] = rel_pos_dim rel_pos_args['hidden_dim'] = rel_pos_dim
# FIXME experimenting with different relpos log coord configs
if 'swin' in rel_pos_type: if 'swin' in rel_pos_type:
rel_pos_args['mode'] = 'swin' rel_pos_args['mode'] = 'swin'
elif 'rw' in rel_pos_type:
rel_pos_args['mode'] = 'rw'
rel_pos_cls = partial(RelPosMlp, **rel_pos_args) rel_pos_cls = partial(RelPosMlp, **rel_pos_args)
else: else:
rel_pos_cls = partial(RelPosBias, **rel_pos_args) rel_pos_cls = partial(RelPosBias, **rel_pos_args)
@ -223,9 +307,19 @@ class VisionTransformerRelPos(nn.Module):
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList([
block_fn( block_fn(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, rel_pos_cls=rel_pos_cls, dim=embed_dim,
init_values=init_values, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], num_heads=num_heads,
norm_layer=norm_layer, act_layer=act_layer) mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
rel_pos_cls=rel_pos_cls,
init_values=init_values,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
)
for i in range(depth)]) for i in range(depth)])
self.norm = norm_layer(embed_dim) if not fc_norm else nn.Identity() self.norm = norm_layer(embed_dim) if not fc_norm else nn.Identity()

Loading…
Cancel
Save