Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 8cc2805e7b
commit 187c051ac0

@ -12,8 +12,10 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
# All rights reserved.
# This source code is licensed under the MIT license
# FIXME remove unused imports
import itertools
from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union, List
from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, overload, Tuple, TypeVar, Union
from collections import OrderedDict
import torch
@ -24,7 +26,7 @@ import torch.utils.checkpoint as checkpoint
from .features import FeatureInfo
from .fx_features import register_notrace_function, register_notrace_module
from .helpers import build_model_with_cfg, pretrained_cfg_for_features
from .helpers import build_model_with_cfg, checkpoint_seq pretrained_cfg_for_features
from .layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp
from .pretrained import generate_default_cfgs
from .registry import register_model
@ -32,6 +34,13 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = ['DaViT']
class SequentialWithSize(nn.Sequential):
def forward(self, x : Tensor, size: Tuple[int, int]):
for module in self._modules.values():
x, size = module(x, size)
return x, size
class ConvPosEnc(nn.Module):
def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'):
@ -343,6 +352,76 @@ class SpatialBlock(nn.Module):
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x, size
class DaViTStage(nn.Module):
def __init__(
self,
in_chs,
out_chs,
depth = 1,
patch_size = 4,
overlapped_patch = False,
attention_types = ('spatial', 'channel'),
num_heads = 3,
window_size = 7,
mlp_ratio = 4,
qkv_bias = True,
drop_path_rates = (0, 0),
norm_layer = nn.LayerNorm,
ffn = True,
cpe_act = False
):
super().__init__()
self.grad_checkpointing = False
self.patch_embed = PatchEmbed(
patch_size=patch_size,
in_chans=in_chs,
embed_dim=out_chs,
overlapped=overlapped_patch
)
stage_blocks = []
for block_idx in range(depth):
dual_attention_block = []
for attention_id, attention_type in enumerate(attention_types):
if attention_type == 'channel':
dual_attention_block.append(ChannelBlock(
dim=self.embed_dims[item],
num_heads=self.num_heads[item],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id],
norm_layer=nn.LayerNorm,
ffn=ffn,
cpe_act=cpe_act
))
elif attention_type == 'spatial':
dual_attention_block.append(SpatialBlock(
dim=self.embed_dims[item],
num_heads=self.num_heads[item],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id],
norm_layer=nn.LayerNorm,
ffn=ffn,
cpe_act=cpe_act,
window_size=window_size,
))
stage_blocks.append(SequentialWithSize(*dual_attention_block))
self.blocks = SequentialWithSize(*stage_blocks)
def forward(self, x : Tensor, size: Tuple[int, int]):
x, size = self.patch_embed(x, size)
if self.grad_checkpointing and not torch.jit.is_scripting():
x, size = checkpoint_seq(self.blocks, x, size)
else:
x, size = self.blocks(x, size)
class DaViT(nn.Module):
@ -392,7 +471,7 @@ class DaViT(nn.Module):
self.embed_dims = embed_dims
self.num_heads = num_heads
self.num_stages = len(self.embed_dims)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, 2 * len(list(itertools.chain(*self.architecture))))]
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, len(attention_types) * len(list(itertools.chain(*self.architecture))))]
assert self.num_stages == len(self.num_heads) == (sorted(list(itertools.chain(*self.architecture)))[-1] + 1)
self.num_classes = num_classes
@ -401,47 +480,34 @@ class DaViT(nn.Module):
self.grad_checkpointing = False
self.feature_info = []
self.patch_embeds = nn.ModuleList([
PatchEmbed(patch_size=patch_size if i == 0 else 2,
in_chans=in_chans if i == 0 else self.embed_dims[i - 1],
embed_dim=self.embed_dims[i],
overlapped=overlapped_patch)
for i in range(self.num_stages)])
self.stages = nn.ModuleList()
for stage_id, stage_param in enumerate(self.architecture):
layer_offset_id = len(list(itertools.chain(*self.architecture[:stage_id])))
stage = nn.ModuleList([
nn.ModuleList([
ChannelBlock(
dim=self.embed_dims[item],
num_heads=self.num_heads[item],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=dpr[2 * (layer_id + layer_offset_id) + attention_id],
norm_layer=nn.LayerNorm,
ffn=ffn,
cpe_act=cpe_act
) if attention_type == 'channel' else
SpatialBlock(
dim=self.embed_dims[item],
num_heads=self.num_heads[item],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=dpr[2 * (layer_id + layer_offset_id) + attention_id],
norm_layer=nn.LayerNorm,
ffn=ffn,
cpe_act=cpe_act,
window_size=window_size,
) if attention_type == 'spatial' else None
for attention_id, attention_type in enumerate(attention_types)]
) for layer_id, item in enumerate(stage_param)
])
stages = []
for stage_id in range(self.num_stages):
stage_drop_rates = dpr[len(attention_types) * sum(depths[:stage_id]):len(attention_types) * sum(depths[:stage_id + 1])]
stage = DaViTStage(
in_chans if stage_id == 0 else embed_dims[i - 1],
embed_dims[stage_id],
depth = 1,
patch_size = patch_size,
overlapped_patch = overlapped_patch,
attention_types = attention_types,
num_heads = num_heads[stage_id],
window_size = window_size,
mlp_ratio = mlp_ratio,
qkv_bias = qkv_bias,
drop_path_rates = stage_drop_rates,
norm_layer = nn.LayerNorm,
ffn = ffn,
cpe_act = cpe_act
)
self.stages.add_module(f'stage_{stage_id}', stage)
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.stage_{stage_id}')]
stages.append(stage)
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')]
self.stages = SequentialWithSize(*stages)
self.norms = norm_layer(self.num_features)
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
self.apply(self._init_weights)
@ -471,66 +537,21 @@ class DaViT(nn.Module):
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
def forward_network(self, x):
size: Tuple[int, int] = (x.size(2), x.size(3))
features = [x]
sizes = [size]
for patch_layer, stage in zip(self.patch_embeds, self.stages):
features[-1], sizes[-1] = patch_layer(features[-1], sizes[-1])
for _, block in enumerate(stage):
for _, layer in enumerate(block):
if self.grad_checkpointing and not torch.jit.is_scripting():
features[-1], sizes[-1] = checkpoint.checkpoint(layer, features[-1], sizes[-1])
else:
features[-1], sizes[-1] = layer(features[-1], sizes[-1])
# don't append outputs of last stage, since they are already there
if(len(features) < self.num_stages):
features.append(features[-1])
sizes.append(sizes[-1])
# non-normalized pyramid features + corresponding sizes
return features, sizes
def forward_pyramid_features(self, x) -> List[Tensor]:
x, sizes = self.forward_network(x)
outs = []
for i, out in enumerate(x):
H, W = sizes[i]
outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous())
return outs
def forward_features(self, x):
x, sizes = self.forward_network(x)
# take final feature and norm
x = self.norms(x[-1])
H, W = sizes[-1]
size: Tuple[int, int] = (x.size(2), x.size(3))
x, size = self.stages(x, size)
x = self.norms(x)
H, W = size
x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous()
return x
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits)
def forward_classifier(self, x):
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def forward(self, x):
return self.forward_classifier(x)
class DaViTFeatures(DaViT):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_indices', (0, 1, 2, 3)))
def forward(self, x) -> List[Tensor]:
return self.forward_pyramid_features(x)
def checkpoint_filter_fn(state_dict, model):
@ -551,25 +572,15 @@ def checkpoint_filter_fn(state_dict, model):
def _create_davit(variant, pretrained=False, **kwargs):
model_cls = DaViT
features_only = False
kwargs_filter = None
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices)
if kwargs.pop('features_only', False):
model_cls = DaViTFeatures
kwargs_filter = ('num_classes', 'global_pool')
features_only = True
model = build_model_with_cfg(
model_cls,
DaViT,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs)
if features_only:
model.pretrained_cfg = pretrained_cfg_for_features(model.default_cfg)
model.default_cfg = model.pretrained_cfg # backwards compat
return model

Loading…
Cancel
Save