Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 131259f065
commit 90164495b3

@ -389,8 +389,8 @@ class DaViTStage(nn.Module):
for attention_id, attention_type in enumerate(attention_types):
if attention_type == 'channel':
dual_attention_block.append(ChannelBlock(
dim=self.embed_dims[item],
num_heads=self.num_heads[item],
dim=embed_dims[item],
num_heads=num_heads[item],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id],
@ -400,8 +400,8 @@ class DaViTStage(nn.Module):
))
elif attention_type == 'spatial':
dual_attention_block.append(SpatialBlock(
dim=self.embed_dims[item],
num_heads=self.num_heads[item],
dim=embed_dims[item],
num_heads=num_heads[item],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id],

Loading…
Cancel
Save