Finalize DaViT, some formatting and modelling simplifications (separate PatchEmbed to Stem + Downsample, weights on HF hub.

pull/1654/head
Ross Wightman 2 years ago
parent fb717056da
commit 9a53c3f727

@ -29,7 +29,6 @@ NON_STD_FILTERS = [
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'eva_*', 'flexivit*'
]
#'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', '
NUM_NON_STD = len(NON_STD_FILTERS)
# exclude models that cause specific test failures

@ -11,9 +11,10 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
# Copyright (c) 2022 Mingyu Ding
# All rights reserved.
# This source code is licensed under the MIT license
from collections import OrderedDict
import itertools
from collections import OrderedDict
from functools import partial
from typing import Tuple
import torch
import torch.nn as nn
@ -21,9 +22,8 @@ import torch.nn.functional as F
from torch import Tensor
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, Mlp # ClassifierHead
from timm.layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, Mlp, LayerNorm2d, get_norm_layer
from ._builder import build_model_with_cfg
from ._features import FeatureInfo
from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq
from ._pretrained import generate_default_cfgs
@ -33,89 +33,83 @@ __all__ = ['DaViT']
class ConvPosEnc(nn.Module):
def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'):
def __init__(self, dim: int, k: int = 3, act: bool = False):
super(ConvPosEnc, self).__init__()
self.proj = nn.Conv2d(dim,
dim,
to_2tuple(k),
to_2tuple(1),
to_2tuple(k // 2),
groups=dim)
self.normtype = normtype
self.norm = nn.Identity()
if self.normtype == 'batch':
self.norm = nn.BatchNorm2d(dim)
elif self.normtype == 'layer':
self.norm = nn.LayerNorm(dim)
self.activation = nn.GELU() if act else nn.Identity()
def forward(self, x : Tensor):
B, C, H, W = x.shape
#feat = x.transpose(1, 2).view(B, C, H, W)
self.proj = nn.Conv2d(dim, dim, k, 1, k // 2, groups=dim)
self.act = nn.GELU() if act else nn.Identity()
def forward(self, x: Tensor):
feat = self.proj(x)
if self.normtype == 'batch':
feat = self.norm(feat).flatten(2).transpose(1, 2)
elif self.normtype == 'layer':
feat = self.norm(feat.flatten(2).transpose(1, 2))
else:
feat = feat.flatten(2).transpose(1, 2)
x = x + self.activation(feat).transpose(1, 2).view(B, C, H, W)
x = x + self.act(feat)
return x
class PatchEmbed(nn.Module):
class Stem(nn.Module):
""" Size-agnostic implementation of 2D image to patch embedding,
allowing input size to be adjusted during model forward operation
"""
def __init__(
self,
patch_size=4,
in_chans=3,
embed_dim=96,
overlapped=False):
in_chs=3,
out_chs=96,
stride=4,
norm_layer=LayerNorm2d,
):
super().__init__()
patch_size = to_2tuple(patch_size)
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
if patch_size[0] == 4:
self.proj = nn.Conv2d(
in_chans,
embed_dim,
kernel_size=(7, 7),
stride=patch_size,
padding=(3, 3))
self.norm = nn.LayerNorm(embed_dim)
if patch_size[0] == 2:
kernel = 3 if overlapped else 2
pad = 1 if overlapped else 0
self.proj = nn.Conv2d(
in_chans,
embed_dim,
kernel_size=to_2tuple(kernel),
stride=patch_size,
padding=to_2tuple(pad))
self.norm = nn.LayerNorm(in_chans)
def forward(self, x : Tensor):
stride = to_2tuple(stride)
self.stride = stride
self.in_chs = in_chs
self.out_chs = out_chs
assert stride[0] == 4 # only setup for stride==4
self.conv = nn.Conv2d(
in_chs,
out_chs,
kernel_size=7,
stride=stride,
padding=3,
)
self.norm = norm_layer(out_chs)
def forward(self, x: Tensor):
B, C, H, W = x.shape
if self.norm.normalized_shape[0] == self.in_chans:
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
x = F.pad(x, (0, (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]))
x = F.pad(x, (0, 0, 0, (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]))
x = F.pad(x, (0, (self.stride[1] - W % self.stride[1]) % self.stride[1]))
x = F.pad(x, (0, 0, 0, (self.stride[0] - H % self.stride[0]) % self.stride[0]))
x = self.conv(x)
x = self.norm(x)
return x
x = self.proj(x)
if self.norm.normalized_shape[0] == self.embed_dim:
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
class Downsample(nn.Module):
def __init__(
self,
in_chs,
out_chs,
norm_layer=LayerNorm2d,
):
super().__init__()
self.in_chs = in_chs
self.out_chs = out_chs
self.norm = norm_layer(in_chs)
self.conv = nn.Conv2d(
in_chs,
out_chs,
kernel_size=2,
stride=2,
padding=0,
)
def forward(self, x: Tensor):
B, C, H, W = x.shape
x = self.norm(x)
x = F.pad(x, (0, (2 - W % 2) % 2))
x = F.pad(x, (0, 0, 0, (2 - H % 2) % 2))
x = self.conv(x)
return x
class ChannelAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False):
@ -127,11 +121,11 @@ class ChannelAttention(nn.Module):
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
def forward(self, x : Tensor):
def forward(self, x: Tensor):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q, k, v = qkv.unbind(0)
k = k * self.scale
attention = k.transpose(-1, -2) @ v
@ -140,50 +134,64 @@ class ChannelAttention(nn.Module):
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x
class ChannelBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
ffn=True, cpe_act=False):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
ffn=True,
cpe_act=False,
):
super().__init__()
self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
self.ffn = ffn
self.norm1 = norm_layer(dim)
self.attn = ChannelAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
if self.ffn:
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer)
def forward(self, x : Tensor):
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
else:
self.norm2 = None
self.mlp = None
self.drop_path2 = None
def forward(self, x: Tensor):
B, C, H, W = x.shape
x = self.cpe1(x).flatten(2).transpose(1, 2)
cur = self.norm1(x)
cur = self.attn(cur)
x = x + self.drop_path(cur)
x = x + self.drop_path1(cur)
x = self.cpe2(x.transpose(1, 2).view(B, C, H, W))
if self.mlp is not None:
x = x.flatten(2).transpose(1, 2)
x = x + self.drop_path2(self.mlp(self.norm2(x)))
x = x.transpose(1, 2).view(B, C, H, W)
x = self.cpe2(x.transpose(1, 2).view(B, C, H, W)).flatten(2).transpose(1, 2)
if self.ffn:
x = x + self.drop_path(self.mlp(self.norm2(x)))
x = x.transpose(1, 2).view(B, C, H, W)
return x
def window_partition(x : Tensor, window_size: int):
def window_partition(x: Tensor, window_size: Tuple[int, int]):
"""
Args:
x: (B, H, W, C)
@ -192,12 +200,13 @@ def window_partition(x : Tensor, window_size: int):
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
return windows
@register_notrace_function # reason: int argument is a Proxy
def window_reverse(windows : Tensor, window_size: int, H: int, W: int):
@register_notrace_function # reason: int argument is a Proxy
def window_reverse(windows: Tensor, window_size: Tuple[int, int], H: int, W: int):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
@ -207,9 +216,8 @@ def window_reverse(windows : Tensor, window_size: int, H: int, W: int):
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
@ -225,7 +233,6 @@ class WindowAttention(nn.Module):
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True):
super().__init__()
self.dim = dim
self.window_size = window_size
@ -238,11 +245,11 @@ class WindowAttention(nn.Module):
self.softmax = nn.Softmax(dim=-1)
def forward(self, x : Tensor):
def forward(self, x: Tensor):
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q, k, v = qkv.unbind(0)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
@ -266,108 +273,119 @@ class SpatialBlock(nn.Module):
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, num_heads, window_size=7,
mlp_ratio=4., qkv_bias=True, drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm,
ffn=True, cpe_act=False):
def __init__(
self,
dim,
num_heads,
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
ffn=True,
cpe_act=False,
):
super().__init__()
self.dim = dim
self.ffn = ffn
self.num_heads = num_heads
self.window_size = window_size
self.window_size = to_2tuple(window_size)
self.mlp_ratio = mlp_ratio
self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim,
window_size=to_2tuple(self.window_size),
self.window_size,
num_heads=num_heads,
qkv_bias=qkv_bias)
qkv_bias=qkv_bias,
)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
if self.ffn:
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer)
act_layer=act_layer,
)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
else:
self.norm2 = None
self.mlp = None
self.drop_path1 = None
def forward(self, x : Tensor):
def forward(self, x: Tensor):
B, C, H, W = x.shape
shortcut = self.cpe1(x).flatten(2).transpose(1, 2)
x = self.norm1(shortcut)
x = x.view(B, H, W, C)
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
pad_r = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1]
pad_b = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0]
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
x_windows = window_partition(x, self.window_size)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1], C)
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows)
# merge windows
attn_windows = attn_windows.view(-1,
self.window_size,
self.window_size,
C)
attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
x = window_reverse(attn_windows, self.window_size, Hp, Wp)
#if pad_r > 0 or pad_b > 0:
# if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
x = shortcut + self.drop_path(x)
x = shortcut + self.drop_path1(x)
x = self.cpe2(x.transpose(1, 2).view(B, C, H, W))
if self.mlp is not None:
x = x.flatten(2).transpose(1, 2)
x = x + self.drop_path2(self.mlp(self.norm2(x)))
x = x.transpose(1, 2).view(B, C, H, W)
x = self.cpe2(x.transpose(1, 2).view(B, C, H, W)).flatten(2).transpose(1, 2)
if self.ffn:
x = x + self.drop_path(self.mlp(self.norm2(x)))
x = x.transpose(1, 2).view(B, C, H, W)
return x
class DaViTStage(nn.Module):
def __init__(
self,
in_chs,
out_chs,
depth = 1,
patch_size = 4,
overlapped_patch = False,
attention_types = ('spatial', 'channel'),
num_heads = 3,
window_size = 7,
mlp_ratio = 4,
qkv_bias = True,
drop_path_rates = (0, 0),
norm_layer = nn.LayerNorm,
ffn = True,
cpe_act = False
self,
in_chs,
out_chs,
depth=1,
downsample=True,
attn_types=('spatial', 'channel'),
num_heads=3,
window_size=7,
mlp_ratio=4,
qkv_bias=True,
drop_path_rates=(0, 0),
norm_layer=LayerNorm2d,
norm_layer_cl=nn.LayerNorm,
ffn=True,
cpe_act=False
):
super().__init__()
self.grad_checkpointing = False
# patch embedding layer at the beginning of each stage
self.patch_embed = PatchEmbed(
patch_size=patch_size,
in_chans=in_chs,
embed_dim=out_chs,
overlapped=overlapped_patch
)
# downsample embedding layer at the beginning of each stage
if downsample:
self.downsample = Downsample(in_chs, out_chs, norm_layer=norm_layer)
else:
self.downsample = nn.Identity()
'''
repeating alternating attention blocks in each stage
default: (spatial -> channel) x depth
@ -377,44 +395,40 @@ class DaViTStage(nn.Module):
'''
stage_blocks = []
for block_idx in range(depth):
dual_attention_block = []
for attention_id, attention_type in enumerate(attention_types):
if attention_type == 'spatial':
for attn_idx, attn_type in enumerate(attn_types):
if attn_type == 'spatial':
dual_attention_block.append(SpatialBlock(
dim=out_chs,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id],
norm_layer=norm_layer,
drop_path=drop_path_rates[block_idx],
norm_layer=norm_layer_cl,
ffn=ffn,
cpe_act=cpe_act,
window_size=window_size,
))
elif attention_type == 'channel':
elif attn_type == 'channel':
dual_attention_block.append(ChannelBlock(
dim=out_chs,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id],
norm_layer=norm_layer,
drop_path=drop_path_rates[block_idx],
norm_layer=norm_layer_cl,
ffn=ffn,
cpe_act=cpe_act
))
stage_blocks.append(nn.Sequential(*dual_attention_block))
self.blocks = nn.Sequential(*stage_blocks)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
def forward(self, x : Tensor):
x = self.patch_embed(x)
def forward(self, x: Tensor):
x = self.downsample(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
@ -431,7 +445,6 @@ class DaViT(nn.Module):
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
depths (tuple(int)): Number of blocks in each stage. Default: (1, 1, 3, 1)
patch_size (int | tuple(int)): Patch size. Default: 4
embed_dims (tuple(int)): Patch embedding dimension. Default: (96, 192, 384, 768)
num_heads (tuple(int)): Number of attention heads in different layers. Default: (3, 6, 12, 24)
window_size (int): Window size. Default: 7
@ -442,75 +455,67 @@ class DaViT(nn.Module):
"""
def __init__(
self,
in_chans=3,
depths=(1, 1, 3, 1),
patch_size=4,
embed_dims=(96, 192, 384, 768),
num_heads=(3, 6, 12, 24),
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
drop_path_rate=0.1,
norm_layer=nn.LayerNorm,
attention_types=('spatial', 'channel'),
ffn=True,
overlapped_patch=False,
cpe_act=False,
drop_rate=0.,
attn_drop_rate=0.,
num_classes=1000,
global_pool='avg',
head_norm_first=False,
self,
in_chans=3,
depths=(1, 1, 3, 1),
embed_dims=(96, 192, 384, 768),
num_heads=(3, 6, 12, 24),
window_size=7,
mlp_ratio=4,
qkv_bias=True,
norm_layer='layernorm2d',
norm_layer_cl='layernorm',
norm_eps=1e-5,
attn_types=('spatial', 'channel'),
ffn=True,
cpe_act=False,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
num_classes=1000,
global_pool='avg',
head_norm_first=False,
):
super().__init__()
architecture = [[index] * item for index, item in enumerate(depths)]
self.architecture = architecture
self.embed_dims = embed_dims
self.num_heads = num_heads
self.num_stages = len(self.embed_dims)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, len(attention_types) * len(list(itertools.chain(*self.architecture))))]
assert self.num_stages == len(self.num_heads) == (sorted(list(itertools.chain(*self.architecture)))[-1] + 1)
num_stages = len(embed_dims)
assert num_stages == len(num_heads) == len(depths)
norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps)
self.num_classes = num_classes
self.num_features = embed_dims[-1]
self.drop_rate=drop_rate
self.drop_rate = drop_rate
self.grad_checkpointing = False
self.feature_info = []
self.patch_embed = None
stages = []
for stage_id in range(self.num_stages):
stage_drop_rates = dpr[len(attention_types) * sum(depths[:stage_id]):len(attention_types) * sum(depths[:stage_id + 1])]
self.stem = Stem(in_chans, embed_dims[0], norm_layer=norm_layer)
in_chs = embed_dims[0]
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
stages = []
for stage_idx in range(num_stages):
out_chs = embed_dims[stage_idx]
stage = DaViTStage(
in_chans if stage_id == 0 else embed_dims[stage_id - 1],
embed_dims[stage_id],
depth = depths[stage_id],
patch_size = patch_size if stage_id == 0 else 2,
overlapped_patch = overlapped_patch,
attention_types = attention_types,
num_heads = num_heads[stage_id],
window_size = window_size,
mlp_ratio = mlp_ratio,
qkv_bias = qkv_bias,
drop_path_rates = stage_drop_rates,
norm_layer = nn.LayerNorm,
ffn = ffn,
cpe_act = cpe_act
in_chs,
out_chs,
depth=depths[stage_idx],
downsample=stage_idx > 0,
attn_types=attn_types,
num_heads=num_heads[stage_idx],
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path_rates=dpr[stage_idx],
norm_layer=norm_layer,
norm_layer_cl=norm_layer_cl,
ffn=ffn,
cpe_act=cpe_act,
)
if stage_id == 0:
self.patch_embed = stage.patch_embed
stage.patch_embed = nn.Identity()
in_chs = out_chs
stages.append(stage)
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')]
self.feature_info += [dict(num_chs=out_chs, reduction=2, module=f'stages.{stage_idx}')]
self.stages = nn.Sequential(*stages)
# if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
# otherwise pool -> norm -> fc, the default DaViT order, similar to ConvNeXt
# FIXME generalize this structure to ClassifierHead
@ -521,28 +526,25 @@ class DaViT(nn.Module):
('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
('drop', nn.Dropout(self.drop_rate)),
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())]))
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
for stage in self.stages:
stage.set_grad_checkpointing(enable=enable)
@torch.jit.ignore
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
if global_pool is not None:
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
@ -550,21 +552,21 @@ class DaViT(nn.Module):
self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.patch_embed(x)
x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.stages, x)
else:
x = self.stages(x)
x = self.norm_pre(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
x = self.norm_pre(x)
return x
def forward_head(self, x, pre_logits: bool = False):
x = self.head.global_pool(x)
x = self.head.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
x = self.head.norm(x)
x = self.head.flatten(x)
x = self.head.drop(x)
return x if pre_logits else self.head.fc(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
@ -573,29 +575,28 @@ class DaViT(nn.Module):
def checkpoint_filter_fn(state_dict, model):
""" Remap MSFT checkpoints -> timm """
if 'head' in state_dict:
if 'head.fc.weight' in state_dict:
return state_dict # non-MSFT checkpoint
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
import re
out_dict = {}
for k, v in state_dict.items():
k = re.sub(r'patch_embeds.([0-9]+)', r'stages.\1.patch_embed', k)
k = re.sub(r'patch_embeds.([0-9]+)', r'stages.\1.downsample', k)
k = re.sub(r'main_blocks.([0-9]+)', r'stages.\1.blocks', k)
k = k.replace('stages.0.patch_embed', 'patch_embed')
k = k.replace('downsample.proj', 'downsample.conv')
k = k.replace('stages.0.downsample', 'stem')
k = k.replace('head.', 'head.fc.')
k = k.replace('norms.', 'head.norm.')
k = k.replace('cpe.0', 'cpe1')
k = k.replace('cpe.1', 'cpe2')
out_dict[k] = v
return out_dict
def _create_davit(variant, pretrained=False, **kwargs):
def _create_davit(variant, pretrained=False, **kwargs):
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices)
@ -608,69 +609,71 @@ def _create_davit(variant, pretrained=False, **kwargs):
**kwargs)
return model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.850, 'interpolation': 'bicubic',
'crop_pct': 0.95, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
'first_conv': 'stem.conv', 'classifier': 'head.fc',
**kwargs
}
# TODO contact authors to get larger pretrained models
default_cfgs = generate_default_cfgs({
# official microsoft weights from https://github.com/dingmyu/davit
'davit_tiny.msft_in1k': _cfg(
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_tiny_ed28dd55.pth.tar"),
hf_hub_id='timm/'),
'davit_small.msft_in1k': _cfg(
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_small_d1ecf281.pth.tar"),
hf_hub_id='timm/'),
'davit_base.msft_in1k': _cfg(
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_base_67d9ac26.pth.tar"),
hf_hub_id='timm/'),
'davit_large': _cfg(),
'davit_huge': _cfg(),
'davit_giant': _cfg(),
})
@register_model
def davit_tiny(pretrained=False, **kwargs):
model_kwargs = dict(depths=(1, 1, 3, 1), embed_dims=(96, 192, 384, 768),
num_heads=(3, 6, 12, 24), **kwargs)
model_kwargs = dict(
depths=(1, 1, 3, 1), embed_dims=(96, 192, 384, 768), num_heads=(3, 6, 12, 24), **kwargs)
return _create_davit('davit_tiny', pretrained=pretrained, **model_kwargs)
@register_model
def davit_small(pretrained=False, **kwargs):
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(96, 192, 384, 768),
num_heads=(3, 6, 12, 24), **kwargs)
model_kwargs = dict(
depths=(1, 1, 9, 1), embed_dims=(96, 192, 384, 768), num_heads=(3, 6, 12, 24), **kwargs)
return _create_davit('davit_small', pretrained=pretrained, **model_kwargs)
@register_model
def davit_base(pretrained=False, **kwargs):
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024),
num_heads=(4, 8, 16, 32), **kwargs)
model_kwargs = dict(
depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024), num_heads=(4, 8, 16, 32), **kwargs)
return _create_davit('davit_base', pretrained=pretrained, **model_kwargs)
@register_model
def davit_large(pretrained=False, **kwargs):
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536),
num_heads=(6, 12, 24, 48), **kwargs)
model_kwargs = dict(
depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536), num_heads=(6, 12, 24, 48), **kwargs)
return _create_davit('davit_large', pretrained=pretrained, **model_kwargs)
@register_model
def davit_huge(pretrained=False, **kwargs):
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048),
num_heads=(8, 16, 32, 64), **kwargs)
model_kwargs = dict(
depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048), num_heads=(8, 16, 32, 64), **kwargs)
return _create_davit('davit_huge', pretrained=pretrained, **model_kwargs)
@register_model
def davit_giant(pretrained=False, **kwargs):
model_kwargs = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072),
num_heads=(12, 24, 48, 96), **kwargs)
model_kwargs = dict(
depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072), num_heads=(12, 24, 48, 96), **kwargs)
return _create_davit('davit_giant', pretrained=pretrained, **model_kwargs)

Loading…
Cancel
Save