diff --git a/whisper.cpp b/whisper.cpp index d52143f..46a074a 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -1659,7 +1659,7 @@ static bool whisper_encode( layer.cross_attn_k_w, cur); - Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); + //Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); struct ggml_tensor * Vcross = ggml_mul_mat(ctx0, layer.cross_attn_v_w, @@ -1793,14 +1793,14 @@ static bool whisper_decode( Qcur), Qcur); - Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); + //Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); // note: no bias for Key struct ggml_tensor * Kcur = ggml_mul_mat(ctxL, layer.attn_k_w, cur); - Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); + //Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); struct ggml_tensor * Vcur = ggml_mul_mat(ctxL, layer.attn_v_w, @@ -1823,6 +1823,33 @@ static bool whisper_decode( // ------ +#ifdef USE_FLASH_ATTN + struct ggml_tensor * Q = + ggml_permute(ctxL, + ggml_cpy(ctxL, + Qcur, + ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)), + 0, 2, 1, 3); + + struct ggml_tensor * K = + ggml_permute(ctxL, + ggml_reshape_3d(ctxL, + ggml_view_1d(ctxL, model.memory_k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_k)*n_state), + n_state/n_head, n_head, n_past + N), + 0, 2, 1, 3); + + struct ggml_tensor * V = + ggml_cpy(ctxL, + ggml_permute(ctxL, + ggml_reshape_3d(ctxL, + ggml_view_1d(ctxL, model.memory_v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_v)*n_state), + n_state/n_head, n_head, n_past + N), + 1, 2, 0, 3), + ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_past + N, n_state/n_head, n_head) + ); + + struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, true); +#else struct ggml_tensor * Q = ggml_permute(ctxL, ggml_cpy(ctxL, @@ -1840,13 +1867,13 @@ static bool whisper_decode( // K * Q struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q); - //struct ggml_tensor * KQ_scaled = - // ggml_scale(ctxL, - // KQ, - // ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head)) - // ); + struct ggml_tensor * KQ_scaled = + ggml_scale(ctxL, + KQ, + ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head)) + ); - struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ, n_past); + struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ_scaled, n_past); struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_masked); @@ -1858,6 +1885,7 @@ static bool whisper_decode( 1, 2, 0, 3); struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max); +#endif struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3); @@ -1903,7 +1931,7 @@ static bool whisper_decode( Qcur), Qcur); - Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); + //Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); // Kcross is already scaled struct ggml_tensor * Kcross = @@ -1918,6 +1946,24 @@ static bool whisper_decode( // ------ +#ifdef USE_FLASH_ATTN + struct ggml_tensor * Q = + ggml_permute(ctxL, + ggml_cpy(ctxL, + Qcur, + ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)), + 0, 2, 1, 3); + + struct ggml_tensor * K = ggml_permute(ctxL, Kcross, 0, 2, 1, 3); + + struct ggml_tensor * V = + ggml_cpy(ctxL, + ggml_permute(ctxL, Vcross, 1, 2, 0, 3), + ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, M, n_state/n_head, n_head) + ); + + struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false); +#else struct ggml_tensor * Q = ggml_permute(ctxL, ggml_cpy(ctxL, @@ -1944,6 +1990,7 @@ static bool whisper_decode( struct ggml_tensor * V_trans = ggml_permute(ctxL, Vcross, 1, 2, 0, 3); struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max); +#endif struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);