|
|
@ -33,11 +33,12 @@ import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import torch.utils.checkpoint
|
|
|
|
import torch.utils.checkpoint
|
|
|
|
|
|
|
|
from torch.jit import Final
|
|
|
|
|
|
|
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
|
|
|
|
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
|
|
|
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
|
|
|
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \
|
|
|
|
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \
|
|
|
|
resample_abs_pos_embed
|
|
|
|
resample_abs_pos_embed, RmsNorm
|
|
|
|
from ._builder import build_model_with_cfg
|
|
|
|
from ._builder import build_model_with_cfg
|
|
|
|
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
|
|
|
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
|
|
|
from ._pretrained import generate_default_cfgs
|
|
|
|
from ._pretrained import generate_default_cfgs
|
|
|
@ -51,28 +52,51 @@ _logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Attention(nn.Module):
|
|
|
|
class Attention(nn.Module):
|
|
|
|
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
|
|
|
fast_attn: Final[bool]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
|
|
self,
|
|
|
|
|
|
|
|
dim,
|
|
|
|
|
|
|
|
num_heads=8,
|
|
|
|
|
|
|
|
qkv_bias=False,
|
|
|
|
|
|
|
|
qk_norm=False,
|
|
|
|
|
|
|
|
attn_drop=0.,
|
|
|
|
|
|
|
|
proj_drop=0.,
|
|
|
|
|
|
|
|
norm_layer=nn.LayerNorm,
|
|
|
|
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
|
|
|
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
|
|
|
self.num_heads = num_heads
|
|
|
|
self.num_heads = num_heads
|
|
|
|
head_dim = dim // num_heads
|
|
|
|
self.head_dim = dim // num_heads
|
|
|
|
self.scale = head_dim ** -0.5
|
|
|
|
self.scale = self.head_dim ** -0.5
|
|
|
|
|
|
|
|
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
|
|
|
|
|
|
|
|
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
|
|
|
|
|
|
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
|
|
|
|
|
|
|
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
|
|
|
self.attn_drop = nn.Dropout(attn_drop)
|
|
|
|
self.attn_drop = nn.Dropout(attn_drop)
|
|
|
|
self.proj = nn.Linear(dim, dim)
|
|
|
|
self.proj = nn.Linear(dim, dim)
|
|
|
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
|
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
B, N, C = x.shape
|
|
|
|
B, N, C = x.shape
|
|
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
|
|
|
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
|
|
|
q, k, v = qkv.unbind(0)
|
|
|
|
|
|
|
|
q, k = self.q_norm(q), self.k_norm(k)
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
|
|
|
|
|
|
|
|
|
|
if self.fast_attn:
|
|
|
|
|
|
|
|
x = F.scaled_dot_product_attention(
|
|
|
|
|
|
|
|
q, k, v,
|
|
|
|
|
|
|
|
dropout_p=self.attn_drop.p,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
q = q * self.scale
|
|
|
|
|
|
|
|
attn = q @ k.transpose(-2, -1)
|
|
|
|
attn = attn.softmax(dim=-1)
|
|
|
|
attn = attn.softmax(dim=-1)
|
|
|
|
attn = self.attn_drop(attn)
|
|
|
|
attn = self.attn_drop(attn)
|
|
|
|
|
|
|
|
x = attn @ v
|
|
|
|
|
|
|
|
|
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
|
|
|
x = x.transpose(1, 2).reshape(B, N, C)
|
|
|
|
x = self.proj(x)
|
|
|
|
x = self.proj(x)
|
|
|
|
x = self.proj_drop(x)
|
|
|
|
x = self.proj_drop(x)
|
|
|
|
return x
|
|
|
|
return x
|
|
|
@ -96,6 +120,7 @@ class Block(nn.Module):
|
|
|
|
num_heads,
|
|
|
|
num_heads,
|
|
|
|
mlp_ratio=4.,
|
|
|
|
mlp_ratio=4.,
|
|
|
|
qkv_bias=False,
|
|
|
|
qkv_bias=False,
|
|
|
|
|
|
|
|
qk_norm=False,
|
|
|
|
drop=0.,
|
|
|
|
drop=0.,
|
|
|
|
attn_drop=0.,
|
|
|
|
attn_drop=0.,
|
|
|
|
init_values=None,
|
|
|
|
init_values=None,
|
|
|
@ -105,13 +130,25 @@ class Block(nn.Module):
|
|
|
|
):
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
self.norm1 = norm_layer(dim)
|
|
|
|
self.norm1 = norm_layer(dim)
|
|
|
|
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
|
|
|
self.attn = Attention(
|
|
|
|
|
|
|
|
dim,
|
|
|
|
|
|
|
|
num_heads=num_heads,
|
|
|
|
|
|
|
|
qkv_bias=qkv_bias,
|
|
|
|
|
|
|
|
qk_norm=qk_norm,
|
|
|
|
|
|
|
|
attn_drop=attn_drop,
|
|
|
|
|
|
|
|
proj_drop=drop,
|
|
|
|
|
|
|
|
norm_layer=norm_layer,
|
|
|
|
|
|
|
|
)
|
|
|
|
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
|
|
|
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
|
|
|
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
|
|
|
|
|
|
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
|
|
|
|
|
|
|
|
self.norm2 = norm_layer(dim)
|
|
|
|
self.norm2 = norm_layer(dim)
|
|
|
|
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
|
|
|
|
self.mlp = Mlp(
|
|
|
|
|
|
|
|
in_features=dim,
|
|
|
|
|
|
|
|
hidden_features=int(dim * mlp_ratio),
|
|
|
|
|
|
|
|
act_layer=act_layer,
|
|
|
|
|
|
|
|
drop=drop,
|
|
|
|
|
|
|
|
)
|
|
|
|
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
|
|
|
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
|
|
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
|
|
|
|
|
|
|
@ -129,6 +166,7 @@ class ResPostBlock(nn.Module):
|
|
|
|
num_heads,
|
|
|
|
num_heads,
|
|
|
|
mlp_ratio=4.,
|
|
|
|
mlp_ratio=4.,
|
|
|
|
qkv_bias=False,
|
|
|
|
qkv_bias=False,
|
|
|
|
|
|
|
|
qk_norm=False,
|
|
|
|
drop=0.,
|
|
|
|
drop=0.,
|
|
|
|
attn_drop=0.,
|
|
|
|
attn_drop=0.,
|
|
|
|
init_values=None,
|
|
|
|
init_values=None,
|
|
|
@ -139,11 +177,24 @@ class ResPostBlock(nn.Module):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
self.init_values = init_values
|
|
|
|
self.init_values = init_values
|
|
|
|
|
|
|
|
|
|
|
|
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
|
|
|
self.attn = Attention(
|
|
|
|
|
|
|
|
dim,
|
|
|
|
|
|
|
|
num_heads=num_heads,
|
|
|
|
|
|
|
|
qkv_bias=qkv_bias,
|
|
|
|
|
|
|
|
qk_norm=qk_norm,
|
|
|
|
|
|
|
|
attn_drop=attn_drop,
|
|
|
|
|
|
|
|
proj_drop=drop,
|
|
|
|
|
|
|
|
norm_layer=norm_layer,
|
|
|
|
|
|
|
|
)
|
|
|
|
self.norm1 = norm_layer(dim)
|
|
|
|
self.norm1 = norm_layer(dim)
|
|
|
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
|
|
|
|
|
|
|
|
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
|
|
|
|
self.mlp = Mlp(
|
|
|
|
|
|
|
|
in_features=dim,
|
|
|
|
|
|
|
|
hidden_features=int(dim * mlp_ratio),
|
|
|
|
|
|
|
|
act_layer=act_layer,
|
|
|
|
|
|
|
|
drop=drop,
|
|
|
|
|
|
|
|
)
|
|
|
|
self.norm2 = norm_layer(dim)
|
|
|
|
self.norm2 = norm_layer(dim)
|
|
|
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
|
|
|
|
|
|
|
@ -161,8 +212,105 @@ class ResPostBlock(nn.Module):
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ParallelBlock(nn.Module):
|
|
|
|
class ParallelScalingBlock(nn.Module):
|
|
|
|
|
|
|
|
""" Parallel ViT block (MLP & Attention in parallel)
|
|
|
|
|
|
|
|
Based on:
|
|
|
|
|
|
|
|
'Scaling Vision Transformers to 22 Billion Parameters` - https://arxiv.org/abs/2302.05442
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
fast_attn: Final[bool]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
|
|
self,
|
|
|
|
|
|
|
|
dim,
|
|
|
|
|
|
|
|
num_heads,
|
|
|
|
|
|
|
|
mlp_ratio=4.,
|
|
|
|
|
|
|
|
qkv_bias=False,
|
|
|
|
|
|
|
|
qk_norm=False,
|
|
|
|
|
|
|
|
drop=0.,
|
|
|
|
|
|
|
|
attn_drop=0.,
|
|
|
|
|
|
|
|
init_values=None,
|
|
|
|
|
|
|
|
drop_path=0.,
|
|
|
|
|
|
|
|
act_layer=nn.GELU,
|
|
|
|
|
|
|
|
norm_layer=nn.LayerNorm
|
|
|
|
|
|
|
|
):
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
|
|
|
|
|
|
|
self.num_heads = num_heads
|
|
|
|
|
|
|
|
self.head_dim = dim // num_heads
|
|
|
|
|
|
|
|
self.scale = self.head_dim ** -0.5
|
|
|
|
|
|
|
|
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
|
|
|
|
|
|
|
|
mlp_hidden_dim = int(mlp_ratio * dim)
|
|
|
|
|
|
|
|
in_proj_out_dim = mlp_hidden_dim + 3 * dim
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.in_norm = norm_layer(dim)
|
|
|
|
|
|
|
|
self.in_proj = nn.Linear(dim, in_proj_out_dim, bias=qkv_bias)
|
|
|
|
|
|
|
|
self.in_split = [mlp_hidden_dim] + [dim] * 3
|
|
|
|
|
|
|
|
if qkv_bias:
|
|
|
|
|
|
|
|
self.register_buffer('qkv_bias', None)
|
|
|
|
|
|
|
|
self.register_parameter('mlp_bias', None)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
self.register_buffer('qkv_bias', torch.zeros(3 * dim), persistent=False)
|
|
|
|
|
|
|
|
self.mlp_bias = nn.Parameter(torch.zeros(mlp_hidden_dim))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
|
|
|
|
|
|
|
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
|
|
|
|
|
|
|
self.attn_drop = nn.Dropout(attn_drop)
|
|
|
|
|
|
|
|
self.attn_out_proj = nn.Linear(dim, dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.mlp_drop = nn.Dropout(drop)
|
|
|
|
|
|
|
|
self.mlp_act = act_layer()
|
|
|
|
|
|
|
|
self.mlp_out_proj = nn.Linear(mlp_hidden_dim, dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.ls = LayerScale(dim, init_values=init_values) if init_values is not None else nn.Identity()
|
|
|
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
|
|
B, N, C = x.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Combined MLP fc1 & qkv projections
|
|
|
|
|
|
|
|
y = self.in_norm(x)
|
|
|
|
|
|
|
|
if self.mlp_bias is not None:
|
|
|
|
|
|
|
|
# Concat constant zero-bias for qkv w/ trainable mlp_bias.
|
|
|
|
|
|
|
|
# Appears faster than adding to x_mlp separately
|
|
|
|
|
|
|
|
y = F.linear(y, self.in_proj.weight, torch.cat((self.qkv_bias, self.mlp_bias)))
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
y = self.in_proj(y)
|
|
|
|
|
|
|
|
x_mlp, q, k, v = torch.split(y, self.in_split, dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Dot product attention w/ qk norm
|
|
|
|
|
|
|
|
q = self.q_norm(q.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2)
|
|
|
|
|
|
|
|
k = self.k_norm(k.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2)
|
|
|
|
|
|
|
|
v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
|
|
|
if self.fast_attn:
|
|
|
|
|
|
|
|
x_attn = F.scaled_dot_product_attention(
|
|
|
|
|
|
|
|
q, k, v,
|
|
|
|
|
|
|
|
dropout_p=self.attn_drop.p,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
q = q * self.scale
|
|
|
|
|
|
|
|
attn = q @ k.transpose(-2, -1)
|
|
|
|
|
|
|
|
attn = attn.softmax(dim=-1)
|
|
|
|
|
|
|
|
attn = self.attn_drop(attn)
|
|
|
|
|
|
|
|
x_attn = attn @ v
|
|
|
|
|
|
|
|
x_attn = x_attn.transpose(1, 2).reshape(B, N, C)
|
|
|
|
|
|
|
|
x_attn = self.attn_out_proj(x_attn)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# MLP activation, dropout, fc2
|
|
|
|
|
|
|
|
x_mlp = self.mlp_act(x_mlp)
|
|
|
|
|
|
|
|
x_mlp = self.mlp_drop(x_mlp)
|
|
|
|
|
|
|
|
x_mlp = self.mlp_out_proj(x_mlp)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Add residual w/ drop path & layer scale applied
|
|
|
|
|
|
|
|
y = self.drop_path(self.ls(x_attn + x_mlp))
|
|
|
|
|
|
|
|
x = x + y
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ParallelThingsBlock(nn.Module):
|
|
|
|
|
|
|
|
""" Parallel ViT block (N parallel attention followed by N parallel MLP)
|
|
|
|
|
|
|
|
Based on:
|
|
|
|
|
|
|
|
`Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
|
|
|
|
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
dim,
|
|
|
|
dim,
|
|
|
@ -170,6 +318,7 @@ class ParallelBlock(nn.Module):
|
|
|
|
num_parallel=2,
|
|
|
|
num_parallel=2,
|
|
|
|
mlp_ratio=4.,
|
|
|
|
mlp_ratio=4.,
|
|
|
|
qkv_bias=False,
|
|
|
|
qkv_bias=False,
|
|
|
|
|
|
|
|
qk_norm=False,
|
|
|
|
init_values=None,
|
|
|
|
init_values=None,
|
|
|
|
drop=0.,
|
|
|
|
drop=0.,
|
|
|
|
attn_drop=0.,
|
|
|
|
attn_drop=0.,
|
|
|
@ -184,13 +333,26 @@ class ParallelBlock(nn.Module):
|
|
|
|
for _ in range(num_parallel):
|
|
|
|
for _ in range(num_parallel):
|
|
|
|
self.attns.append(nn.Sequential(OrderedDict([
|
|
|
|
self.attns.append(nn.Sequential(OrderedDict([
|
|
|
|
('norm', norm_layer(dim)),
|
|
|
|
('norm', norm_layer(dim)),
|
|
|
|
('attn', Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)),
|
|
|
|
('attn', Attention(
|
|
|
|
|
|
|
|
dim,
|
|
|
|
|
|
|
|
num_heads=num_heads,
|
|
|
|
|
|
|
|
qkv_bias=qkv_bias,
|
|
|
|
|
|
|
|
qk_norm=qk_norm,
|
|
|
|
|
|
|
|
attn_drop=attn_drop,
|
|
|
|
|
|
|
|
proj_drop=drop,
|
|
|
|
|
|
|
|
norm_layer=norm_layer,
|
|
|
|
|
|
|
|
)),
|
|
|
|
('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
|
|
|
|
('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
|
|
|
|
('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
|
|
|
|
('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
|
|
|
|
])))
|
|
|
|
])))
|
|
|
|
self.ffns.append(nn.Sequential(OrderedDict([
|
|
|
|
self.ffns.append(nn.Sequential(OrderedDict([
|
|
|
|
('norm', norm_layer(dim)),
|
|
|
|
('norm', norm_layer(dim)),
|
|
|
|
('mlp', Mlp(dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)),
|
|
|
|
('mlp', Mlp(
|
|
|
|
|
|
|
|
dim,
|
|
|
|
|
|
|
|
hidden_features=int(dim * mlp_ratio),
|
|
|
|
|
|
|
|
act_layer=act_layer,
|
|
|
|
|
|
|
|
drop=drop,
|
|
|
|
|
|
|
|
)),
|
|
|
|
('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
|
|
|
|
('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
|
|
|
|
('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
|
|
|
|
('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
|
|
|
|
])))
|
|
|
|
])))
|
|
|
@ -232,6 +394,7 @@ class VisionTransformer(nn.Module):
|
|
|
|
num_heads=12,
|
|
|
|
num_heads=12,
|
|
|
|
mlp_ratio=4.,
|
|
|
|
mlp_ratio=4.,
|
|
|
|
qkv_bias=True,
|
|
|
|
qkv_bias=True,
|
|
|
|
|
|
|
|
qk_norm=False,
|
|
|
|
init_values=None,
|
|
|
|
init_values=None,
|
|
|
|
class_token=True,
|
|
|
|
class_token=True,
|
|
|
|
no_embed_class=False,
|
|
|
|
no_embed_class=False,
|
|
|
@ -305,6 +468,7 @@ class VisionTransformer(nn.Module):
|
|
|
|
num_heads=num_heads,
|
|
|
|
num_heads=num_heads,
|
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
|
qkv_bias=qkv_bias,
|
|
|
|
qkv_bias=qkv_bias,
|
|
|
|
|
|
|
|
qk_norm=qk_norm,
|
|
|
|
init_values=init_values,
|
|
|
|
init_values=init_values,
|
|
|
|
drop=drop_rate,
|
|
|
|
drop=drop_rate,
|
|
|
|
attn_drop=attn_drop_rate,
|
|
|
|
attn_drop=attn_drop_rate,
|
|
|
@ -641,9 +805,8 @@ def checkpoint_filter_fn(
|
|
|
|
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
|
|
|
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
|
|
|
import re
|
|
|
|
import re
|
|
|
|
out_dict = {}
|
|
|
|
out_dict = {}
|
|
|
|
if 'model' in state_dict:
|
|
|
|
state_dict = state_dict.get('model', state_dict)
|
|
|
|
# For deit models
|
|
|
|
state_dict = state_dict.get('state_dict', state_dict)
|
|
|
|
state_dict = state_dict['model']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if 'visual.class_embedding' in state_dict:
|
|
|
|
if 'visual.class_embedding' in state_dict:
|
|
|
|
return _convert_openai_clip(state_dict, model)
|
|
|
|
return _convert_openai_clip(state_dict, model)
|
|
|
@ -1129,6 +1292,10 @@ default_cfgs = generate_default_cfgs({
|
|
|
|
url='https://storage.googleapis.com/big_vision/flexivit/vit_b30_i21k_300ep.npz', custom_load=True,
|
|
|
|
url='https://storage.googleapis.com/big_vision/flexivit/vit_b30_i21k_300ep.npz', custom_load=True,
|
|
|
|
hf_hub_id='timm/',
|
|
|
|
hf_hub_id='timm/',
|
|
|
|
input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843),
|
|
|
|
input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'vit_base_patch16_xp_224.untrained': _cfg(url=''),
|
|
|
|
|
|
|
|
'vit_large_patch14_xp_224.untrained': _cfg(url=''),
|
|
|
|
|
|
|
|
'vit_huge_patch14_xp_224.untrained': _cfg(url=''),
|
|
|
|
})
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -1566,7 +1733,7 @@ def vit_small_patch16_18x2_224(pretrained=False, **kwargs):
|
|
|
|
Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
|
|
|
|
Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_kwargs = dict(
|
|
|
|
model_kwargs = dict(
|
|
|
|
patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelBlock)
|
|
|
|
patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelThingsBlock)
|
|
|
|
model = _create_vision_transformer(
|
|
|
|
model = _create_vision_transformer(
|
|
|
|
'vit_small_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
|
|
|
'vit_small_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
|
|
|
return model
|
|
|
|
return model
|
|
|
@ -1577,7 +1744,8 @@ def vit_base_patch16_18x2_224(pretrained=False, **kwargs):
|
|
|
|
""" ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove.
|
|
|
|
""" ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove.
|
|
|
|
Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
|
|
|
|
Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock)
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
|
|
|
patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelThingsBlock)
|
|
|
|
model = _create_vision_transformer(
|
|
|
|
model = _create_vision_transformer(
|
|
|
|
'vit_base_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
|
|
|
'vit_base_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
|
|
|
return model
|
|
|
|
return model
|
|
|
@ -1625,3 +1793,42 @@ def flexivit_large(pretrained=False, **kwargs):
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True)
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True)
|
|
|
|
model = _create_vision_transformer('flexivit_large', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
|
|
|
model = _create_vision_transformer('flexivit_large', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def vit_base_patch16_xp_224(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
""" ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
|
|
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, no_embed_class=True,
|
|
|
|
|
|
|
|
norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
model = _create_vision_transformer(
|
|
|
|
|
|
|
|
'vit_base_patch16_xp_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def vit_large_patch14_xp_224(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
""" ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
|
|
|
patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, no_embed_class=True,
|
|
|
|
|
|
|
|
norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
model = _create_vision_transformer(
|
|
|
|
|
|
|
|
'vit_large_patch14_xp_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def vit_huge_patch14_xp_224(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
""" ViT-Huge model (ViT-H/14) w/ parallel blocks and qk norm enabled.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
|
|
|
patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, no_embed_class=True,
|
|
|
|
|
|
|
|
norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
model = _create_vision_transformer(
|
|
|
|
|
|
|
|
'vit_huge_patch14_xp_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
|
|
|
|
|
|
|
return model
|
|
|
|