From b83430350fb84af79e7aef2e921eab70fa5b0b7c Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 10 Dec 2022 04:10:22 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index dfc4203d..076e84b5 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -372,24 +372,31 @@ class DaViTStage(nn.Module): cpe_act = False ): super().__init__() + self.grad_checkpointing = False + # patch embedding layer at the beginning of each stage self.patch_embed = PatchEmbed( patch_size=patch_size, in_chans=in_chs, embed_dim=out_chs, overlapped=overlapped_patch ) - + ''' + repeating alternating attention blocks in each stage + default: (spatial -> channel) x depth + + potential opportunity to integrate with a more general version of ByobNet/ByoaNet + since the logic is similar + ''' stage_blocks = [] - for block_idx in range(depth): dual_attention_block = [] for attention_id, attention_type in enumerate(attention_types): - if attention_type == 'channel': - dual_attention_block.append(ChannelBlock( + if attention_type == 'spatial': + dual_attention_block.append(SpatialBlock( dim=out_chs, num_heads=num_heads, mlp_ratio=mlp_ratio, @@ -397,10 +404,11 @@ class DaViTStage(nn.Module): drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id], norm_layer=nn.LayerNorm, ffn=ffn, - cpe_act=cpe_act + cpe_act=cpe_act, + window_size=window_size, )) - elif attention_type == 'spatial': - dual_attention_block.append(SpatialBlock( + elif attention_type == 'channel': + dual_attention_block.append(ChannelBlock( dim=out_chs, num_heads=num_heads, mlp_ratio=mlp_ratio, @@ -408,8 +416,7 @@ class DaViTStage(nn.Module): drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id], norm_layer=nn.LayerNorm, ffn=ffn, - cpe_act=cpe_act, - window_size=window_size, + cpe_act=cpe_act )) stage_blocks.append(nn.ModuleList(dual_attention_block))