diff --git a/timm/models/davit.py b/timm/models/davit.py index ef31be5a..99af9bca 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -389,8 +389,8 @@ class DaViTStage(nn.Module): for attention_id, attention_type in enumerate(attention_types): if attention_type == 'channel': dual_attention_block.append(ChannelBlock( - dim=self.embed_dims[item], - num_heads=self.num_heads[item], + dim=embed_dims[item], + num_heads=num_heads[item], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id], @@ -400,8 +400,8 @@ class DaViTStage(nn.Module): )) elif attention_type == 'spatial': dual_attention_block.append(SpatialBlock( - dim=self.embed_dims[item], - num_heads=self.num_heads[item], + dim=embed_dims[item], + num_heads=num_heads[item], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id],