Davit (#1)
Implement the davit model from https://arxiv.org/abs/2204.03645 and https://github.com/dingmyu/davitpull/1583/head
parent
da6644b6ba
commit
3bd96609c8
@ -0,0 +1,624 @@
|
|||||||
|
""" DaViT: Dual Attention Vision Transformers
|
||||||
|
|
||||||
|
As described in https://arxiv.org/abs/2204.03645
|
||||||
|
|
||||||
|
Input size invariant transformer architecture that combines channel and spacial
|
||||||
|
attention in each block. The attention mechanisms used are linear in complexity.
|
||||||
|
|
||||||
|
DaViT model defs and weights adapted from https://github.com/dingmyu/davit, original copyright below
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Copyright (c) 2022 Mingyu Ding
|
||||||
|
# All rights reserved.
|
||||||
|
# This source code is licensed under the MIT license
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from .helpers import build_model_with_cfg
|
||||||
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
from .layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp
|
||||||
|
from collections import OrderedDict
|
||||||
|
import torch.utils.checkpoint as checkpoint
|
||||||
|
from .pretrained import generate_default_cfgs
|
||||||
|
from .registry import register_model
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ['DaViT']
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class MySequential(nn.Sequential):
|
||||||
|
def forward(self, *inputs):
|
||||||
|
for module in self._modules.values():
|
||||||
|
if type(inputs) == tuple:
|
||||||
|
inputs = module(*inputs)
|
||||||
|
else:
|
||||||
|
inputs = module(inputs)
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
|
class ConvPosEnc(nn.Module):
|
||||||
|
def __init__(self, dim, k=3, act=False, normtype=False):
|
||||||
|
super(ConvPosEnc, self).__init__()
|
||||||
|
self.proj = nn.Conv2d(dim,
|
||||||
|
dim,
|
||||||
|
to_2tuple(k),
|
||||||
|
to_2tuple(1),
|
||||||
|
to_2tuple(k // 2),
|
||||||
|
groups=dim)
|
||||||
|
self.normtype = normtype
|
||||||
|
if self.normtype == 'batch':
|
||||||
|
self.norm = nn.BatchNorm2d(dim)
|
||||||
|
elif self.normtype == 'layer':
|
||||||
|
self.norm = nn.LayerNorm(dim)
|
||||||
|
self.activation = nn.GELU() if act else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x, size: Tuple[int, int]):
|
||||||
|
B, N, C = x.shape
|
||||||
|
H, W = size
|
||||||
|
assert N == H * W
|
||||||
|
|
||||||
|
feat = x.transpose(1, 2).view(B, C, H, W)
|
||||||
|
feat = self.proj(feat)
|
||||||
|
if self.normtype == 'batch':
|
||||||
|
feat = self.norm(feat).flatten(2).transpose(1, 2)
|
||||||
|
elif self.normtype == 'layer':
|
||||||
|
feat = self.norm(feat.flatten(2).transpose(1, 2))
|
||||||
|
else:
|
||||||
|
feat = feat.flatten(2).transpose(1, 2)
|
||||||
|
x = x + self.activation(feat)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEmbed(nn.Module):
|
||||||
|
""" Size-agnostic implementation of 2D image to patch embedding,
|
||||||
|
allowing input size to be adjusted during model forward operation
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patch_size=16,
|
||||||
|
in_chans=3,
|
||||||
|
embed_dim=96,
|
||||||
|
overlapped=False):
|
||||||
|
super().__init__()
|
||||||
|
patch_size = to_2tuple(patch_size)
|
||||||
|
self.patch_size = patch_size
|
||||||
|
|
||||||
|
if patch_size[0] == 4:
|
||||||
|
self.proj = nn.Conv2d(
|
||||||
|
in_chans,
|
||||||
|
embed_dim,
|
||||||
|
kernel_size=(7, 7),
|
||||||
|
stride=patch_size,
|
||||||
|
padding=(3, 3))
|
||||||
|
self.norm = nn.LayerNorm(embed_dim)
|
||||||
|
if patch_size[0] == 2:
|
||||||
|
kernel = 3 if overlapped else 2
|
||||||
|
pad = 1 if overlapped else 0
|
||||||
|
self.proj = nn.Conv2d(
|
||||||
|
in_chans,
|
||||||
|
embed_dim,
|
||||||
|
kernel_size=to_2tuple(kernel),
|
||||||
|
stride=patch_size,
|
||||||
|
padding=to_2tuple(pad))
|
||||||
|
self.norm = nn.LayerNorm(in_chans)
|
||||||
|
|
||||||
|
def forward(self, x, size):
|
||||||
|
H, W = size
|
||||||
|
dim = len(x.shape)
|
||||||
|
if dim == 3:
|
||||||
|
B, HW, C = x.shape
|
||||||
|
x = self.norm(x)
|
||||||
|
x = x.reshape(B,
|
||||||
|
H,
|
||||||
|
W,
|
||||||
|
C).permute(0, 3, 1, 2).contiguous()
|
||||||
|
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
if W % self.patch_size[1] != 0:
|
||||||
|
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
||||||
|
if H % self.patch_size[0] != 0:
|
||||||
|
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
||||||
|
|
||||||
|
x = self.proj(x)
|
||||||
|
newsize = (x.size(2), x.size(3))
|
||||||
|
x = x.flatten(2).transpose(1, 2)
|
||||||
|
if dim == 4:
|
||||||
|
x = self.norm(x)
|
||||||
|
return x, newsize
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim, num_heads=8, qkv_bias=False):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
self.scale = head_dim ** -0.5
|
||||||
|
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, N, C = x.shape
|
||||||
|
|
||||||
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||||
|
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||||
|
|
||||||
|
k = k * self.scale
|
||||||
|
attention = k.transpose(-1, -2) @ v
|
||||||
|
attention = attention.softmax(dim=-1)
|
||||||
|
x = (attention @ q.transpose(-1, -2)).transpose(-1, -2)
|
||||||
|
x = x.transpose(1, 2).reshape(B, N, C)
|
||||||
|
x = self.proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False,
|
||||||
|
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
||||||
|
ffn=True, cpe_act=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.cpe = nn.ModuleList([ConvPosEnc(dim=dim, k=3, act=cpe_act),
|
||||||
|
ConvPosEnc(dim=dim, k=3, act=cpe_act)])
|
||||||
|
self.ffn = ffn
|
||||||
|
self.norm1 = norm_layer(dim)
|
||||||
|
self.attn = ChannelAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
|
||||||
|
if self.ffn:
|
||||||
|
self.norm2 = norm_layer(dim)
|
||||||
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
|
self.mlp = Mlp(
|
||||||
|
in_features=dim,
|
||||||
|
hidden_features=mlp_hidden_dim,
|
||||||
|
act_layer=act_layer)
|
||||||
|
|
||||||
|
def forward(self, x, size):
|
||||||
|
x = self.cpe[0](x, size)
|
||||||
|
cur = self.norm1(x)
|
||||||
|
cur = self.attn(cur)
|
||||||
|
x = x + self.drop_path(cur)
|
||||||
|
|
||||||
|
x = self.cpe[1](x, size)
|
||||||
|
if self.ffn:
|
||||||
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||||
|
return x, size
|
||||||
|
|
||||||
|
|
||||||
|
def window_partition(x, window_size: int):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (B, H, W, C)
|
||||||
|
window_size (int): window size
|
||||||
|
Returns:
|
||||||
|
windows: (num_windows*B, window_size, window_size, C)
|
||||||
|
"""
|
||||||
|
B, H, W, C = x.shape
|
||||||
|
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
||||||
|
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
||||||
|
return windows
|
||||||
|
|
||||||
|
|
||||||
|
def window_reverse(windows, window_size: int, H: int, W: int):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
windows: (num_windows*B, window_size, window_size, C)
|
||||||
|
window_size (int): Window size
|
||||||
|
H (int): Height of image
|
||||||
|
W (int): Width of image
|
||||||
|
Returns:
|
||||||
|
x: (B, H, W, C)
|
||||||
|
"""
|
||||||
|
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||||
|
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
||||||
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class WindowAttention(nn.Module):
|
||||||
|
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
||||||
|
It supports both of shifted and non-shifted window.
|
||||||
|
Args:
|
||||||
|
dim (int): Number of input channels.
|
||||||
|
window_size (tuple[int]): The height and width of the window.
|
||||||
|
num_heads (int): Number of attention heads.
|
||||||
|
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim, window_size, num_heads, qkv_bias=True):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.window_size = window_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
self.scale = head_dim ** -0.5
|
||||||
|
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
|
||||||
|
self.softmax = nn.Softmax(dim=-1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B_, N, C = x.shape
|
||||||
|
|
||||||
|
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||||
|
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||||
|
|
||||||
|
q = q * self.scale
|
||||||
|
attn = (q @ k.transpose(-2, -1))
|
||||||
|
attn = self.softmax(attn)
|
||||||
|
|
||||||
|
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
||||||
|
x = self.proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialBlock(nn.Module):
|
||||||
|
r""" Windows Block.
|
||||||
|
Args:
|
||||||
|
dim (int): Number of input channels.
|
||||||
|
num_heads (int): Number of attention heads.
|
||||||
|
window_size (int): Window size.
|
||||||
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||||
|
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||||
|
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
||||||
|
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
||||||
|
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim, num_heads, window_size=7,
|
||||||
|
mlp_ratio=4., qkv_bias=True, drop_path=0.,
|
||||||
|
act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
||||||
|
ffn=True, cpe_act=False):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.ffn = ffn
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.window_size = window_size
|
||||||
|
self.mlp_ratio = mlp_ratio
|
||||||
|
self.cpe = nn.ModuleList([ConvPosEnc(dim=dim, k=3, act=cpe_act),
|
||||||
|
ConvPosEnc(dim=dim, k=3, act=cpe_act)])
|
||||||
|
|
||||||
|
self.norm1 = norm_layer(dim)
|
||||||
|
self.attn = WindowAttention(
|
||||||
|
dim,
|
||||||
|
window_size=to_2tuple(self.window_size),
|
||||||
|
num_heads=num_heads,
|
||||||
|
qkv_bias=qkv_bias)
|
||||||
|
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
|
||||||
|
if self.ffn:
|
||||||
|
self.norm2 = norm_layer(dim)
|
||||||
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
|
self.mlp = Mlp(
|
||||||
|
in_features=dim,
|
||||||
|
hidden_features=mlp_hidden_dim,
|
||||||
|
act_layer=act_layer)
|
||||||
|
|
||||||
|
def forward(self, x, size):
|
||||||
|
H, W = size
|
||||||
|
B, L, C = x.shape
|
||||||
|
assert L == H * W, "input feature has wrong size"
|
||||||
|
|
||||||
|
shortcut = self.cpe[0](x, size)
|
||||||
|
x = self.norm1(shortcut)
|
||||||
|
x = x.view(B, H, W, C)
|
||||||
|
|
||||||
|
pad_l = pad_t = 0
|
||||||
|
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
||||||
|
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
||||||
|
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
||||||
|
_, Hp, Wp, _ = x.shape
|
||||||
|
|
||||||
|
x_windows = window_partition(x, self.window_size)
|
||||||
|
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
|
||||||
|
|
||||||
|
# W-MSA/SW-MSA
|
||||||
|
attn_windows = self.attn(x_windows)
|
||||||
|
|
||||||
|
# merge windows
|
||||||
|
attn_windows = attn_windows.view(-1,
|
||||||
|
self.window_size,
|
||||||
|
self.window_size,
|
||||||
|
C)
|
||||||
|
x = window_reverse(attn_windows, self.window_size, Hp, Wp)
|
||||||
|
|
||||||
|
if pad_r > 0 or pad_b > 0:
|
||||||
|
x = x[:, :H, :W, :].contiguous()
|
||||||
|
|
||||||
|
x = x.view(B, H * W, C)
|
||||||
|
x = shortcut + self.drop_path(x)
|
||||||
|
|
||||||
|
x = self.cpe[1](x, size)
|
||||||
|
if self.ffn:
|
||||||
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||||
|
return x, size
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class DaViT(nn.Module):
|
||||||
|
r""" Dual Attention Transformer
|
||||||
|
Args:
|
||||||
|
patch_size (int | tuple(int)): Patch size. Default: 4
|
||||||
|
in_chans (int): Number of input image channels. Default: 3
|
||||||
|
embed_dims (tuple(int)): Patch embedding dimension. Default: (64, 128, 192, 256)
|
||||||
|
num_heads (tuple(int)): Number of attention heads in different layers. Default: (4, 8, 12, 16)
|
||||||
|
window_size (int): Window size. Default: 7
|
||||||
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
||||||
|
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
||||||
|
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
||||||
|
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_chans=3,
|
||||||
|
depths=(1, 1, 3, 1),
|
||||||
|
patch_size=4,
|
||||||
|
embed_dims=(96, 192, 384, 768),
|
||||||
|
num_heads=(3, 6, 12, 24),
|
||||||
|
window_size=7,
|
||||||
|
mlp_ratio=4.,
|
||||||
|
qkv_bias=True,
|
||||||
|
drop_path_rate=0.1,
|
||||||
|
norm_layer=nn.LayerNorm,
|
||||||
|
attention_types=('spatial', 'channel'),
|
||||||
|
ffn=True,
|
||||||
|
overlapped_patch=False,
|
||||||
|
cpe_act=False,
|
||||||
|
drop_rate=0.,
|
||||||
|
attn_drop_rate=0.,
|
||||||
|
img_size=224,
|
||||||
|
num_classes=1000,
|
||||||
|
global_pool='avg'
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
architecture = [[index] * item for index, item in enumerate(depths)]
|
||||||
|
self.architecture = architecture
|
||||||
|
self.embed_dims = embed_dims
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_stages = len(self.embed_dims)
|
||||||
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, 2 * len(list(itertools.chain(*self.architecture))))]
|
||||||
|
assert self.num_stages == len(self.num_heads) == (sorted(list(itertools.chain(*self.architecture)))[-1] + 1)
|
||||||
|
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.num_features = embed_dims[-1]
|
||||||
|
self.drop_rate=drop_rate
|
||||||
|
self.grad_checkpointing = False
|
||||||
|
|
||||||
|
|
||||||
|
self.patch_embeds = nn.ModuleList([
|
||||||
|
PatchEmbed(patch_size=patch_size if i == 0 else 2,
|
||||||
|
in_chans=in_chans if i == 0 else self.embed_dims[i - 1],
|
||||||
|
embed_dim=self.embed_dims[i],
|
||||||
|
overlapped=overlapped_patch)
|
||||||
|
for i in range(self.num_stages)])
|
||||||
|
|
||||||
|
main_blocks = []
|
||||||
|
for block_id, block_param in enumerate(self.architecture):
|
||||||
|
layer_offset_id = len(list(itertools.chain(*self.architecture[:block_id])))
|
||||||
|
|
||||||
|
block = nn.ModuleList([
|
||||||
|
MySequential(*[
|
||||||
|
ChannelBlock(
|
||||||
|
dim=self.embed_dims[item],
|
||||||
|
num_heads=self.num_heads[item],
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
drop_path=dpr[2 * (layer_id + layer_offset_id) + attention_id],
|
||||||
|
norm_layer=nn.LayerNorm,
|
||||||
|
ffn=ffn,
|
||||||
|
cpe_act=cpe_act
|
||||||
|
) if attention_type == 'channel' else
|
||||||
|
SpatialBlock(
|
||||||
|
dim=self.embed_dims[item],
|
||||||
|
num_heads=self.num_heads[item],
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
drop_path=dpr[2 * (layer_id + layer_offset_id) + attention_id],
|
||||||
|
norm_layer=nn.LayerNorm,
|
||||||
|
ffn=ffn,
|
||||||
|
cpe_act=cpe_act,
|
||||||
|
window_size=window_size,
|
||||||
|
) if attention_type == 'spatial' else None
|
||||||
|
for attention_id, attention_type in enumerate(attention_types)]
|
||||||
|
) for layer_id, item in enumerate(block_param)
|
||||||
|
])
|
||||||
|
main_blocks.append(block)
|
||||||
|
self.main_blocks = nn.ModuleList(main_blocks)
|
||||||
|
|
||||||
|
'''
|
||||||
|
# layer norms for pyramid feature extraction
|
||||||
|
#
|
||||||
|
# TODO implement pyramid feature extraction
|
||||||
|
#
|
||||||
|
# davit should be a good transformer candidate, since the only official implementation
|
||||||
|
# is for segmentation and detection
|
||||||
|
for i_layer in range(self.num_stages):
|
||||||
|
layer = norm_layer(self.embed_dims[i_layer])
|
||||||
|
layer_name = f'norm{i_layer}'
|
||||||
|
self.add_module(layer_name, layer)
|
||||||
|
'''
|
||||||
|
self.norms = norm_layer(self.num_features)
|
||||||
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
def _init_weights(self, m):
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
trunc_normal_(m.weight, std=.02)
|
||||||
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.LayerNorm):
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
nn.init.constant_(m.weight, 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def set_grad_checkpointing(self, enable=True):
|
||||||
|
self.grad_checkpointing = enable
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def get_classifier(self):
|
||||||
|
return self.head.fc
|
||||||
|
|
||||||
|
def reset_classifier(self, num_classes, global_pool=None):
|
||||||
|
self.num_classes = num_classes
|
||||||
|
if global_pool is None:
|
||||||
|
global_pool = self.head.global_pool.pool_type
|
||||||
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
||||||
|
|
||||||
|
|
||||||
|
def forward_features_full(self, x):
|
||||||
|
x, size = self.patch_embeds[0](x, (x.size(2), x.size(3)))
|
||||||
|
features = [x]
|
||||||
|
sizes = [size]
|
||||||
|
branches = [0]
|
||||||
|
|
||||||
|
for block_index, block_param in enumerate(self.architecture):
|
||||||
|
branch_ids = sorted(set(block_param))
|
||||||
|
for branch_id in branch_ids:
|
||||||
|
if branch_id not in branches:
|
||||||
|
x, size = self.patch_embeds[branch_id](features[-1], sizes[-1])
|
||||||
|
features.append(x)
|
||||||
|
sizes.append(size)
|
||||||
|
branches.append(branch_id)
|
||||||
|
for layer_index, branch_id in enumerate(block_param):
|
||||||
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
|
features[branch_id], _ = checkpoint.checkpoint(self.main_blocks[block_index][layer_index], features[branch_id], sizes[branch_id])
|
||||||
|
else:
|
||||||
|
features[branch_id], _ = self.main_blocks[block_index][layer_index](features[branch_id], sizes[branch_id])
|
||||||
|
'''
|
||||||
|
# pyramid feature norm logic, no weights for these extra norm layers from pretrained classification model
|
||||||
|
outs = []
|
||||||
|
for i in range(self.num_stages):
|
||||||
|
norm_layer = getattr(self, f'norm{i}')
|
||||||
|
x_out = norm_layer(features[i])
|
||||||
|
H, W = sizes[i]
|
||||||
|
out = x_out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous()
|
||||||
|
outs.append(out)
|
||||||
|
'''
|
||||||
|
# non-normalized pyramid features + corresponding sizes
|
||||||
|
return tuple(features), tuple(sizes)
|
||||||
|
|
||||||
|
def forward_features(self, x):
|
||||||
|
x, sizes = self.forward_features_full(x)
|
||||||
|
# take final feature and norm
|
||||||
|
x = self.norms(x[-1])
|
||||||
|
H, W = sizes[-1]
|
||||||
|
x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous()
|
||||||
|
#print(x.shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward_head(self, x, pre_logits: bool = False):
|
||||||
|
|
||||||
|
return self.head(x, pre_logits=pre_logits)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.forward_features(x)
|
||||||
|
x = self.forward_head(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def checkpoint_filter_fn(state_dict, model):
|
||||||
|
""" Remap MSFT checkpoints -> timm """
|
||||||
|
if 'head.norm.weight' in state_dict:
|
||||||
|
return state_dict # non-MSFT checkpoint
|
||||||
|
|
||||||
|
if 'state_dict' in state_dict:
|
||||||
|
state_dict = state_dict['state_dict']
|
||||||
|
|
||||||
|
out_dict = {}
|
||||||
|
import re
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
|
||||||
|
k = k.replace('head.', 'head.fc.')
|
||||||
|
out_dict[k] = v
|
||||||
|
return out_dict
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _create_davit(variant, pretrained=False, **kwargs):
|
||||||
|
model = build_model_with_cfg(DaViT, variant, pretrained,
|
||||||
|
pretrained_filter_fn=checkpoint_filter_fn, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _cfg(url='', **kwargs): # not sure how this should be set up
|
||||||
|
return {
|
||||||
|
'url': url,
|
||||||
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||||
|
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||||
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||||
|
'first_conv': 'patch_embeds.0.proj', 'classifier': 'head.fc',
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
default_cfgs = generate_default_cfgs({
|
||||||
|
|
||||||
|
'davit_tiny.msft_in1k': _cfg(
|
||||||
|
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_tiny_ed28dd55.pth.tar"),
|
||||||
|
'davit_small.msft_in1k': _cfg(
|
||||||
|
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_small_d1ecf281.pth.tar"),
|
||||||
|
'davit_base.msft_in1k': _cfg(
|
||||||
|
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_base_67d9ac26.pth.tar"),
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def davit_tiny(pretrained=False, **kwargs):
|
||||||
|
model_kwargs = dict(depths=(1, 1, 3, 1), embed_dims=(96, 192, 384, 768),
|
||||||
|
num_heads=(3, 6, 12, 24), **kwargs)
|
||||||
|
return _create_davit('davit_tiny', pretrained=pretrained, **model_kwargs)
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def davit_small(pretrained=False, **kwargs):
|
||||||
|
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(96, 192, 384, 768),
|
||||||
|
num_heads=(3, 6, 12, 24), **kwargs)
|
||||||
|
return _create_davit('davit_small', pretrained=pretrained, **model_kwargs)
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def davit_base(pretrained=False, **kwargs):
|
||||||
|
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024),
|
||||||
|
num_heads=(4, 8, 16, 32), **kwargs)
|
||||||
|
return _create_davit('davit_base', pretrained=pretrained, **model_kwargs)
|
||||||
|
|
||||||
|
''' models without weights
|
||||||
|
# TODO contact authors to get larger pretrained models
|
||||||
|
@register_model
|
||||||
|
def davit_large(pretrained=False, **kwargs):
|
||||||
|
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536),
|
||||||
|
num_heads=(6, 12, 24, 48), **kwargs)
|
||||||
|
return _create_davit('davit_large', pretrained=pretrained, **model_kwargs)
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def davit_huge(pretrained=False, **kwargs):
|
||||||
|
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048),
|
||||||
|
num_heads=(8, 16, 32, 64), **kwargs)
|
||||||
|
return _create_davit('davit_huge', pretrained=pretrained, **model_kwargs)
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def davit_giant(pretrained=False, **kwargs):
|
||||||
|
model_kwargs = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072),
|
||||||
|
num_heads=(12, 24, 48, 96), **kwargs)
|
||||||
|
return _create_davit('davit_giant', pretrained=pretrained, **model_kwargs)
|
||||||
|
'''
|
Loading…
Reference in new issue