Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 52093607f8
commit b83430350f

@ -372,24 +372,31 @@ class DaViTStage(nn.Module):
cpe_act = False
):
super().__init__()
self.grad_checkpointing = False
# patch embedding layer at the beginning of each stage
self.patch_embed = PatchEmbed(
patch_size=patch_size,
in_chans=in_chs,
embed_dim=out_chs,
overlapped=overlapped_patch
)
'''
repeating alternating attention blocks in each stage
default: (spatial -> channel) x depth
potential opportunity to integrate with a more general version of ByobNet/ByoaNet
since the logic is similar
'''
stage_blocks = []
for block_idx in range(depth):
dual_attention_block = []
for attention_id, attention_type in enumerate(attention_types):
if attention_type == 'channel':
dual_attention_block.append(ChannelBlock(
if attention_type == 'spatial':
dual_attention_block.append(SpatialBlock(
dim=out_chs,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
@ -397,10 +404,11 @@ class DaViTStage(nn.Module):
drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id],
norm_layer=nn.LayerNorm,
ffn=ffn,
cpe_act=cpe_act
cpe_act=cpe_act,
window_size=window_size,
))
elif attention_type == 'spatial':
dual_attention_block.append(SpatialBlock(
elif attention_type == 'channel':
dual_attention_block.append(ChannelBlock(
dim=out_chs,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
@ -408,8 +416,7 @@ class DaViTStage(nn.Module):
drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id],
norm_layer=nn.LayerNorm,
ffn=ffn,
cpe_act=cpe_act,
window_size=window_size,
cpe_act=cpe_act
))
stage_blocks.append(nn.ModuleList(dual_attention_block))

Loading…
Cancel
Save