From 187c051ac03d9edf57ae36861d27488028b1c749 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 10 Dec 2022 00:11:58 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 219 +++++++++++++++++++++++-------------------- 1 file changed, 115 insertions(+), 104 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index eda928e4..0357370f 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -12,8 +12,10 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig # All rights reserved. # This source code is licensed under the MIT license +# FIXME remove unused imports + import itertools -from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union, List +from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, overload, Tuple, TypeVar, Union from collections import OrderedDict import torch @@ -24,7 +26,7 @@ import torch.utils.checkpoint as checkpoint from .features import FeatureInfo from .fx_features import register_notrace_function, register_notrace_module -from .helpers import build_model_with_cfg, pretrained_cfg_for_features +from .helpers import build_model_with_cfg, checkpoint_seq pretrained_cfg_for_features from .layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp from .pretrained import generate_default_cfgs from .registry import register_model @@ -32,6 +34,13 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD __all__ = ['DaViT'] +class SequentialWithSize(nn.Sequential): + def forward(self, x : Tensor, size: Tuple[int, int]): + for module in self._modules.values(): + x, size = module(x, size) + return x, size + + class ConvPosEnc(nn.Module): def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'): @@ -343,6 +352,76 @@ class SpatialBlock(nn.Module): x = x + self.drop_path(self.mlp(self.norm2(x))) return x, size +class DaViTStage(nn.Module): + def __init__( + self, + in_chs, + out_chs, + depth = 1, + patch_size = 4, + overlapped_patch = False, + attention_types = ('spatial', 'channel'), + num_heads = 3, + window_size = 7, + mlp_ratio = 4, + qkv_bias = True, + drop_path_rates = (0, 0), + norm_layer = nn.LayerNorm, + ffn = True, + cpe_act = False + ): + super().__init__() + self.grad_checkpointing = False + + self.patch_embed = PatchEmbed( + patch_size=patch_size, + in_chans=in_chs, + embed_dim=out_chs, + overlapped=overlapped_patch + ) + + stage_blocks = [] + + for block_idx in range(depth): + + dual_attention_block = [] + + for attention_id, attention_type in enumerate(attention_types): + if attention_type == 'channel': + dual_attention_block.append(ChannelBlock( + dim=self.embed_dims[item], + num_heads=self.num_heads[item], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id], + norm_layer=nn.LayerNorm, + ffn=ffn, + cpe_act=cpe_act + )) + elif attention_type == 'spatial': + dual_attention_block.append(SpatialBlock( + dim=self.embed_dims[item], + num_heads=self.num_heads[item], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id], + norm_layer=nn.LayerNorm, + ffn=ffn, + cpe_act=cpe_act, + window_size=window_size, + )) + + stage_blocks.append(SequentialWithSize(*dual_attention_block)) + + self.blocks = SequentialWithSize(*stage_blocks) + + def forward(self, x : Tensor, size: Tuple[int, int]): + x, size = self.patch_embed(x, size) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x, size = checkpoint_seq(self.blocks, x, size) + else: + x, size = self.blocks(x, size) + class DaViT(nn.Module): @@ -392,7 +471,7 @@ class DaViT(nn.Module): self.embed_dims = embed_dims self.num_heads = num_heads self.num_stages = len(self.embed_dims) - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, 2 * len(list(itertools.chain(*self.architecture))))] + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, len(attention_types) * len(list(itertools.chain(*self.architecture))))] assert self.num_stages == len(self.num_heads) == (sorted(list(itertools.chain(*self.architecture)))[-1] + 1) self.num_classes = num_classes @@ -401,47 +480,34 @@ class DaViT(nn.Module): self.grad_checkpointing = False self.feature_info = [] - self.patch_embeds = nn.ModuleList([ - PatchEmbed(patch_size=patch_size if i == 0 else 2, - in_chans=in_chans if i == 0 else self.embed_dims[i - 1], - embed_dim=self.embed_dims[i], - overlapped=overlapped_patch) - for i in range(self.num_stages)]) - - self.stages = nn.ModuleList() - for stage_id, stage_param in enumerate(self.architecture): - layer_offset_id = len(list(itertools.chain(*self.architecture[:stage_id]))) - - stage = nn.ModuleList([ - nn.ModuleList([ - ChannelBlock( - dim=self.embed_dims[item], - num_heads=self.num_heads[item], - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - drop_path=dpr[2 * (layer_id + layer_offset_id) + attention_id], - norm_layer=nn.LayerNorm, - ffn=ffn, - cpe_act=cpe_act - ) if attention_type == 'channel' else - SpatialBlock( - dim=self.embed_dims[item], - num_heads=self.num_heads[item], - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - drop_path=dpr[2 * (layer_id + layer_offset_id) + attention_id], - norm_layer=nn.LayerNorm, - ffn=ffn, - cpe_act=cpe_act, - window_size=window_size, - ) if attention_type == 'spatial' else None - for attention_id, attention_type in enumerate(attention_types)] - ) for layer_id, item in enumerate(stage_param) - ]) + stages = [] + + for stage_id in range(self.num_stages): + stage_drop_rates = dpr[len(attention_types) * sum(depths[:stage_id]):len(attention_types) * sum(depths[:stage_id + 1])] + + stage = DaViTStage( + in_chans if stage_id == 0 else embed_dims[i - 1], + embed_dims[stage_id], + depth = 1, + patch_size = patch_size, + overlapped_patch = overlapped_patch, + attention_types = attention_types, + num_heads = num_heads[stage_id], + window_size = window_size, + mlp_ratio = mlp_ratio, + qkv_bias = qkv_bias, + drop_path_rates = stage_drop_rates, + norm_layer = nn.LayerNorm, + ffn = ffn, + cpe_act = cpe_act + ) - self.stages.add_module(f'stage_{stage_id}', stage) - self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.stage_{stage_id}')] - + stages.append(stage) + self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')] + + + self.stages = SequentialWithSize(*stages) + self.norms = norm_layer(self.num_features) self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) self.apply(self._init_weights) @@ -471,66 +537,21 @@ class DaViT(nn.Module): self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) - def forward_network(self, x): - size: Tuple[int, int] = (x.size(2), x.size(3)) - features = [x] - sizes = [size] - - for patch_layer, stage in zip(self.patch_embeds, self.stages): - features[-1], sizes[-1] = patch_layer(features[-1], sizes[-1]) - for _, block in enumerate(stage): - for _, layer in enumerate(block): - if self.grad_checkpointing and not torch.jit.is_scripting(): - features[-1], sizes[-1] = checkpoint.checkpoint(layer, features[-1], sizes[-1]) - else: - features[-1], sizes[-1] = layer(features[-1], sizes[-1]) - - # don't append outputs of last stage, since they are already there - if(len(features) < self.num_stages): - features.append(features[-1]) - sizes.append(sizes[-1]) - - - # non-normalized pyramid features + corresponding sizes - return features, sizes - - def forward_pyramid_features(self, x) -> List[Tensor]: - x, sizes = self.forward_network(x) - outs = [] - for i, out in enumerate(x): - H, W = sizes[i] - outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous()) - - return outs - def forward_features(self, x): - x, sizes = self.forward_network(x) - # take final feature and norm - x = self.norms(x[-1]) - H, W = sizes[-1] + size: Tuple[int, int] = (x.size(2), x.size(3)) + x, size = self.stages(x, size) + x = self.norms(x) + H, W = size x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous() return x def forward_head(self, x, pre_logits: bool = False): return self.head(x, pre_logits=pre_logits) - def forward_classifier(self, x): + def forward(self, x): x = self.forward_features(x) x = self.forward_head(x) return x - - def forward(self, x): - return self.forward_classifier(x) - - -class DaViTFeatures(DaViT): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_indices', (0, 1, 2, 3))) - - def forward(self, x) -> List[Tensor]: - return self.forward_pyramid_features(x) def checkpoint_filter_fn(state_dict, model): @@ -551,25 +572,15 @@ def checkpoint_filter_fn(state_dict, model): def _create_davit(variant, pretrained=False, **kwargs): - model_cls = DaViT - features_only = False - kwargs_filter = None default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1)))) out_indices = kwargs.pop('out_indices', default_out_indices) - if kwargs.pop('features_only', False): - model_cls = DaViTFeatures - kwargs_filter = ('num_classes', 'global_pool') - features_only = True model = build_model_with_cfg( - model_cls, + DaViT, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs) - if features_only: - model.pretrained_cfg = pretrained_cfg_for_features(model.default_cfg) - model.default_cfg = model.pretrained_cfg # backwards compat return model