From 90164495b3645d4b92fc616321339bf3f368a2b9 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 10 Dec 2022 00:17:45 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index ef31be5a..99af9bca 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -389,8 +389,8 @@ class DaViTStage(nn.Module): for attention_id, attention_type in enumerate(attention_types): if attention_type == 'channel': dual_attention_block.append(ChannelBlock( - dim=self.embed_dims[item], - num_heads=self.num_heads[item], + dim=embed_dims[item], + num_heads=num_heads[item], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id], @@ -400,8 +400,8 @@ class DaViTStage(nn.Module): )) elif attention_type == 'spatial': dual_attention_block.append(SpatialBlock( - dim=self.embed_dims[item], - num_heads=self.num_heads[item], + dim=embed_dims[item], + num_heads=num_heads[item], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id],