Davit std (#5)

* Update davit.py

* Update test_models.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* starting point

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update test_models.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Davit revised (#4)

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

clean up

* Update test_models.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update test_models.py

* Update davit.py
pull/1583/head
Fredo Guan 2 years ago committed by GitHub
parent edea013dd1
commit c43340ddd4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -40,7 +40,7 @@ if 'GITHUB_ACTIONS' in os.environ:
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm', '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*', '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*',
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*', '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*',
'swin*giant*'] 'swin*giant*', 'davit*giant', 'davit*huge']
NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*', 'eva_giant*'] NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*', 'eva_giant*']
else: else:
EXCLUDE_FILTERS = [] EXCLUDE_FILTERS = []
@ -271,7 +271,7 @@ if 'GITHUB_ACTIONS' not in os.environ:
EXCLUDE_JIT_FILTERS = [ EXCLUDE_JIT_FILTERS = [
'*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable '*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable
'dla*', 'hrnet*', 'ghostnet*', # hopefully fix at some point 'dla*', 'hrnet*', 'ghostnet*' # hopefully fix at some point
'vit_large_*', 'vit_huge_*', 'vit_gi*', 'vit_large_*', 'vit_huge_*', 'vit_gi*',
] ]

@ -12,8 +12,10 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
# All rights reserved. # All rights reserved.
# This source code is licensed under the MIT license # This source code is licensed under the MIT license
# FIXME remove unused imports
import itertools import itertools
from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union, List from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, overload, Tuple, TypeVar, Union
from collections import OrderedDict from collections import OrderedDict
import torch import torch
@ -32,6 +34,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = ['DaViT'] __all__ = ['DaViT']
class ConvPosEnc(nn.Module): class ConvPosEnc(nn.Module):
def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'): def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'):
@ -50,25 +53,21 @@ class ConvPosEnc(nn.Module):
self.norm = nn.LayerNorm(dim) self.norm = nn.LayerNorm(dim)
self.activation = nn.GELU() if act else nn.Identity() self.activation = nn.GELU() if act else nn.Identity()
def forward(self, x : Tensor, size: Tuple[int, int]): def forward(self, x : Tensor):
B, N, C = x.shape B, C, H, W = x.shape
H, W = size
feat = x.transpose(1, 2).view(B, C, H, W) #feat = x.transpose(1, 2).view(B, C, H, W)
feat = self.proj(feat) feat = self.proj(x)
if self.normtype == 'batch': if self.normtype == 'batch':
feat = self.norm(feat).flatten(2).transpose(1, 2) feat = self.norm(feat).flatten(2).transpose(1, 2)
elif self.normtype == 'layer': elif self.normtype == 'layer':
feat = self.norm(feat.flatten(2).transpose(1, 2)) feat = self.norm(feat.flatten(2).transpose(1, 2))
else: else:
feat = feat.flatten(2).transpose(1, 2) feat = feat.flatten(2).transpose(1, 2)
x = x + self.activation(feat) x = x + self.activation(feat).transpose(1, 2).view(B, C, H, W)
return x return x
# reason: dim in control sequence
# FIXME reimplement to allow tracing
@register_notrace_module
class PatchEmbed(nn.Module): class PatchEmbed(nn.Module):
""" Size-agnostic implementation of 2D image to patch embedding, """ Size-agnostic implementation of 2D image to patch embedding,
allowing input size to be adjusted during model forward operation allowing input size to be adjusted during model forward operation
@ -76,13 +75,15 @@ class PatchEmbed(nn.Module):
def __init__( def __init__(
self, self,
patch_size=16, patch_size=4,
in_chans=3, in_chans=3,
embed_dim=96, embed_dim=96,
overlapped=False): overlapped=False):
super().__init__() super().__init__()
patch_size = to_2tuple(patch_size) patch_size = to_2tuple(patch_size)
self.patch_size = patch_size self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
if patch_size[0] == 4: if patch_size[0] == 4:
self.proj = nn.Conv2d( self.proj = nn.Conv2d(
@ -104,30 +105,19 @@ class PatchEmbed(nn.Module):
self.norm = nn.LayerNorm(in_chans) self.norm = nn.LayerNorm(in_chans)
def forward(self, x : Tensor, size: Tuple[int, int]): def forward(self, x : Tensor):
H, W = size
dim = x.dim()
if dim == 3:
B, HW, C = x.shape
x = self.norm(x)
x = x.reshape(B,
H,
W,
C).permute(0, 3, 1, 2).contiguous()
B, C, H, W = x.shape B, C, H, W = x.shape
if W % self.patch_size[1] != 0: if self.norm.normalized_shape[0] == self.in_chans:
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
if H % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) x = F.pad(x, (0, (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]))
x = F.pad(x, (0, 0, 0, (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]))
x = self.proj(x) x = self.proj(x)
newsize = (x.size(2), x.size(3))
x = x.flatten(2).transpose(1, 2)
if dim == 4:
x = self.norm(x)
return x, newsize
if self.norm.normalized_shape[0] == self.embed_dim:
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x
class ChannelAttention(nn.Module): class ChannelAttention(nn.Module):
@ -162,12 +152,12 @@ class ChannelBlock(nn.Module):
ffn=True, cpe_act=False): ffn=True, cpe_act=False):
super().__init__() super().__init__()
self.cpe = nn.ModuleList([ConvPosEnc(dim=dim, k=3, act=cpe_act), self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
ConvPosEnc(dim=dim, k=3, act=cpe_act)])
self.ffn = ffn self.ffn = ffn
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.attn = ChannelAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias) self.attn = ChannelAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
if self.ffn: if self.ffn:
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
@ -178,17 +168,23 @@ class ChannelBlock(nn.Module):
act_layer=act_layer) act_layer=act_layer)
def forward(self, x : Tensor, size: Tuple[int, int]): def forward(self, x : Tensor):
x = self.cpe[0](x, size)
B, C, H, W = x.shape
x = self.cpe1(x).flatten(2).transpose(1, 2)
cur = self.norm1(x) cur = self.norm1(x)
cur = self.attn(cur) cur = self.attn(cur)
x = x + self.drop_path(cur) x = x + self.drop_path(cur)
x = self.cpe[1](x, size) x = self.cpe2(x.transpose(1, 2).view(B, C, H, W)).flatten(2).transpose(1, 2)
if self.ffn: if self.ffn:
x = x + self.drop_path(self.mlp(self.norm2(x))) x = x + self.drop_path(self.mlp(self.norm2(x)))
return x, size
x = x.transpose(1, 2).view(B, C, H, W)
return x
def window_partition(x : Tensor, window_size: int): def window_partition(x : Tensor, window_size: int):
""" """
@ -283,9 +279,8 @@ class SpatialBlock(nn.Module):
self.num_heads = num_heads self.num_heads = num_heads
self.window_size = window_size self.window_size = window_size
self.mlp_ratio = mlp_ratio self.mlp_ratio = mlp_ratio
self.cpe = nn.ModuleList([ConvPosEnc(dim=dim, k=3, act=cpe_act),
ConvPosEnc(dim=dim, k=3, act=cpe_act)])
self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.attn = WindowAttention( self.attn = WindowAttention(
dim, dim,
@ -294,6 +289,7 @@ class SpatialBlock(nn.Module):
qkv_bias=qkv_bias) qkv_bias=qkv_bias)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
if self.ffn: if self.ffn:
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
@ -304,12 +300,11 @@ class SpatialBlock(nn.Module):
act_layer=act_layer) act_layer=act_layer)
def forward(self, x : Tensor, size: Tuple[int, int]): def forward(self, x : Tensor):
B, C, H, W = x.shape
H, W = size
B, L, C = x.shape
shortcut = self.cpe[0](x, size) shortcut = self.cpe1(x).flatten(2).transpose(1, 2)
x = self.norm1(shortcut) x = self.norm1(shortcut)
x = x.view(B, H, W, C) x = x.view(B, H, W, C)
@ -338,11 +333,92 @@ class SpatialBlock(nn.Module):
x = x.view(B, H * W, C) x = x.view(B, H * W, C)
x = shortcut + self.drop_path(x) x = shortcut + self.drop_path(x)
x = self.cpe[1](x, size) x = self.cpe2(x.transpose(1, 2).view(B, C, H, W)).flatten(2).transpose(1, 2)
if self.ffn: if self.ffn:
x = x + self.drop_path(self.mlp(self.norm2(x))) x = x + self.drop_path(self.mlp(self.norm2(x)))
return x, size
x = x.transpose(1, 2).view(B, C, H, W)
return x
class DaViTStage(nn.Module):
def __init__(
self,
in_chs,
out_chs,
depth = 1,
patch_size = 4,
overlapped_patch = False,
attention_types = ('spatial', 'channel'),
num_heads = 3,
window_size = 7,
mlp_ratio = 4,
qkv_bias = True,
drop_path_rates = (0, 0),
norm_layer = nn.LayerNorm,
ffn = True,
cpe_act = False
):
super().__init__()
self.grad_checkpointing = False
# patch embedding layer at the beginning of each stage
self.patch_embed = PatchEmbed(
patch_size=patch_size,
in_chans=in_chs,
embed_dim=out_chs,
overlapped=overlapped_patch
)
'''
repeating alternating attention blocks in each stage
default: (spatial -> channel) x depth
potential opportunity to integrate with a more general version of ByobNet/ByoaNet
since the logic is similar
'''
stage_blocks = []
for block_idx in range(depth):
dual_attention_block = []
for attention_id, attention_type in enumerate(attention_types):
if attention_type == 'spatial':
dual_attention_block.append(SpatialBlock(
dim=out_chs,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id],
norm_layer=nn.LayerNorm,
ffn=ffn,
cpe_act=cpe_act,
window_size=window_size,
))
elif attention_type == 'channel':
dual_attention_block.append(ChannelBlock(
dim=out_chs,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id],
norm_layer=nn.LayerNorm,
ffn=ffn,
cpe_act=cpe_act
))
stage_blocks.append(nn.Sequential(*dual_attention_block))
self.blocks = nn.Sequential(*stage_blocks)
def forward(self, x : Tensor):
x = self.patch_embed(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
return x
class DaViT(nn.Module): class DaViT(nn.Module):
@ -392,7 +468,7 @@ class DaViT(nn.Module):
self.embed_dims = embed_dims self.embed_dims = embed_dims
self.num_heads = num_heads self.num_heads = num_heads
self.num_stages = len(self.embed_dims) self.num_stages = len(self.embed_dims)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, 2 * len(list(itertools.chain(*self.architecture))))] dpr = [x.item() for x in torch.linspace(0, drop_path_rate, len(attention_types) * len(list(itertools.chain(*self.architecture))))]
assert self.num_stages == len(self.num_heads) == (sorted(list(itertools.chain(*self.architecture)))[-1] + 1) assert self.num_stages == len(self.num_heads) == (sorted(list(itertools.chain(*self.architecture)))[-1] + 1)
self.num_classes = num_classes self.num_classes = num_classes
@ -401,52 +477,38 @@ class DaViT(nn.Module):
self.grad_checkpointing = False self.grad_checkpointing = False
self.feature_info = [] self.feature_info = []
self.patch_embeds = nn.ModuleList([ stages = []
PatchEmbed(patch_size=patch_size if i == 0 else 2,
in_chans=in_chans if i == 0 else self.embed_dims[i - 1],
embed_dim=self.embed_dims[i],
overlapped=overlapped_patch)
for i in range(self.num_stages)])
self.stages = nn.ModuleList()
for stage_id, stage_param in enumerate(self.architecture):
layer_offset_id = len(list(itertools.chain(*self.architecture[:stage_id])))
stage = nn.ModuleList([
nn.ModuleList([
ChannelBlock(
dim=self.embed_dims[item],
num_heads=self.num_heads[item],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=dpr[2 * (layer_id + layer_offset_id) + attention_id],
norm_layer=nn.LayerNorm,
ffn=ffn,
cpe_act=cpe_act
) if attention_type == 'channel' else
SpatialBlock(
dim=self.embed_dims[item],
num_heads=self.num_heads[item],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=dpr[2 * (layer_id + layer_offset_id) + attention_id],
norm_layer=nn.LayerNorm,
ffn=ffn,
cpe_act=cpe_act,
window_size=window_size,
) if attention_type == 'spatial' else None
for attention_id, attention_type in enumerate(attention_types)]
) for layer_id, item in enumerate(stage_param)
])
self.stages.add_module(f'stage_{stage_id}', stage) for stage_id in range(self.num_stages):
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.stage_{stage_id}')] stage_drop_rates = dpr[len(attention_types) * sum(depths[:stage_id]):len(attention_types) * sum(depths[:stage_id + 1])]
stage = DaViTStage(
in_chans if stage_id == 0 else embed_dims[stage_id - 1],
embed_dims[stage_id],
depth = depths[stage_id],
patch_size = patch_size if stage_id == 0 else 2,
overlapped_patch = overlapped_patch,
attention_types = attention_types,
num_heads = num_heads[stage_id],
window_size = window_size,
mlp_ratio = mlp_ratio,
qkv_bias = qkv_bias,
drop_path_rates = stage_drop_rates,
norm_layer = nn.LayerNorm,
ffn = ffn,
cpe_act = cpe_act
)
stages.append(stage)
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')]
self.stages = nn.Sequential(*stages)
self.norms = norm_layer(self.num_features) self.norms = norm_layer(self.num_features)
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
self.apply(self._init_weights) self.apply(self._init_weights)
def _init_weights(self, m): def _init_weights(self, m):
if isinstance(m, nn.Linear): if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02) trunc_normal_(m.weight, std=.02)
@ -470,45 +532,12 @@ class DaViT(nn.Module):
global_pool = self.head.global_pool.pool_type global_pool = self.head.global_pool.pool_type
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
def forward_network(self, x):
size: Tuple[int, int] = (x.size(2), x.size(3))
features = [x]
sizes = [size]
for patch_layer, stage in zip(self.patch_embeds, self.stages):
features[-1], sizes[-1] = patch_layer(features[-1], sizes[-1])
for _, block in enumerate(stage):
for _, layer in enumerate(block):
if self.grad_checkpointing and not torch.jit.is_scripting():
features[-1], sizes[-1] = checkpoint.checkpoint(layer, features[-1], sizes[-1])
else:
features[-1], sizes[-1] = layer(features[-1], sizes[-1])
# don't append outputs of last stage, since they are already there
if(len(features) < self.num_stages):
features.append(features[-1])
sizes.append(sizes[-1])
# non-normalized pyramid features + corresponding sizes
return features, sizes
def forward_pyramid_features(self, x) -> List[Tensor]:
x, sizes = self.forward_network(x)
outs = []
for i, out in enumerate(x):
H, W = sizes[i]
outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous())
return outs
def forward_features(self, x): def forward_features(self, x):
x, sizes = self.forward_network(x) x = self.stages(x)
# take final feature and norm # take final feature and norm
x = self.norms(x[-1]) x = self.norms(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
H, W = sizes[-1] #H, W = sizes[-1]
x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous() #x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous()
return x return x
def forward_head(self, x, pre_logits: bool = False): def forward_head(self, x, pre_logits: bool = False):
@ -522,17 +551,6 @@ class DaViT(nn.Module):
def forward(self, x): def forward(self, x):
return self.forward_classifier(x) return self.forward_classifier(x)
class DaViTFeatures(DaViT):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_indices', (0, 1, 2, 3)))
def forward(self, x) -> List[Tensor]:
return self.forward_pyramid_features(x)
def checkpoint_filter_fn(state_dict, model): def checkpoint_filter_fn(state_dict, model):
""" Remap MSFT checkpoints -> timm """ """ Remap MSFT checkpoints -> timm """
if 'head.norm.weight' in state_dict: if 'head.norm.weight' in state_dict:
@ -541,35 +559,33 @@ def checkpoint_filter_fn(state_dict, model):
if 'state_dict' in state_dict: if 'state_dict' in state_dict:
state_dict = state_dict['state_dict'] state_dict = state_dict['state_dict']
import re
out_dict = {} out_dict = {}
for k, v in state_dict.items(): for k, v in state_dict.items():
k = k.replace('main_blocks.', 'stages.stage_') k = re.sub(r'patch_embeds.([0-9]+)', r'stages.\1.patch_embed', k)
k = re.sub(r'main_blocks.([0-9]+)', r'stages.\1.blocks', k)
k = k.replace('head.', 'head.fc.') k = k.replace('head.', 'head.fc.')
k = k.replace('cpe.0', 'cpe1')
k = k.replace('cpe.1', 'cpe2')
out_dict[k] = v out_dict[k] = v
return out_dict return out_dict
def _create_davit(variant, pretrained=False, **kwargs): def _create_davit(variant, pretrained=False, **kwargs):
model_cls = DaViT
features_only = False
kwargs_filter = None
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1)))) default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices) out_indices = kwargs.pop('out_indices', default_out_indices)
if kwargs.pop('features_only', False):
model_cls = DaViTFeatures
kwargs_filter = ('num_classes', 'global_pool')
features_only = True
model = build_model_with_cfg( model = build_model_with_cfg(
model_cls, DaViT,
variant, variant,
pretrained, pretrained,
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs) **kwargs)
if features_only:
model.pretrained_cfg = pretrained_cfg_for_features(model.default_cfg)
model.default_cfg = model.pretrained_cfg # backwards compat
return model return model
@ -580,7 +596,7 @@ def _cfg(url='', **kwargs): # not sure how this should be set up
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bilinear', 'crop_pct': 0.875, 'interpolation': 'bilinear',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embeds.0.proj', 'classifier': 'head.fc', 'first_conv': 'stages.0.patch_embed.proj', 'classifier': 'head.fc',
**kwargs **kwargs
} }
@ -594,6 +610,9 @@ default_cfgs = generate_default_cfgs({
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_small_d1ecf281.pth.tar"), url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_small_d1ecf281.pth.tar"),
'davit_base.msft_in1k': _cfg( 'davit_base.msft_in1k': _cfg(
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_base_67d9ac26.pth.tar"), url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_base_67d9ac26.pth.tar"),
'davit_large': _cfg(),
'davit_huge': _cfg(),
'davit_giant': _cfg(),
}) })
@ -616,7 +635,7 @@ def davit_base(pretrained=False, **kwargs):
num_heads=(4, 8, 16, 32), **kwargs) num_heads=(4, 8, 16, 32), **kwargs)
return _create_davit('davit_base', pretrained=pretrained, **model_kwargs) return _create_davit('davit_base', pretrained=pretrained, **model_kwargs)
''' models without weights
# TODO contact authors to get larger pretrained models # TODO contact authors to get larger pretrained models
@register_model @register_model
def davit_large(pretrained=False, **kwargs): def davit_large(pretrained=False, **kwargs):
@ -635,4 +654,3 @@ def davit_giant(pretrained=False, **kwargs):
model_kwargs = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072), model_kwargs = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072),
num_heads=(12, 24, 48, 96), **kwargs) num_heads=(12, 24, 48, 96), **kwargs)
return _create_davit('davit_giant', pretrained=pretrained, **model_kwargs) return _create_davit('davit_giant', pretrained=pretrained, **model_kwargs)
'''
Loading…
Cancel
Save