Add MViT (Multi-Scale) V2

pull/1415/head
Ross Wightman 2 years ago
parent 43aa84e861
commit 6e559e9b5f

@ -28,6 +28,7 @@ from .levit import *
from .mlp_mixer import *
from .mobilenetv3 import *
from .mobilevit import *
from .mvitv2 import *
from .nasnet import *
from .nest import *
from .nfnet import *

@ -0,0 +1,993 @@
""" Multi-Scale Vision Transformer v2
@inproceedings{li2021improved,
title={MViTv2: Improved multiscale vision transformers for classification and detection},
author={Li, Yanghao and Wu, Chao-Yuan and Fan, Haoqi and Mangalam, Karttikeya and Xiong, Bo and Malik, Jitendra and Feichtenhofer, Christoph},
booktitle={CVPR},
year={2022}
}
Code adapted from original Apache 2.0 licensed impl at https://github.com/facebookresearch/mvit
Original copyright below.
Modifications and timm support by / Copyright 2022, Ross Wightman
"""
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. All Rights Reserved.
import operator
from collections import OrderedDict
from dataclasses import dataclass
from functools import partial, reduce
from typing import Union, List, Tuple, Optional
import torch
import torch.utils.checkpoint as checkpoint
from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple
from .registry import register_model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head', 'fixed_input_size': True,
**kwargs
}
default_cfgs = dict(
mvitv2_tiny=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_T_in1k.pyth'),
mvitv2_small=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_S_in1k.pyth'),
mvitv2_base=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_B_in1k.pyth'),
mvitv2_large=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_L_in1k.pyth'),
mvitv2_base_in21k=_cfg(
url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_B_in21k.pyth',
num_classes=19168),
mvitv2_large_in21k=_cfg(
url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_L_in21k.pyth',
num_classes=19168),
mvitv2_huge_in21k=_cfg(
url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_H_in21k.pyth',
num_classes=19168),
)
@dataclass
class MultiScaleVitCfg:
depths: Tuple[int, ...] = (2, 3, 16, 3)
embed_dim: Union[int, Tuple[int, ...]] = 96
num_heads: Union[int, Tuple[int, ...]] = 1
mlp_ratio: float = 4.
pool_first: bool = False
expand_attn: bool = True
qkv_bias: bool = True
use_cls_token: bool = False
use_abs_pos: bool = False
residual_pooling: bool = True
mode: str = 'conv'
kernel_qkv: Tuple[int, int] = (3, 3)
stride_q: Optional[Tuple[Tuple[int, int]]] = ((1, 1), (2, 2), (2, 2), (2, 2))
stride_kv: Optional[Tuple[Tuple[int, int]]] = None
stride_kv_adaptive: Optional[Tuple[int, int]] = (4, 4)
patch_kernel: Tuple[int, int] = (7, 7)
patch_stride: Tuple[int, int] = (4, 4)
patch_padding: Tuple[int, int] = (3, 3)
pool_type: str = 'max'
rel_pos_type: str = 'spatial'
act_layer: Union[str, Tuple[str, str]] = 'gelu'
norm_layer: Union[str, Tuple[str, str]] = 'layernorm'
norm_eps: float = 1e-6
def __post_init__(self):
num_stages = len(self.depths)
if not isinstance(self.embed_dim, (tuple, list)):
self.embed_dim = tuple(self.embed_dim * 2 ** i for i in range(num_stages))
assert len(self.embed_dim) == num_stages
if not isinstance(self.num_heads, (tuple, list)):
self.num_heads = tuple(self.num_heads * 2 ** i for i in range(num_stages))
assert len(self.num_heads) == num_stages
if self.stride_kv_adaptive is not None and self.stride_kv is None:
_stride_kv = self.stride_kv_adaptive
pool_kv_stride = []
for i in range(num_stages):
if min(self.stride_q[i]) > 1:
_stride_kv = [
max(_stride_kv[d] // self.stride_q[i][d], 1)
for d in range(len(_stride_kv))
]
pool_kv_stride.append(tuple(_stride_kv))
self.stride_kv = tuple(pool_kv_stride)
model_cfgs = dict(
mvitv2_tiny=MultiScaleVitCfg(
depths=(1, 2, 5, 2),
),
mvitv2_small=MultiScaleVitCfg(
depths=(1, 2, 11, 2),
),
mvitv2_base=MultiScaleVitCfg(
depths=(2, 3, 16, 3),
),
mvitv2_large=MultiScaleVitCfg(
depths=(2, 6, 36, 4),
embed_dim=144,
num_heads=2,
expand_attn=False,
),
mvitv2_base_in21k=MultiScaleVitCfg(
depths=(2, 3, 16, 3),
),
mvitv2_large_in21k=MultiScaleVitCfg(
depths=(2, 6, 36, 4),
embed_dim=144,
num_heads=2,
expand_attn=False,
),
)
def prod(iterable):
return reduce(operator.mul, iterable, 1)
class PatchEmbed(nn.Module):
"""
PatchEmbed.
"""
def __init__(
self,
dim_in=3,
dim_out=768,
kernel=(7, 7),
stride=(4, 4),
padding=(3, 3),
):
super().__init__()
self.proj = nn.Conv2d(
dim_in,
dim_out,
kernel_size=kernel,
stride=stride,
padding=padding,
)
def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
x = self.proj(x)
# B C H W -> B HW C
return x.flatten(2).transpose(1, 2), x.shape[-2:]
def reshape_pre_pool(
x,
feat_size: List[int],
has_cls_token: bool = True
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
H, W = feat_size
if has_cls_token:
cls_tok, x = x[:, :, :1, :], x[:, :, 1:, :]
else:
cls_tok = None
x = x.reshape(-1, H, W, x.shape[-1]).permute(0, 3, 1, 2).contiguous()
return x, cls_tok
def reshape_post_pool(
x,
num_heads: int,
cls_tok: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, List[int]]:
feat_size = [x.shape[2], x.shape[3]]
L_pooled = x.shape[2] * x.shape[3]
x = x.reshape(-1, num_heads, x.shape[1], L_pooled).transpose(2, 3)
if cls_tok is not None:
x = torch.cat((cls_tok, x), dim=2)
return x, feat_size
def cal_rel_pos_type(
attn: torch.Tensor,
q: torch.Tensor,
has_cls_token: bool,
q_size: List[int],
k_size: List[int],
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
):
"""
Spatial Relative Positional Embeddings.
"""
sp_idx = 1 if has_cls_token else 0
q_h, q_w = q_size
k_h, k_w = k_size
# Scale up rel pos if shapes for q and k are different.
q_h_ratio = max(k_h / q_h, 1.0)
k_h_ratio = max(q_h / k_h, 1.0)
dist_h = torch.arange(q_h)[:, None] * q_h_ratio - torch.arange(k_h)[None, :] * k_h_ratio
dist_h += (k_h - 1) * k_h_ratio
q_w_ratio = max(k_w / q_w, 1.0)
k_w_ratio = max(q_w / k_w, 1.0)
dist_w = torch.arange(q_w)[:, None] * q_w_ratio - torch.arange(k_w)[None, :] * k_w_ratio
dist_w += (k_w - 1) * k_w_ratio
Rh = rel_pos_h[dist_h.long()]
Rw = rel_pos_w[dist_w.long()]
B, n_head, q_N, dim = q.shape
r_q = q[:, :, sp_idx:].reshape(B, n_head, q_h, q_w, dim)
rel_h = torch.einsum("byhwc,hkc->byhwk", r_q, Rh)
rel_w = torch.einsum("byhwc,wkc->byhwk", r_q, Rw)
attn[:, :, sp_idx:, sp_idx:] = (
attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_h, q_w, k_h, k_w)
+ rel_h[:, :, :, :, :, None]
+ rel_w[:, :, :, :, None, :]
).view(B, -1, q_h * q_w, k_h * k_w)
return attn
class MultiScaleAttentionPoolFirst(nn.Module):
def __init__(
self,
dim,
dim_out,
feat_size,
num_heads=8,
qkv_bias=True,
mode="conv",
kernel_q=(1, 1),
kernel_kv=(1, 1),
stride_q=(1, 1),
stride_kv=(1, 1),
has_cls_token=True,
rel_pos_type='spatial',
residual_pooling=True,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.num_heads = num_heads
self.dim_out = dim_out
self.head_dim = dim_out // num_heads
self.scale = self.head_dim ** -0.5
self.has_cls_token = has_cls_token
padding_q = tuple([int(q // 2) for q in kernel_q])
padding_kv = tuple([int(kv // 2) for kv in kernel_kv])
self.q = nn.Linear(dim, dim_out, bias=qkv_bias)
self.k = nn.Linear(dim, dim_out, bias=qkv_bias)
self.v = nn.Linear(dim, dim_out, bias=qkv_bias)
self.proj = nn.Linear(dim_out, dim_out)
# Skip pooling with kernel and stride size of (1, 1, 1).
if prod(kernel_q) == 1 and prod(stride_q) == 1:
kernel_q = None
if prod(kernel_kv) == 1 and prod(stride_kv) == 1:
kernel_kv = None
self.mode = mode
self.unshared = mode == 'conv_unshared'
self.pool_q, self.pool_k, self.pool_v = None, None, None
self.norm_q, self.norm_k, self.norm_v = None, None, None
if mode in ("avg", "max"):
pool_op = nn.MaxPool2d if mode == "max" else nn.AvgPool2d
if kernel_q:
self.pool_q = pool_op(kernel_q, stride_q, padding_q)
if kernel_kv:
self.pool_k = pool_op(kernel_kv, stride_kv, padding_kv)
self.pool_v = pool_op(kernel_kv, stride_kv, padding_kv)
elif mode == "conv" or mode == "conv_unshared":
dim_conv = dim // num_heads if mode == "conv" else dim
if kernel_q:
self.pool_q = nn.Conv2d(
dim_conv,
dim_conv,
kernel_q,
stride=stride_q,
padding=padding_q,
groups=dim_conv,
bias=False,
)
self.norm_q = norm_layer(dim_conv)
if kernel_kv:
self.pool_k = nn.Conv2d(
dim_conv,
dim_conv,
kernel_kv,
stride=stride_kv,
padding=padding_kv,
groups=dim_conv,
bias=False,
)
self.norm_k = norm_layer(dim_conv)
self.pool_v = nn.Conv2d(
dim_conv,
dim_conv,
kernel_kv,
stride=stride_kv,
padding=padding_kv,
groups=dim_conv,
bias=False,
)
self.norm_v = norm_layer(dim_conv)
else:
raise NotImplementedError(f"Unsupported model {mode}")
# relative pos embedding
self.rel_pos_type = rel_pos_type
if self.rel_pos_type == 'spatial':
assert feat_size[0] == feat_size[1]
size = feat_size[0]
q_size = size // stride_q[1] if len(stride_q) > 0 else size
kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size
rel_sp_dim = 2 * max(q_size, kv_size) - 1
self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim))
trunc_normal_tf_(self.rel_pos_h, std=0.02)
trunc_normal_tf_(self.rel_pos_w, std=0.02)
self.residual_pooling = residual_pooling
def forward(self, x, feat_size: List[int]):
B, N, _ = x.shape
fold_dim = 1 if self.unshared else self.num_heads
x = x.reshape(B, N, fold_dim, -1).permute(0, 2, 1, 3)
q = k = v = x
if self.pool_q is not None:
q, q_tok = reshape_pre_pool(q, feat_size, self.has_cls_token)
q = self.pool_q(q)
q, q_size = reshape_post_pool(q, self.num_heads, q_tok)
else:
q_size = feat_size
if self.norm_q is not None:
q = self.norm_q(q)
if self.pool_k is not None:
k, k_tok = reshape_pre_pool(k, feat_size, self.has_cls_token)
k = self.pool_k(k)
k, k_size = reshape_post_pool(k, self.num_heads, k_tok)
else:
k_size = feat_size
if self.norm_k is not None:
k = self.norm_k(k)
if self.pool_v is not None:
v, v_tok = reshape_pre_pool(v, feat_size, self.has_cls_token)
v = self.pool_v(v)
v, v_size = reshape_post_pool(v, self.num_heads, v_tok)
else:
v_size = feat_size
if self.norm_v is not None:
v = self.norm_v(v)
q_N = q_size[0] * q_size[1] + int(self.has_cls_token)
q = q.permute(0, 2, 1, 3).reshape(B, q_N, -1)
q = self.q(q).reshape(B, q_N, self.num_heads, -1).permute(0, 2, 1, 3)
k_N = k_size[0] * k_size[1] + int(self.has_cls_token)
k = k.permute(0, 2, 1, 3).reshape(B, k_N, -1)
k = self.k(k).reshape(B, k_N, self.num_heads, -1).permute(0, 2, 1, 3)
v_N = v_size[0] * v_size[1] + int(self.has_cls_token)
v = v.permute(0, 2, 1, 3).reshape(B, v_N, -1)
v = self.v(v).reshape(B, v_N, self.num_heads, -1).permute(0, 2, 1, 3)
attn = (q * self.scale) @ k.transpose(-2, -1)
if self.rel_pos_type == 'spatial':
attn = cal_rel_pos_type(
attn,
q,
self.has_cls_token,
q_size,
k_size,
self.rel_pos_h,
self.rel_pos_w,
)
attn = attn.softmax(dim=-1)
x = attn @ v
if self.residual_pooling:
x = x + q
x = x.transpose(1, 2).reshape(B, -1, self.dim_out)
x = self.proj(x)
return x, q_size
class MultiScaleAttention(nn.Module):
def __init__(
self,
dim,
dim_out,
feat_size,
num_heads=8,
qkv_bias=True,
mode="conv",
kernel_q=(1, 1),
kernel_kv=(1, 1),
stride_q=(1, 1),
stride_kv=(1, 1),
has_cls_token=True,
rel_pos_type='spatial',
residual_pooling=True,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.num_heads = num_heads
self.dim_out = dim_out
self.head_dim = dim_out // num_heads
self.scale = self.head_dim ** -0.5
self.has_cls_token = has_cls_token
padding_q = tuple([int(q // 2) for q in kernel_q])
padding_kv = tuple([int(kv // 2) for kv in kernel_kv])
self.qkv = nn.Linear(dim, dim_out * 3, bias=qkv_bias)
self.proj = nn.Linear(dim_out, dim_out)
# Skip pooling with kernel and stride size of (1, 1, 1).
if prod(kernel_q) == 1 and prod(stride_q) == 1:
kernel_q = None
if prod(kernel_kv) == 1 and prod(stride_kv) == 1:
kernel_kv = None
self.mode = mode
self.unshared = mode == 'conv_unshared'
self.norm_q, self.norm_k, self.norm_v = None, None, None
self.pool_q, self.pool_k, self.pool_v = None, None, None
if mode in ("avg", "max"):
pool_op = nn.MaxPool2d if mode == "max" else nn.AvgPool2d
if kernel_q:
self.pool_q = pool_op(kernel_q, stride_q, padding_q)
if kernel_kv:
self.pool_k = pool_op(kernel_kv, stride_kv, padding_kv)
self.pool_v = pool_op(kernel_kv, stride_kv, padding_kv)
elif mode == "conv" or mode == "conv_unshared":
dim_conv = dim_out // num_heads if mode == "conv" else dim_out
if kernel_q:
self.pool_q = nn.Conv2d(
dim_conv,
dim_conv,
kernel_q,
stride=stride_q,
padding=padding_q,
groups=dim_conv,
bias=False,
)
self.norm_q = norm_layer(dim_conv)
if kernel_kv:
self.pool_k = nn.Conv2d(
dim_conv,
dim_conv,
kernel_kv,
stride=stride_kv,
padding=padding_kv,
groups=dim_conv,
bias=False,
)
self.norm_k = norm_layer(dim_conv)
self.pool_v = nn.Conv2d(
dim_conv,
dim_conv,
kernel_kv,
stride=stride_kv,
padding=padding_kv,
groups=dim_conv,
bias=False,
)
self.norm_v = norm_layer(dim_conv)
else:
raise NotImplementedError(f"Unsupported model {mode}")
# relative pos embedding
self.rel_pos_type = rel_pos_type
if self.rel_pos_type == 'spatial':
assert feat_size[0] == feat_size[1]
size = feat_size[0]
q_size = size // stride_q[1] if len(stride_q) > 0 else size
kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size
rel_sp_dim = 2 * max(q_size, kv_size) - 1
self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim))
trunc_normal_tf_(self.rel_pos_h, std=0.02)
trunc_normal_tf_(self.rel_pos_w, std=0.02)
self.residual_pooling = residual_pooling
def forward(self, x, feat_size: List[int]):
B, N, _ = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(dim=0)
if self.pool_q is not None:
q, q_tok = reshape_pre_pool(q, feat_size, self.has_cls_token)
q = self.pool_q(q)
q, q_size = reshape_post_pool(q, self.num_heads, q_tok)
else:
q_size = feat_size
if self.norm_q is not None:
q = self.norm_q(q)
if self.pool_k is not None:
k, k_tok = reshape_pre_pool(k, feat_size, self.has_cls_token)
k = self.pool_k(k)
k, k_size = reshape_post_pool(k, self.num_heads, k_tok)
else:
k_size = feat_size
if self.norm_k is not None:
k = self.norm_k(k)
if self.pool_v is not None:
v, v_tok = reshape_pre_pool(v, feat_size, self.has_cls_token)
v = self.pool_v(v)
v, _ = reshape_post_pool(v, self.num_heads, v_tok)
if self.norm_v is not None:
v = self.norm_v(v)
attn = (q * self.scale) @ k.transpose(-2, -1)
if self.rel_pos_type == 'spatial':
attn = cal_rel_pos_type(
attn,
q,
self.has_cls_token,
q_size,
k_size,
self.rel_pos_h,
self.rel_pos_w,
)
attn = attn.softmax(dim=-1)
x = attn @ v
if self.residual_pooling:
x = x + q
x = x.transpose(1, 2).reshape(B, -1, self.dim_out)
x = self.proj(x)
return x, q_size
class MultiScaleBlock(nn.Module):
def __init__(
self,
dim,
dim_out,
num_heads,
feat_size,
mlp_ratio=4.0,
qkv_bias=True,
drop_path=0.0,
norm_layer=nn.LayerNorm,
kernel_q=(1, 1),
kernel_kv=(1, 1),
stride_q=(1, 1),
stride_kv=(1, 1),
mode="conv",
has_cls_token=True,
expand_attn=False,
pool_first=False,
rel_pos_type='spatial',
residual_pooling=True,
):
super().__init__()
proj_needed = dim != dim_out
self.dim = dim
self.dim_out = dim_out
self.has_cls_token = has_cls_token
self.norm1 = norm_layer(dim)
self.shortcut_proj_attn = nn.Linear(dim, dim_out) if proj_needed and expand_attn else None
if stride_q and prod(stride_q) > 1:
kernel_skip = [s + 1 if s > 1 else s for s in stride_q]
stride_skip = stride_q
padding_skip = [int(skip // 2) for skip in kernel_skip]
self.shortcut_pool_attn = nn.MaxPool2d(kernel_skip, stride_skip, padding_skip)
else:
self.shortcut_pool_attn = None
att_dim = dim_out if expand_attn else dim
attn_layer = MultiScaleAttentionPoolFirst if pool_first else MultiScaleAttention
self.attn = attn_layer(
dim,
att_dim,
num_heads=num_heads,
feat_size=feat_size,
qkv_bias=qkv_bias,
kernel_q=kernel_q,
kernel_kv=kernel_kv,
stride_q=stride_q,
stride_kv=stride_kv,
norm_layer=norm_layer,
has_cls_token=has_cls_token,
mode=mode,
rel_pos_type=rel_pos_type,
residual_pooling=residual_pooling,
)
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(att_dim)
mlp_dim_out = dim_out
self.shortcut_proj_mlp = nn.Linear(dim, dim_out) if proj_needed and not expand_attn else None
self.mlp = Mlp(
in_features=att_dim,
hidden_features=int(att_dim * mlp_ratio),
out_features=mlp_dim_out,
)
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def _shortcut_pool(self, x, feat_size: List[int]):
if self.shortcut_pool_attn is None:
return x
if self.has_cls_token:
cls_tok, x = x[:, :, :1, :], x[:, :, 1:, :]
else:
cls_tok = None
B, L, C = x.shape
H, W = feat_size
x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()
x = self.shortcut_pool_attn(x)
x = x.reshape(B, C, -1).transpose(1, 2)
if cls_tok is not None:
x = torch.cat((cls_tok, x), dim=2)
return x
def forward(self, x, feat_size: List[int]):
x_norm = self.norm1(x)
# NOTE as per the original impl, this seems odd, but shortcut uses un-normalized input if no proj
x_shortcut = x if self.shortcut_proj_attn is None else self.shortcut_proj_attn(x_norm)
x_shortcut = self._shortcut_pool(x_shortcut, feat_size)
x, feat_size_new = self.attn(x_norm, feat_size)
x = x_shortcut + self.drop_path1(x)
x_norm = self.norm2(x)
x_shortcut = x if self.shortcut_proj_mlp is None else self.shortcut_proj_mlp(x_norm)
x = x_shortcut + self.drop_path2(self.mlp(x_norm))
return x, feat_size_new
class MultiScaleVitStage(nn.Module):
def __init__(
self,
dim,
dim_out,
depth,
num_heads,
feat_size,
mlp_ratio=4.0,
qkv_bias=True,
mode="conv",
kernel_q=(1, 1),
kernel_kv=(1, 1),
stride_q=(1, 1),
stride_kv=(1, 1),
has_cls_token=True,
expand_attn=False,
pool_first=False,
rel_pos_type='spatial',
residual_pooling=True,
norm_layer=nn.LayerNorm,
drop_path=0.0,
):
super().__init__()
self.grad_checkpointing = False
self.blocks = nn.ModuleList()
if expand_attn:
out_dims = (dim_out,) * depth
else:
out_dims = (dim,) * (depth - 1) + (dim_out,)
for i in range(depth):
attention_block = MultiScaleBlock(
dim=dim,
dim_out=out_dims[i],
num_heads=num_heads,
feat_size=feat_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
kernel_q=kernel_q,
kernel_kv=kernel_kv,
stride_q=stride_q if i == 0 else (1, 1),
stride_kv=stride_kv,
mode=mode,
has_cls_token=has_cls_token,
pool_first=pool_first,
rel_pos_type=rel_pos_type,
residual_pooling=residual_pooling,
expand_attn=expand_attn,
norm_layer=norm_layer,
drop_path=drop_path[i] if isinstance(drop_path, (list, tuple)) else drop_path,
)
dim = out_dims[i]
self.blocks.append(attention_block)
if i == 0:
feat_size = tuple([size // stride for size, stride in zip(feat_size, stride_q)])
self.feat_size = feat_size
def forward(self, x, feat_size: List[int]):
for blk in self.blocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
x, feat_size = checkpoint.checkpoint(blk, x, feat_size)
else:
x, feat_size = blk(x, feat_size)
return x, feat_size
class MultiScaleVit(nn.Module):
"""
Improved Multiscale Vision Transformers for Classification and Detection
Yanghao Li*, Chao-Yuan Wu*, Haoqi Fan, Karttikeya Mangalam, Bo Xiong, Jitendra Malik,
Christoph Feichtenhofer*
https://arxiv.org/abs/2112.01526
Multiscale Vision Transformers
Haoqi Fan*, Bo Xiong*, Karttikeya Mangalam*, Yanghao Li*, Zhicheng Yan, Jitendra Malik,
Christoph Feichtenhofer*
https://arxiv.org/abs/2104.11227
"""
def __init__(
self,
cfg: MultiScaleVitCfg,
img_size: Tuple[int, int] = (224, 224),
in_chans: int = 3,
global_pool: str = 'avg',
num_classes: int = 1000,
drop_path_rate: float = 0.,
drop_rate: float = 0.,
):
super().__init__()
img_size = to_2tuple(img_size)
norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps)
self.num_classes = num_classes
self.drop_rate = drop_rate
self.global_pool = global_pool
self.depths = tuple(cfg.depths)
self.expand_attn = cfg.expand_attn
embed_dim = cfg.embed_dim[0]
self.patch_embed = PatchEmbed(
dim_in=in_chans,
dim_out=embed_dim,
kernel=cfg.patch_kernel,
stride=cfg.patch_stride,
padding=cfg.patch_padding,
)
patch_dims = (img_size[0] // cfg.patch_stride[0], img_size[1] // cfg.patch_stride[1])
num_patches = prod(patch_dims)
if cfg.use_cls_token:
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.num_prefix_tokens = 1
pos_embed_dim = num_patches + 1
else:
self.num_prefix_tokens = 0
self.cls_token = None
pos_embed_dim = num_patches
if cfg.use_abs_pos:
self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_dim, embed_dim))
else:
self.pos_embed = None
num_stages = len(cfg.embed_dim)
feat_size = patch_dims
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
self.stages = nn.ModuleList()
for i in range(num_stages):
if cfg.expand_attn:
dim_out = cfg.embed_dim[i]
else:
dim_out = cfg.embed_dim[min(i + 1, num_stages - 1)]
stage = MultiScaleVitStage(
dim=embed_dim,
dim_out=dim_out,
depth=cfg.depths[i],
num_heads=cfg.num_heads[i],
feat_size=feat_size,
mlp_ratio=cfg.mlp_ratio,
qkv_bias=cfg.qkv_bias,
mode=cfg.mode,
pool_first=cfg.pool_first,
expand_attn=cfg.expand_attn,
kernel_q=cfg.kernel_qkv,
kernel_kv=cfg.kernel_qkv,
stride_q=cfg.stride_q[i],
stride_kv=cfg.stride_kv[i],
has_cls_token=cfg.use_cls_token,
rel_pos_type=cfg.rel_pos_type,
residual_pooling=cfg.residual_pooling,
norm_layer=norm_layer,
drop_path=dpr[i],
)
embed_dim = dim_out
feat_size = stage.feat_size
self.stages.append(stage)
self.num_features = embed_dim
self.norm = norm_layer(embed_dim)
self.head = nn.Sequential(OrderedDict([
('drop', nn.Dropout(self.drop_rate)),
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())
]))
if self.pos_embed is not None:
trunc_normal_tf_(self.pos_embed, std=0.02)
if self.cls_token is not None:
trunc_normal_tf_(self.cls_token, std=0.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_tf_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0.0)
@torch.jit.ignore
def no_weight_decay(self):
return {k for k, _ in self.named_parameters()
if any(n in k for n in ["pos_embed", "rel_pos_h", "rel_pos_w", "cls_token"])}
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^stem', # stem and embed
blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
)
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
for s in self.stages:
s.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool
self.head = nn.Sequential(OrderedDict([
('drop', nn.Dropout(self.drop_rate)),
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())
]))
def forward_features(self, x):
x, feat_size = self.patch_embed(x)
B, N, C = x.shape
if self.cls_token is not None:
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
if self.pos_embed is not None:
x = x + self.pos_embed
for stage in self.stages:
x, feat_size = stage(x, feat_size)
x = self.norm(x)
return x
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool:
if self.global_pool == 'avg':
x = x[:, self.num_prefix_tokens:].mean(1)
else:
x = x[:, 0]
return x if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def checkpoint_filter_fn(state_dict, model):
if 'stages.0.blocks.0.norm1.weight' in state_dict:
return state_dict
import re
if 'model_state' in state_dict:
state_dict = state_dict['model_state']
depths = getattr(model, 'depths', None)
expand_attn = getattr(model, 'expand_attn', True)
assert depths is not None, 'model requires depth attribute to remap checkpoints'
depth_map = {}
block_idx = 0
for stage_idx, d in enumerate(depths):
depth_map.update({i: (stage_idx, i - block_idx) for i in range(block_idx, block_idx + d)})
block_idx += d
out_dict = {}
for k, v in state_dict.items():
k = re.sub(
r'blocks\.(\d+)',
lambda x: f'stages.{depth_map[int(x.group(1))][0]}.blocks.{depth_map[int(x.group(1))][1]}',
k)
if expand_attn:
k = re.sub(r'stages\.(\d+).blocks\.(\d+).proj', f'stages.\\1.blocks.\\2.shortcut_proj_attn', k)
else:
k = re.sub(r'stages\.(\d+).blocks\.(\d+).proj', f'stages.\\1.blocks.\\2.shortcut_proj_mlp', k)
if 'head' in k:
k = k.replace('head.projection', 'head.fc')
out_dict[k] = v
# for k, v in state_dict.items():
# if model.pos_embed is not None and k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
# # To resize pos embedding when using model at different size from pretrained weights
# v = resize_pos_embed(
# v,
# model.pos_embed,
# 0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1),
# model.patch_embed.grid_size
# )
return out_dict
def _create_mvitv2(variant, cfg_variant=None, pretrained=False, **kwargs):
return build_model_with_cfg(
MultiScaleVit, variant, pretrained,
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True),
**kwargs)
@register_model
def mvitv2_tiny(pretrained=False, **kwargs):
return _create_mvitv2('mvitv2_tiny', pretrained=pretrained, **kwargs)
@register_model
def mvitv2_small(pretrained=False, **kwargs):
return _create_mvitv2('mvitv2_small', pretrained=pretrained, **kwargs)
@register_model
def mvitv2_base(pretrained=False, **kwargs):
return _create_mvitv2('mvitv2_base', pretrained=pretrained, **kwargs)
@register_model
def mvitv2_large(pretrained=False, **kwargs):
return _create_mvitv2('mvitv2_large', pretrained=pretrained, **kwargs)
# @register_model
# def mvitv2_base_in21k(pretrained=False, **kwargs):
# return _create_mvitv2('mvitv2_base_in21k', pretrained=pretrained, **kwargs)
#
#
# @register_model
# def mvitv2_large_in21k(pretrained=False, **kwargs):
# return _create_mvitv2('mvitv2_large_in21k', pretrained=pretrained, **kwargs)
#
#
# @register_model
# def mvitv2_huge_in21k(pretrained=False, **kwargs):
# return _create_mvitv2('mvitv2_huge_in21k', pretrained=pretrained, **kwargs)
Loading…
Cancel
Save