Finalize DaViT, some formatting and modelling simplifications (separate PatchEmbed to Stem + Downsample, weights on HF hub.

pull/1654/head
Ross Wightman 2 years ago
parent fb717056da
commit 9a53c3f727

@ -29,7 +29,6 @@ NON_STD_FILTERS = [
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*', 'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'eva_*', 'flexivit*' 'eva_*', 'flexivit*'
] ]
#'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', '
NUM_NON_STD = len(NON_STD_FILTERS) NUM_NON_STD = len(NON_STD_FILTERS)
# exclude models that cause specific test failures # exclude models that cause specific test failures

@ -11,9 +11,10 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
# Copyright (c) 2022 Mingyu Ding # Copyright (c) 2022 Mingyu Ding
# All rights reserved. # All rights reserved.
# This source code is licensed under the MIT license # This source code is licensed under the MIT license
from collections import OrderedDict
import itertools import itertools
from collections import OrderedDict
from functools import partial
from typing import Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -21,9 +22,8 @@ import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, Mlp # ClassifierHead from timm.layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, Mlp, LayerNorm2d, get_norm_layer
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import FeatureInfo
from ._features_fx import register_notrace_function from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._pretrained import generate_default_cfgs from ._pretrained import generate_default_cfgs
@ -33,89 +33,83 @@ __all__ = ['DaViT']
class ConvPosEnc(nn.Module): class ConvPosEnc(nn.Module):
def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'): def __init__(self, dim: int, k: int = 3, act: bool = False):
super(ConvPosEnc, self).__init__() super(ConvPosEnc, self).__init__()
self.proj = nn.Conv2d(dim,
dim,
to_2tuple(k),
to_2tuple(1),
to_2tuple(k // 2),
groups=dim)
self.normtype = normtype
self.norm = nn.Identity()
if self.normtype == 'batch':
self.norm = nn.BatchNorm2d(dim)
elif self.normtype == 'layer':
self.norm = nn.LayerNorm(dim)
self.activation = nn.GELU() if act else nn.Identity()
def forward(self, x : Tensor): self.proj = nn.Conv2d(dim, dim, k, 1, k // 2, groups=dim)
B, C, H, W = x.shape self.act = nn.GELU() if act else nn.Identity()
#feat = x.transpose(1, 2).view(B, C, H, W) def forward(self, x: Tensor):
feat = self.proj(x) feat = self.proj(x)
if self.normtype == 'batch': x = x + self.act(feat)
feat = self.norm(feat).flatten(2).transpose(1, 2)
elif self.normtype == 'layer':
feat = self.norm(feat.flatten(2).transpose(1, 2))
else:
feat = feat.flatten(2).transpose(1, 2)
x = x + self.activation(feat).transpose(1, 2).view(B, C, H, W)
return x return x
class PatchEmbed(nn.Module): class Stem(nn.Module):
""" Size-agnostic implementation of 2D image to patch embedding, """ Size-agnostic implementation of 2D image to patch embedding,
allowing input size to be adjusted during model forward operation allowing input size to be adjusted during model forward operation
""" """
def __init__( def __init__(
self, self,
patch_size=4, in_chs=3,
in_chans=3, out_chs=96,
embed_dim=96, stride=4,
overlapped=False): norm_layer=LayerNorm2d,
):
super().__init__() super().__init__()
patch_size = to_2tuple(patch_size) stride = to_2tuple(stride)
self.patch_size = patch_size self.stride = stride
self.in_chans = in_chans self.in_chs = in_chs
self.embed_dim = embed_dim self.out_chs = out_chs
assert stride[0] == 4 # only setup for stride==4
if patch_size[0] == 4: self.conv = nn.Conv2d(
self.proj = nn.Conv2d( in_chs,
in_chans, out_chs,
embed_dim, kernel_size=7,
kernel_size=(7, 7), stride=stride,
stride=patch_size, padding=3,
padding=(3, 3)) )
self.norm = nn.LayerNorm(embed_dim) self.norm = norm_layer(out_chs)
if patch_size[0] == 2:
kernel = 3 if overlapped else 2
pad = 1 if overlapped else 0
self.proj = nn.Conv2d(
in_chans,
embed_dim,
kernel_size=to_2tuple(kernel),
stride=patch_size,
padding=to_2tuple(pad))
self.norm = nn.LayerNorm(in_chans)
def forward(self, x: Tensor): def forward(self, x: Tensor):
B, C, H, W = x.shape B, C, H, W = x.shape
if self.norm.normalized_shape[0] == self.in_chans: x = F.pad(x, (0, (self.stride[1] - W % self.stride[1]) % self.stride[1]))
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) x = F.pad(x, (0, 0, 0, (self.stride[0] - H % self.stride[0]) % self.stride[0]))
x = self.conv(x)
x = self.norm(x)
return x
x = F.pad(x, (0, (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]))
x = F.pad(x, (0, 0, 0, (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]))
x = self.proj(x) class Downsample(nn.Module):
def __init__(
self,
in_chs,
out_chs,
norm_layer=LayerNorm2d,
):
super().__init__()
self.in_chs = in_chs
self.out_chs = out_chs
if self.norm.normalized_shape[0] == self.embed_dim: self.norm = norm_layer(in_chs)
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) self.conv = nn.Conv2d(
in_chs,
out_chs,
kernel_size=2,
stride=2,
padding=0,
)
def forward(self, x: Tensor):
B, C, H, W = x.shape
x = self.norm(x)
x = F.pad(x, (0, (2 - W % 2) % 2))
x = F.pad(x, (0, 0, 0, (2 - H % 2) % 2))
x = self.conv(x)
return x return x
class ChannelAttention(nn.Module): class ChannelAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False): def __init__(self, dim, num_heads=8, qkv_bias=False):
@ -131,7 +125,7 @@ class ChannelAttention(nn.Module):
B, N, C = x.shape B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] q, k, v = qkv.unbind(0)
k = k * self.scale k = k * self.scale
attention = k.transpose(-1, -2) @ v attention = k.transpose(-1, -2) @ v
@ -144,46 +138,60 @@ class ChannelAttention(nn.Module):
class ChannelBlock(nn.Module): class ChannelBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, def __init__(
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, self,
ffn=True, cpe_act=False): dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
ffn=True,
cpe_act=False,
):
super().__init__() super().__init__()
self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act) self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
self.ffn = ffn self.ffn = ffn
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.attn = ChannelAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias) self.attn = ChannelAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act) self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
if self.ffn: if self.ffn:
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp( self.mlp = Mlp(
in_features=dim, in_features=dim,
hidden_features=mlp_hidden_dim, hidden_features=int(dim * mlp_ratio),
act_layer=act_layer) act_layer=act_layer,
)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
else:
self.norm2 = None
self.mlp = None
self.drop_path2 = None
def forward(self, x: Tensor): def forward(self, x: Tensor):
B, C, H, W = x.shape B, C, H, W = x.shape
x = self.cpe1(x).flatten(2).transpose(1, 2) x = self.cpe1(x).flatten(2).transpose(1, 2)
cur = self.norm1(x) cur = self.norm1(x)
cur = self.attn(cur) cur = self.attn(cur)
x = x + self.drop_path(cur) x = x + self.drop_path1(cur)
x = self.cpe2(x.transpose(1, 2).view(B, C, H, W)).flatten(2).transpose(1, 2) x = self.cpe2(x.transpose(1, 2).view(B, C, H, W))
if self.ffn:
x = x + self.drop_path(self.mlp(self.norm2(x)))
if self.mlp is not None:
x = x.flatten(2).transpose(1, 2)
x = x + self.drop_path2(self.mlp(self.norm2(x)))
x = x.transpose(1, 2).view(B, C, H, W) x = x.transpose(1, 2).view(B, C, H, W)
return x return x
def window_partition(x : Tensor, window_size: int):
def window_partition(x: Tensor, window_size: Tuple[int, int]):
""" """
Args: Args:
x: (B, H, W, C) x: (B, H, W, C)
@ -192,12 +200,13 @@ def window_partition(x : Tensor, window_size: int):
windows: (num_windows*B, window_size, window_size, C) windows: (num_windows*B, window_size, window_size, C)
""" """
B, H, W, C = x.shape B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
return windows return windows
@register_notrace_function # reason: int argument is a Proxy @register_notrace_function # reason: int argument is a Proxy
def window_reverse(windows : Tensor, window_size: int, H: int, W: int): def window_reverse(windows: Tensor, window_size: Tuple[int, int], H: int, W: int):
""" """
Args: Args:
windows: (num_windows*B, window_size, window_size, C) windows: (num_windows*B, window_size, window_size, C)
@ -207,9 +216,8 @@ def window_reverse(windows : Tensor, window_size: int, H: int, W: int):
Returns: Returns:
x: (B, H, W, C) x: (B, H, W, C)
""" """
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x return x
@ -225,7 +233,6 @@ class WindowAttention(nn.Module):
""" """
def __init__(self, dim, window_size, num_heads, qkv_bias=True): def __init__(self, dim, window_size, num_heads, qkv_bias=True):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.window_size = window_size self.window_size = window_size
@ -242,7 +249,7 @@ class WindowAttention(nn.Module):
B_, N, C = x.shape B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] q, k, v = qkv.unbind(0)
q = q * self.scale q = q * self.scale
attn = (q @ k.transpose(-2, -1)) attn = (q @ k.transpose(-2, -1))
@ -266,74 +273,86 @@ class SpatialBlock(nn.Module):
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
""" """
def __init__(self, dim, num_heads, window_size=7, def __init__(
mlp_ratio=4., qkv_bias=True, drop_path=0., self,
act_layer=nn.GELU, norm_layer=nn.LayerNorm, dim,
ffn=True, cpe_act=False): num_heads,
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
ffn=True,
cpe_act=False,
):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.ffn = ffn self.ffn = ffn
self.num_heads = num_heads self.num_heads = num_heads
self.window_size = window_size self.window_size = to_2tuple(window_size)
self.mlp_ratio = mlp_ratio self.mlp_ratio = mlp_ratio
self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act) self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.attn = WindowAttention( self.attn = WindowAttention(
dim, dim,
window_size=to_2tuple(self.window_size), self.window_size,
num_heads=num_heads, num_heads=num_heads,
qkv_bias=qkv_bias) qkv_bias=qkv_bias,
)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act) self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
if self.ffn: if self.ffn:
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio) mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp( self.mlp = Mlp(
in_features=dim, in_features=dim,
hidden_features=mlp_hidden_dim, hidden_features=mlp_hidden_dim,
act_layer=act_layer) act_layer=act_layer,
)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
else:
self.norm2 = None
self.mlp = None
self.drop_path1 = None
def forward(self, x: Tensor): def forward(self, x: Tensor):
B, C, H, W = x.shape B, C, H, W = x.shape
shortcut = self.cpe1(x).flatten(2).transpose(1, 2) shortcut = self.cpe1(x).flatten(2).transpose(1, 2)
x = self.norm1(shortcut) x = self.norm1(shortcut)
x = x.view(B, H, W, C) x = x.view(B, H, W, C)
pad_l = pad_t = 0 pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size pad_r = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1]
pad_b = (self.window_size - H % self.window_size) % self.window_size pad_b = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0]
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape _, Hp, Wp, _ = x.shape
x_windows = window_partition(x, self.window_size) x_windows = window_partition(x, self.window_size)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1], C)
# W-MSA/SW-MSA # W-MSA/SW-MSA
attn_windows = self.attn(x_windows) attn_windows = self.attn(x_windows)
# merge windows # merge windows
attn_windows = attn_windows.view(-1, attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
self.window_size,
self.window_size,
C)
x = window_reverse(attn_windows, self.window_size, Hp, Wp) x = window_reverse(attn_windows, self.window_size, Hp, Wp)
# if pad_r > 0 or pad_b > 0: # if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous() x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C) x = x.view(B, H * W, C)
x = shortcut + self.drop_path(x) x = shortcut + self.drop_path1(x)
x = self.cpe2(x.transpose(1, 2).view(B, C, H, W)).flatten(2).transpose(1, 2) x = self.cpe2(x.transpose(1, 2).view(B, C, H, W))
if self.ffn:
x = x + self.drop_path(self.mlp(self.norm2(x)))
if self.mlp is not None:
x = x.flatten(2).transpose(1, 2)
x = x + self.drop_path2(self.mlp(self.norm2(x)))
x = x.transpose(1, 2).view(B, C, H, W) x = x.transpose(1, 2).view(B, C, H, W)
return x return x
@ -345,15 +364,15 @@ class DaViTStage(nn.Module):
in_chs, in_chs,
out_chs, out_chs,
depth=1, depth=1,
patch_size = 4, downsample=True,
overlapped_patch = False, attn_types=('spatial', 'channel'),
attention_types = ('spatial', 'channel'),
num_heads=3, num_heads=3,
window_size=7, window_size=7,
mlp_ratio=4, mlp_ratio=4,
qkv_bias=True, qkv_bias=True,
drop_path_rates=(0, 0), drop_path_rates=(0, 0),
norm_layer = nn.LayerNorm, norm_layer=LayerNorm2d,
norm_layer_cl=nn.LayerNorm,
ffn=True, ffn=True,
cpe_act=False cpe_act=False
): ):
@ -361,13 +380,12 @@ class DaViTStage(nn.Module):
self.grad_checkpointing = False self.grad_checkpointing = False
# patch embedding layer at the beginning of each stage # downsample embedding layer at the beginning of each stage
self.patch_embed = PatchEmbed( if downsample:
patch_size=patch_size, self.downsample = Downsample(in_chs, out_chs, norm_layer=norm_layer)
in_chans=in_chs, else:
embed_dim=out_chs, self.downsample = nn.Identity()
overlapped=overlapped_patch
)
''' '''
repeating alternating attention blocks in each stage repeating alternating attention blocks in each stage
default: (spatial -> channel) x depth default: (spatial -> channel) x depth
@ -377,36 +395,32 @@ class DaViTStage(nn.Module):
''' '''
stage_blocks = [] stage_blocks = []
for block_idx in range(depth): for block_idx in range(depth):
dual_attention_block = [] dual_attention_block = []
for attn_idx, attn_type in enumerate(attn_types):
for attention_id, attention_type in enumerate(attention_types): if attn_type == 'spatial':
if attention_type == 'spatial':
dual_attention_block.append(SpatialBlock( dual_attention_block.append(SpatialBlock(
dim=out_chs, dim=out_chs,
num_heads=num_heads, num_heads=num_heads,
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id], drop_path=drop_path_rates[block_idx],
norm_layer=norm_layer, norm_layer=norm_layer_cl,
ffn=ffn, ffn=ffn,
cpe_act=cpe_act, cpe_act=cpe_act,
window_size=window_size, window_size=window_size,
)) ))
elif attention_type == 'channel': elif attn_type == 'channel':
dual_attention_block.append(ChannelBlock( dual_attention_block.append(ChannelBlock(
dim=out_chs, dim=out_chs,
num_heads=num_heads, num_heads=num_heads,
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id], drop_path=drop_path_rates[block_idx],
norm_layer=norm_layer, norm_layer=norm_layer_cl,
ffn=ffn, ffn=ffn,
cpe_act=cpe_act cpe_act=cpe_act
)) ))
stage_blocks.append(nn.Sequential(*dual_attention_block)) stage_blocks.append(nn.Sequential(*dual_attention_block))
self.blocks = nn.Sequential(*stage_blocks) self.blocks = nn.Sequential(*stage_blocks)
@torch.jit.ignore @torch.jit.ignore
@ -414,7 +428,7 @@ class DaViTStage(nn.Module):
self.grad_checkpointing = enable self.grad_checkpointing = enable
def forward(self, x: Tensor): def forward(self, x: Tensor):
x = self.patch_embed(x) x = self.downsample(x)
if self.grad_checkpointing and not torch.jit.is_scripting(): if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x) x = checkpoint_seq(self.blocks, x)
else: else:
@ -431,7 +445,6 @@ class DaViT(nn.Module):
in_chans (int): Number of input image channels. Default: 3 in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000 num_classes (int): Number of classes for classification head. Default: 1000
depths (tuple(int)): Number of blocks in each stage. Default: (1, 1, 3, 1) depths (tuple(int)): Number of blocks in each stage. Default: (1, 1, 3, 1)
patch_size (int | tuple(int)): Patch size. Default: 4
embed_dims (tuple(int)): Patch embedding dimension. Default: (96, 192, 384, 768) embed_dims (tuple(int)): Patch embedding dimension. Default: (96, 192, 384, 768)
num_heads (tuple(int)): Number of attention heads in different layers. Default: (3, 6, 12, 24) num_heads (tuple(int)): Number of attention heads in different layers. Default: (3, 6, 12, 24)
window_size (int): Window size. Default: 7 window_size (int): Window size. Default: 7
@ -445,69 +458,61 @@ class DaViT(nn.Module):
self, self,
in_chans=3, in_chans=3,
depths=(1, 1, 3, 1), depths=(1, 1, 3, 1),
patch_size=4,
embed_dims=(96, 192, 384, 768), embed_dims=(96, 192, 384, 768),
num_heads=(3, 6, 12, 24), num_heads=(3, 6, 12, 24),
window_size=7, window_size=7,
mlp_ratio=4., mlp_ratio=4,
qkv_bias=True, qkv_bias=True,
drop_path_rate=0.1, norm_layer='layernorm2d',
norm_layer=nn.LayerNorm, norm_layer_cl='layernorm',
attention_types=('spatial', 'channel'), norm_eps=1e-5,
attn_types=('spatial', 'channel'),
ffn=True, ffn=True,
overlapped_patch=False,
cpe_act=False, cpe_act=False,
drop_rate=0., drop_rate=0.,
attn_drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0.,
num_classes=1000, num_classes=1000,
global_pool='avg', global_pool='avg',
head_norm_first=False, head_norm_first=False,
): ):
super().__init__() super().__init__()
num_stages = len(embed_dims)
architecture = [[index] * item for index, item in enumerate(depths)] assert num_stages == len(num_heads) == len(depths)
self.architecture = architecture norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
self.embed_dims = embed_dims norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps)
self.num_heads = num_heads
self.num_stages = len(self.embed_dims)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, len(attention_types) * len(list(itertools.chain(*self.architecture))))]
assert self.num_stages == len(self.num_heads) == (sorted(list(itertools.chain(*self.architecture)))[-1] + 1)
self.num_classes = num_classes self.num_classes = num_classes
self.num_features = embed_dims[-1] self.num_features = embed_dims[-1]
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.grad_checkpointing = False self.grad_checkpointing = False
self.feature_info = [] self.feature_info = []
self.patch_embed = None self.stem = Stem(in_chans, embed_dims[0], norm_layer=norm_layer)
stages = [] in_chs = embed_dims[0]
for stage_id in range(self.num_stages):
stage_drop_rates = dpr[len(attention_types) * sum(depths[:stage_id]):len(attention_types) * sum(depths[:stage_id + 1])]
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
stages = []
for stage_idx in range(num_stages):
out_chs = embed_dims[stage_idx]
stage = DaViTStage( stage = DaViTStage(
in_chans if stage_id == 0 else embed_dims[stage_id - 1], in_chs,
embed_dims[stage_id], out_chs,
depth = depths[stage_id], depth=depths[stage_idx],
patch_size = patch_size if stage_id == 0 else 2, downsample=stage_idx > 0,
overlapped_patch = overlapped_patch, attn_types=attn_types,
attention_types = attention_types, num_heads=num_heads[stage_idx],
num_heads = num_heads[stage_id],
window_size=window_size, window_size=window_size,
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
drop_path_rates = stage_drop_rates, drop_path_rates=dpr[stage_idx],
norm_layer = nn.LayerNorm, norm_layer=norm_layer,
norm_layer_cl=norm_layer_cl,
ffn=ffn, ffn=ffn,
cpe_act = cpe_act cpe_act=cpe_act,
) )
in_chs = out_chs
if stage_id == 0:
self.patch_embed = stage.patch_embed
stage.patch_embed = nn.Identity()
stages.append(stage) stages.append(stage)
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')] self.feature_info += [dict(num_chs=out_chs, reduction=2, module=f'stages.{stage_idx}')]
self.stages = nn.Sequential(*stages) self.stages = nn.Sequential(*stages)
@ -529,9 +534,6 @@ class DaViT(nn.Module):
trunc_normal_(m.weight, std=.02) trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None: if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore @torch.jit.ignore
def set_grad_checkpointing(self, enable=True): def set_grad_checkpointing(self, enable=True):
@ -550,17 +552,17 @@ class DaViT(nn.Module):
self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x): def forward_features(self, x):
x = self.patch_embed(x) x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting(): if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.stages, x) x = checkpoint_seq(self.stages, x)
else: else:
x = self.stages(x) x = self.stages(x)
x = self.norm_pre(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) x = self.norm_pre(x)
return x return x
def forward_head(self, x, pre_logits: bool = False): def forward_head(self, x, pre_logits: bool = False):
x = self.head.global_pool(x) x = self.head.global_pool(x)
x = self.head.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) x = self.head.norm(x)
x = self.head.flatten(x) x = self.head.flatten(x)
x = self.head.drop(x) x = self.head.drop(x)
return x if pre_logits else self.head.fc(x) return x if pre_logits else self.head.fc(x)
@ -573,7 +575,7 @@ class DaViT(nn.Module):
def checkpoint_filter_fn(state_dict, model): def checkpoint_filter_fn(state_dict, model):
""" Remap MSFT checkpoints -> timm """ """ Remap MSFT checkpoints -> timm """
if 'head' in state_dict: if 'head.fc.weight' in state_dict:
return state_dict # non-MSFT checkpoint return state_dict # non-MSFT checkpoint
if 'state_dict' in state_dict: if 'state_dict' in state_dict:
@ -582,10 +584,10 @@ def checkpoint_filter_fn(state_dict, model):
import re import re
out_dict = {} out_dict = {}
for k, v in state_dict.items(): for k, v in state_dict.items():
k = re.sub(r'patch_embeds.([0-9]+)', r'stages.\1.downsample', k)
k = re.sub(r'patch_embeds.([0-9]+)', r'stages.\1.patch_embed', k)
k = re.sub(r'main_blocks.([0-9]+)', r'stages.\1.blocks', k) k = re.sub(r'main_blocks.([0-9]+)', r'stages.\1.blocks', k)
k = k.replace('stages.0.patch_embed', 'patch_embed') k = k.replace('downsample.proj', 'downsample.conv')
k = k.replace('stages.0.downsample', 'stem')
k = k.replace('head.', 'head.fc.') k = k.replace('head.', 'head.fc.')
k = k.replace('norms.', 'head.norm.') k = k.replace('norms.', 'head.norm.')
k = k.replace('cpe.0', 'cpe1') k = k.replace('cpe.0', 'cpe1')
@ -595,7 +597,6 @@ def checkpoint_filter_fn(state_dict, model):
def _create_davit(variant, pretrained=False, **kwargs): def _create_davit(variant, pretrained=False, **kwargs):
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1)))) default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices) out_indices = kwargs.pop('out_indices', default_out_indices)
@ -610,67 +611,69 @@ def _create_davit(variant, pretrained=False, **kwargs):
return model return model
def _cfg(url='', **kwargs): def _cfg(url='', **kwargs):
return { return {
'url': url, 'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.850, 'interpolation': 'bicubic', 'crop_pct': 0.95, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc', 'first_conv': 'stem.conv', 'classifier': 'head.fc',
**kwargs **kwargs
} }
# TODO contact authors to get larger pretrained models # TODO contact authors to get larger pretrained models
default_cfgs = generate_default_cfgs({ default_cfgs = generate_default_cfgs({
# official microsoft weights from https://github.com/dingmyu/davit # official microsoft weights from https://github.com/dingmyu/davit
'davit_tiny.msft_in1k': _cfg( 'davit_tiny.msft_in1k': _cfg(
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_tiny_ed28dd55.pth.tar"), hf_hub_id='timm/'),
'davit_small.msft_in1k': _cfg( 'davit_small.msft_in1k': _cfg(
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_small_d1ecf281.pth.tar"), hf_hub_id='timm/'),
'davit_base.msft_in1k': _cfg( 'davit_base.msft_in1k': _cfg(
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_base_67d9ac26.pth.tar"), hf_hub_id='timm/'),
'davit_large': _cfg(), 'davit_large': _cfg(),
'davit_huge': _cfg(), 'davit_huge': _cfg(),
'davit_giant': _cfg(), 'davit_giant': _cfg(),
}) })
@register_model @register_model
def davit_tiny(pretrained=False, **kwargs): def davit_tiny(pretrained=False, **kwargs):
model_kwargs = dict(depths=(1, 1, 3, 1), embed_dims=(96, 192, 384, 768), model_kwargs = dict(
num_heads=(3, 6, 12, 24), **kwargs) depths=(1, 1, 3, 1), embed_dims=(96, 192, 384, 768), num_heads=(3, 6, 12, 24), **kwargs)
return _create_davit('davit_tiny', pretrained=pretrained, **model_kwargs) return _create_davit('davit_tiny', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def davit_small(pretrained=False, **kwargs): def davit_small(pretrained=False, **kwargs):
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(96, 192, 384, 768), model_kwargs = dict(
num_heads=(3, 6, 12, 24), **kwargs) depths=(1, 1, 9, 1), embed_dims=(96, 192, 384, 768), num_heads=(3, 6, 12, 24), **kwargs)
return _create_davit('davit_small', pretrained=pretrained, **model_kwargs) return _create_davit('davit_small', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def davit_base(pretrained=False, **kwargs): def davit_base(pretrained=False, **kwargs):
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024), model_kwargs = dict(
num_heads=(4, 8, 16, 32), **kwargs) depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024), num_heads=(4, 8, 16, 32), **kwargs)
return _create_davit('davit_base', pretrained=pretrained, **model_kwargs) return _create_davit('davit_base', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def davit_large(pretrained=False, **kwargs): def davit_large(pretrained=False, **kwargs):
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536), model_kwargs = dict(
num_heads=(6, 12, 24, 48), **kwargs) depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536), num_heads=(6, 12, 24, 48), **kwargs)
return _create_davit('davit_large', pretrained=pretrained, **model_kwargs) return _create_davit('davit_large', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def davit_huge(pretrained=False, **kwargs): def davit_huge(pretrained=False, **kwargs):
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048), model_kwargs = dict(
num_heads=(8, 16, 32, 64), **kwargs) depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048), num_heads=(8, 16, 32, 64), **kwargs)
return _create_davit('davit_huge', pretrained=pretrained, **model_kwargs) return _create_davit('davit_huge', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def davit_giant(pretrained=False, **kwargs): def davit_giant(pretrained=False, **kwargs):
model_kwargs = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072), model_kwargs = dict(
num_heads=(12, 24, 48, 96), **kwargs) depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072), num_heads=(12, 24, 48, 96), **kwargs)
return _create_davit('davit_giant', pretrained=pretrained, **model_kwargs) return _create_davit('davit_giant', pretrained=pretrained, **model_kwargs)

Loading…
Cancel
Save