From 6604ba73d7fd0f359a1279c11cd0237fe625d856 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 10 Dec 2022 00:18:54 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index 99af9bca..51162c5a 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -389,8 +389,8 @@ class DaViTStage(nn.Module): for attention_id, attention_type in enumerate(attention_types): if attention_type == 'channel': dual_attention_block.append(ChannelBlock( - dim=embed_dims[item], - num_heads=num_heads[item], + dim=out_chs, + num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id], @@ -400,8 +400,8 @@ class DaViTStage(nn.Module): )) elif attention_type == 'spatial': dual_attention_block.append(SpatialBlock( - dim=embed_dims[item], - num_heads=num_heads[item], + dim=out_chs, + num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id],