|
|
|
@ -372,24 +372,31 @@ class DaViTStage(nn.Module):
|
|
|
|
|
cpe_act = False
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
self.grad_checkpointing = False
|
|
|
|
|
|
|
|
|
|
# patch embedding layer at the beginning of each stage
|
|
|
|
|
self.patch_embed = PatchEmbed(
|
|
|
|
|
patch_size=patch_size,
|
|
|
|
|
in_chans=in_chs,
|
|
|
|
|
embed_dim=out_chs,
|
|
|
|
|
overlapped=overlapped_patch
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
repeating alternating attention blocks in each stage
|
|
|
|
|
default: (spatial -> channel) x depth
|
|
|
|
|
|
|
|
|
|
potential opportunity to integrate with a more general version of ByobNet/ByoaNet
|
|
|
|
|
since the logic is similar
|
|
|
|
|
'''
|
|
|
|
|
stage_blocks = []
|
|
|
|
|
|
|
|
|
|
for block_idx in range(depth):
|
|
|
|
|
|
|
|
|
|
dual_attention_block = []
|
|
|
|
|
|
|
|
|
|
for attention_id, attention_type in enumerate(attention_types):
|
|
|
|
|
if attention_type == 'channel':
|
|
|
|
|
dual_attention_block.append(ChannelBlock(
|
|
|
|
|
if attention_type == 'spatial':
|
|
|
|
|
dual_attention_block.append(SpatialBlock(
|
|
|
|
|
dim=out_chs,
|
|
|
|
|
num_heads=num_heads,
|
|
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
|
@ -397,10 +404,11 @@ class DaViTStage(nn.Module):
|
|
|
|
|
drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id],
|
|
|
|
|
norm_layer=nn.LayerNorm,
|
|
|
|
|
ffn=ffn,
|
|
|
|
|
cpe_act=cpe_act
|
|
|
|
|
cpe_act=cpe_act,
|
|
|
|
|
window_size=window_size,
|
|
|
|
|
))
|
|
|
|
|
elif attention_type == 'spatial':
|
|
|
|
|
dual_attention_block.append(SpatialBlock(
|
|
|
|
|
elif attention_type == 'channel':
|
|
|
|
|
dual_attention_block.append(ChannelBlock(
|
|
|
|
|
dim=out_chs,
|
|
|
|
|
num_heads=num_heads,
|
|
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
|
@ -408,8 +416,7 @@ class DaViTStage(nn.Module):
|
|
|
|
|
drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id],
|
|
|
|
|
norm_layer=nn.LayerNorm,
|
|
|
|
|
ffn=ffn,
|
|
|
|
|
cpe_act=cpe_act,
|
|
|
|
|
window_size=window_size,
|
|
|
|
|
cpe_act=cpe_act
|
|
|
|
|
))
|
|
|
|
|
|
|
|
|
|
stage_blocks.append(nn.ModuleList(dual_attention_block))
|
|
|
|
|