|
|
|
@ -12,8 +12,10 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
|
|
|
|
|
# All rights reserved.
|
|
|
|
|
# This source code is licensed under the MIT license
|
|
|
|
|
|
|
|
|
|
# FIXME remove unused imports
|
|
|
|
|
|
|
|
|
|
import itertools
|
|
|
|
|
from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union, List
|
|
|
|
|
from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, overload, Tuple, TypeVar, Union
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
@ -24,7 +26,7 @@ import torch.utils.checkpoint as checkpoint
|
|
|
|
|
|
|
|
|
|
from .features import FeatureInfo
|
|
|
|
|
from .fx_features import register_notrace_function, register_notrace_module
|
|
|
|
|
from .helpers import build_model_with_cfg, pretrained_cfg_for_features
|
|
|
|
|
from .helpers import build_model_with_cfg, checkpoint_seq pretrained_cfg_for_features
|
|
|
|
|
from .layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp
|
|
|
|
|
from .pretrained import generate_default_cfgs
|
|
|
|
|
from .registry import register_model
|
|
|
|
@ -32,6 +34,13 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
|
|
|
|
|
__all__ = ['DaViT']
|
|
|
|
|
|
|
|
|
|
class SequentialWithSize(nn.Sequential):
|
|
|
|
|
def forward(self, x : Tensor, size: Tuple[int, int]):
|
|
|
|
|
for module in self._modules.values():
|
|
|
|
|
x, size = module(x, size)
|
|
|
|
|
return x, size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConvPosEnc(nn.Module):
|
|
|
|
|
def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'):
|
|
|
|
|
|
|
|
|
@ -343,6 +352,76 @@ class SpatialBlock(nn.Module):
|
|
|
|
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
|
|
|
return x, size
|
|
|
|
|
|
|
|
|
|
class DaViTStage(nn.Module):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
in_chs,
|
|
|
|
|
out_chs,
|
|
|
|
|
depth = 1,
|
|
|
|
|
patch_size = 4,
|
|
|
|
|
overlapped_patch = False,
|
|
|
|
|
attention_types = ('spatial', 'channel'),
|
|
|
|
|
num_heads = 3,
|
|
|
|
|
window_size = 7,
|
|
|
|
|
mlp_ratio = 4,
|
|
|
|
|
qkv_bias = True,
|
|
|
|
|
drop_path_rates = (0, 0),
|
|
|
|
|
norm_layer = nn.LayerNorm,
|
|
|
|
|
ffn = True,
|
|
|
|
|
cpe_act = False
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.grad_checkpointing = False
|
|
|
|
|
|
|
|
|
|
self.patch_embed = PatchEmbed(
|
|
|
|
|
patch_size=patch_size,
|
|
|
|
|
in_chans=in_chs,
|
|
|
|
|
embed_dim=out_chs,
|
|
|
|
|
overlapped=overlapped_patch
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
stage_blocks = []
|
|
|
|
|
|
|
|
|
|
for block_idx in range(depth):
|
|
|
|
|
|
|
|
|
|
dual_attention_block = []
|
|
|
|
|
|
|
|
|
|
for attention_id, attention_type in enumerate(attention_types):
|
|
|
|
|
if attention_type == 'channel':
|
|
|
|
|
dual_attention_block.append(ChannelBlock(
|
|
|
|
|
dim=self.embed_dims[item],
|
|
|
|
|
num_heads=self.num_heads[item],
|
|
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
|
|
qkv_bias=qkv_bias,
|
|
|
|
|
drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id],
|
|
|
|
|
norm_layer=nn.LayerNorm,
|
|
|
|
|
ffn=ffn,
|
|
|
|
|
cpe_act=cpe_act
|
|
|
|
|
))
|
|
|
|
|
elif attention_type == 'spatial':
|
|
|
|
|
dual_attention_block.append(SpatialBlock(
|
|
|
|
|
dim=self.embed_dims[item],
|
|
|
|
|
num_heads=self.num_heads[item],
|
|
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
|
|
qkv_bias=qkv_bias,
|
|
|
|
|
drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id],
|
|
|
|
|
norm_layer=nn.LayerNorm,
|
|
|
|
|
ffn=ffn,
|
|
|
|
|
cpe_act=cpe_act,
|
|
|
|
|
window_size=window_size,
|
|
|
|
|
))
|
|
|
|
|
|
|
|
|
|
stage_blocks.append(SequentialWithSize(*dual_attention_block))
|
|
|
|
|
|
|
|
|
|
self.blocks = SequentialWithSize(*stage_blocks)
|
|
|
|
|
|
|
|
|
|
def forward(self, x : Tensor, size: Tuple[int, int]):
|
|
|
|
|
x, size = self.patch_embed(x, size)
|
|
|
|
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
|
|
|
x, size = checkpoint_seq(self.blocks, x, size)
|
|
|
|
|
else:
|
|
|
|
|
x, size = self.blocks(x, size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DaViT(nn.Module):
|
|
|
|
@ -392,7 +471,7 @@ class DaViT(nn.Module):
|
|
|
|
|
self.embed_dims = embed_dims
|
|
|
|
|
self.num_heads = num_heads
|
|
|
|
|
self.num_stages = len(self.embed_dims)
|
|
|
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, 2 * len(list(itertools.chain(*self.architecture))))]
|
|
|
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, len(attention_types) * len(list(itertools.chain(*self.architecture))))]
|
|
|
|
|
assert self.num_stages == len(self.num_heads) == (sorted(list(itertools.chain(*self.architecture)))[-1] + 1)
|
|
|
|
|
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
@ -401,47 +480,34 @@ class DaViT(nn.Module):
|
|
|
|
|
self.grad_checkpointing = False
|
|
|
|
|
self.feature_info = []
|
|
|
|
|
|
|
|
|
|
self.patch_embeds = nn.ModuleList([
|
|
|
|
|
PatchEmbed(patch_size=patch_size if i == 0 else 2,
|
|
|
|
|
in_chans=in_chans if i == 0 else self.embed_dims[i - 1],
|
|
|
|
|
embed_dim=self.embed_dims[i],
|
|
|
|
|
overlapped=overlapped_patch)
|
|
|
|
|
for i in range(self.num_stages)])
|
|
|
|
|
|
|
|
|
|
self.stages = nn.ModuleList()
|
|
|
|
|
for stage_id, stage_param in enumerate(self.architecture):
|
|
|
|
|
layer_offset_id = len(list(itertools.chain(*self.architecture[:stage_id])))
|
|
|
|
|
|
|
|
|
|
stage = nn.ModuleList([
|
|
|
|
|
nn.ModuleList([
|
|
|
|
|
ChannelBlock(
|
|
|
|
|
dim=self.embed_dims[item],
|
|
|
|
|
num_heads=self.num_heads[item],
|
|
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
|
|
qkv_bias=qkv_bias,
|
|
|
|
|
drop_path=dpr[2 * (layer_id + layer_offset_id) + attention_id],
|
|
|
|
|
norm_layer=nn.LayerNorm,
|
|
|
|
|
ffn=ffn,
|
|
|
|
|
cpe_act=cpe_act
|
|
|
|
|
) if attention_type == 'channel' else
|
|
|
|
|
SpatialBlock(
|
|
|
|
|
dim=self.embed_dims[item],
|
|
|
|
|
num_heads=self.num_heads[item],
|
|
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
|
|
qkv_bias=qkv_bias,
|
|
|
|
|
drop_path=dpr[2 * (layer_id + layer_offset_id) + attention_id],
|
|
|
|
|
norm_layer=nn.LayerNorm,
|
|
|
|
|
ffn=ffn,
|
|
|
|
|
cpe_act=cpe_act,
|
|
|
|
|
window_size=window_size,
|
|
|
|
|
) if attention_type == 'spatial' else None
|
|
|
|
|
for attention_id, attention_type in enumerate(attention_types)]
|
|
|
|
|
) for layer_id, item in enumerate(stage_param)
|
|
|
|
|
])
|
|
|
|
|
stages = []
|
|
|
|
|
|
|
|
|
|
for stage_id in range(self.num_stages):
|
|
|
|
|
stage_drop_rates = dpr[len(attention_types) * sum(depths[:stage_id]):len(attention_types) * sum(depths[:stage_id + 1])]
|
|
|
|
|
|
|
|
|
|
stage = DaViTStage(
|
|
|
|
|
in_chans if stage_id == 0 else embed_dims[i - 1],
|
|
|
|
|
embed_dims[stage_id],
|
|
|
|
|
depth = 1,
|
|
|
|
|
patch_size = patch_size,
|
|
|
|
|
overlapped_patch = overlapped_patch,
|
|
|
|
|
attention_types = attention_types,
|
|
|
|
|
num_heads = num_heads[stage_id],
|
|
|
|
|
window_size = window_size,
|
|
|
|
|
mlp_ratio = mlp_ratio,
|
|
|
|
|
qkv_bias = qkv_bias,
|
|
|
|
|
drop_path_rates = stage_drop_rates,
|
|
|
|
|
norm_layer = nn.LayerNorm,
|
|
|
|
|
ffn = ffn,
|
|
|
|
|
cpe_act = cpe_act
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.stages.add_module(f'stage_{stage_id}', stage)
|
|
|
|
|
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.stage_{stage_id}')]
|
|
|
|
|
|
|
|
|
|
stages.append(stage)
|
|
|
|
|
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.stages = SequentialWithSize(*stages)
|
|
|
|
|
|
|
|
|
|
self.norms = norm_layer(self.num_features)
|
|
|
|
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
@ -471,66 +537,21 @@ class DaViT(nn.Module):
|
|
|
|
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward_network(self, x):
|
|
|
|
|
size: Tuple[int, int] = (x.size(2), x.size(3))
|
|
|
|
|
features = [x]
|
|
|
|
|
sizes = [size]
|
|
|
|
|
|
|
|
|
|
for patch_layer, stage in zip(self.patch_embeds, self.stages):
|
|
|
|
|
features[-1], sizes[-1] = patch_layer(features[-1], sizes[-1])
|
|
|
|
|
for _, block in enumerate(stage):
|
|
|
|
|
for _, layer in enumerate(block):
|
|
|
|
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
|
|
|
features[-1], sizes[-1] = checkpoint.checkpoint(layer, features[-1], sizes[-1])
|
|
|
|
|
else:
|
|
|
|
|
features[-1], sizes[-1] = layer(features[-1], sizes[-1])
|
|
|
|
|
|
|
|
|
|
# don't append outputs of last stage, since they are already there
|
|
|
|
|
if(len(features) < self.num_stages):
|
|
|
|
|
features.append(features[-1])
|
|
|
|
|
sizes.append(sizes[-1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# non-normalized pyramid features + corresponding sizes
|
|
|
|
|
return features, sizes
|
|
|
|
|
|
|
|
|
|
def forward_pyramid_features(self, x) -> List[Tensor]:
|
|
|
|
|
x, sizes = self.forward_network(x)
|
|
|
|
|
outs = []
|
|
|
|
|
for i, out in enumerate(x):
|
|
|
|
|
H, W = sizes[i]
|
|
|
|
|
outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous())
|
|
|
|
|
|
|
|
|
|
return outs
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
|
x, sizes = self.forward_network(x)
|
|
|
|
|
# take final feature and norm
|
|
|
|
|
x = self.norms(x[-1])
|
|
|
|
|
H, W = sizes[-1]
|
|
|
|
|
size: Tuple[int, int] = (x.size(2), x.size(3))
|
|
|
|
|
x, size = self.stages(x, size)
|
|
|
|
|
x = self.norms(x)
|
|
|
|
|
H, W = size
|
|
|
|
|
x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous()
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def forward_head(self, x, pre_logits: bool = False):
|
|
|
|
|
return self.head(x, pre_logits=pre_logits)
|
|
|
|
|
|
|
|
|
|
def forward_classifier(self, x):
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.forward_features(x)
|
|
|
|
|
x = self.forward_head(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
return self.forward_classifier(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DaViTFeatures(DaViT):
|
|
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_indices', (0, 1, 2, 3)))
|
|
|
|
|
|
|
|
|
|
def forward(self, x) -> List[Tensor]:
|
|
|
|
|
return self.forward_pyramid_features(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def checkpoint_filter_fn(state_dict, model):
|
|
|
|
@ -551,25 +572,15 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_davit(variant, pretrained=False, **kwargs):
|
|
|
|
|
model_cls = DaViT
|
|
|
|
|
features_only = False
|
|
|
|
|
kwargs_filter = None
|
|
|
|
|
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
|
|
|
|
|
out_indices = kwargs.pop('out_indices', default_out_indices)
|
|
|
|
|
if kwargs.pop('features_only', False):
|
|
|
|
|
model_cls = DaViTFeatures
|
|
|
|
|
kwargs_filter = ('num_classes', 'global_pool')
|
|
|
|
|
features_only = True
|
|
|
|
|
model = build_model_with_cfg(
|
|
|
|
|
model_cls,
|
|
|
|
|
DaViT,
|
|
|
|
|
variant,
|
|
|
|
|
pretrained,
|
|
|
|
|
pretrained_filter_fn=checkpoint_filter_fn,
|
|
|
|
|
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
|
|
|
|
**kwargs)
|
|
|
|
|
if features_only:
|
|
|
|
|
model.pretrained_cfg = pretrained_cfg_for_features(model.default_cfg)
|
|
|
|
|
model.default_cfg = model.pretrained_cfg # backwards compat
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|