|
|
|
@ -389,8 +389,8 @@ class DaViTStage(nn.Module):
|
|
|
|
|
for attention_id, attention_type in enumerate(attention_types):
|
|
|
|
|
if attention_type == 'channel':
|
|
|
|
|
dual_attention_block.append(ChannelBlock(
|
|
|
|
|
dim=self.embed_dims[item],
|
|
|
|
|
num_heads=self.num_heads[item],
|
|
|
|
|
dim=embed_dims[item],
|
|
|
|
|
num_heads=num_heads[item],
|
|
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
|
|
qkv_bias=qkv_bias,
|
|
|
|
|
drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id],
|
|
|
|
@ -400,8 +400,8 @@ class DaViTStage(nn.Module):
|
|
|
|
|
))
|
|
|
|
|
elif attention_type == 'spatial':
|
|
|
|
|
dual_attention_block.append(SpatialBlock(
|
|
|
|
|
dim=self.embed_dims[item],
|
|
|
|
|
num_heads=self.num_heads[item],
|
|
|
|
|
dim=embed_dims[item],
|
|
|
|
|
num_heads=num_heads[item],
|
|
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
|
|
qkv_bias=qkv_bias,
|
|
|
|
|
drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id],
|
|
|
|
|