diff --git a/timm/models/davit.py b/timm/models/davit.py index f6a1a0ae..bf8a8377 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -387,7 +387,7 @@ class DaViTStage(nn.Module): mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id], - norm_layer=nn.LayerNorm, + norm_layer=norm_layer, ffn=ffn, cpe_act=cpe_act, window_size=window_size, @@ -399,7 +399,7 @@ class DaViTStage(nn.Module): mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id], - norm_layer=nn.LayerNorm, + norm_layer=norm_layer, ffn=ffn, cpe_act=cpe_act ))