Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 52093607f8
commit b83430350f

@ -372,24 +372,31 @@ class DaViTStage(nn.Module):
cpe_act = False cpe_act = False
): ):
super().__init__() super().__init__()
self.grad_checkpointing = False self.grad_checkpointing = False
# patch embedding layer at the beginning of each stage
self.patch_embed = PatchEmbed( self.patch_embed = PatchEmbed(
patch_size=patch_size, patch_size=patch_size,
in_chans=in_chs, in_chans=in_chs,
embed_dim=out_chs, embed_dim=out_chs,
overlapped=overlapped_patch overlapped=overlapped_patch
) )
'''
repeating alternating attention blocks in each stage
default: (spatial -> channel) x depth
potential opportunity to integrate with a more general version of ByobNet/ByoaNet
since the logic is similar
'''
stage_blocks = [] stage_blocks = []
for block_idx in range(depth): for block_idx in range(depth):
dual_attention_block = [] dual_attention_block = []
for attention_id, attention_type in enumerate(attention_types): for attention_id, attention_type in enumerate(attention_types):
if attention_type == 'channel': if attention_type == 'spatial':
dual_attention_block.append(ChannelBlock( dual_attention_block.append(SpatialBlock(
dim=out_chs, dim=out_chs,
num_heads=num_heads, num_heads=num_heads,
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
@ -397,10 +404,11 @@ class DaViTStage(nn.Module):
drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id], drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id],
norm_layer=nn.LayerNorm, norm_layer=nn.LayerNorm,
ffn=ffn, ffn=ffn,
cpe_act=cpe_act cpe_act=cpe_act,
window_size=window_size,
)) ))
elif attention_type == 'spatial': elif attention_type == 'channel':
dual_attention_block.append(SpatialBlock( dual_attention_block.append(ChannelBlock(
dim=out_chs, dim=out_chs,
num_heads=num_heads, num_heads=num_heads,
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
@ -408,8 +416,7 @@ class DaViTStage(nn.Module):
drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id], drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id],
norm_layer=nn.LayerNorm, norm_layer=nn.LayerNorm,
ffn=ffn, ffn=ffn,
cpe_act=cpe_act, cpe_act=cpe_act
window_size=window_size,
)) ))
stage_blocks.append(nn.ModuleList(dual_attention_block)) stage_blocks.append(nn.ModuleList(dual_attention_block))

Loading…
Cancel
Save