diff --git a/timm/models/davit.py b/timm/models/davit.py index 99af9bca..51162c5a 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -389,8 +389,8 @@ class DaViTStage(nn.Module): for attention_id, attention_type in enumerate(attention_types): if attention_type == 'channel': dual_attention_block.append(ChannelBlock( - dim=embed_dims[item], - num_heads=num_heads[item], + dim=out_chs, + num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id], @@ -400,8 +400,8 @@ class DaViTStage(nn.Module): )) elif attention_type == 'spatial': dual_attention_block.append(SpatialBlock( - dim=embed_dims[item], - num_heads=num_heads[item], + dim=out_chs, + num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id],