pull/284/merge
Georgi Gerganov 2 years ago committed by GitHub
commit 2e148e374d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1659,7 +1659,7 @@ static bool whisper_encode(
layer.cross_attn_k_w, layer.cross_attn_k_w,
cur); cur);
Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); //Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
struct ggml_tensor * Vcross = ggml_mul_mat(ctx0, struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
layer.cross_attn_v_w, layer.cross_attn_v_w,
@ -1793,14 +1793,14 @@ static bool whisper_decode(
Qcur), Qcur),
Qcur); Qcur);
Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); //Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
// note: no bias for Key // note: no bias for Key
struct ggml_tensor * Kcur = ggml_mul_mat(ctxL, struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
layer.attn_k_w, layer.attn_k_w,
cur); cur);
Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); //Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
struct ggml_tensor * Vcur = ggml_mul_mat(ctxL, struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
layer.attn_v_w, layer.attn_v_w,
@ -1823,6 +1823,33 @@ static bool whisper_decode(
// ------ // ------
#ifdef USE_FLASH_ATTN
struct ggml_tensor * Q =
ggml_permute(ctxL,
ggml_cpy(ctxL,
Qcur,
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
0, 2, 1, 3);
struct ggml_tensor * K =
ggml_permute(ctxL,
ggml_reshape_3d(ctxL,
ggml_view_1d(ctxL, model.memory_k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_k)*n_state),
n_state/n_head, n_head, n_past + N),
0, 2, 1, 3);
struct ggml_tensor * V =
ggml_cpy(ctxL,
ggml_permute(ctxL,
ggml_reshape_3d(ctxL,
ggml_view_1d(ctxL, model.memory_v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_v)*n_state),
n_state/n_head, n_head, n_past + N),
1, 2, 0, 3),
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_past + N, n_state/n_head, n_head)
);
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, true);
#else
struct ggml_tensor * Q = struct ggml_tensor * Q =
ggml_permute(ctxL, ggml_permute(ctxL,
ggml_cpy(ctxL, ggml_cpy(ctxL,
@ -1840,13 +1867,13 @@ static bool whisper_decode(
// K * Q // K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q); struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
//struct ggml_tensor * KQ_scaled = struct ggml_tensor * KQ_scaled =
// ggml_scale(ctxL, ggml_scale(ctxL,
// KQ, KQ,
// ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head)) ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
// ); );
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ, n_past); struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ_scaled, n_past);
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_masked); struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_masked);
@ -1858,6 +1885,7 @@ static bool whisper_decode(
1, 2, 0, 3); 1, 2, 0, 3);
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max); struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
#endif
struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3); struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
@ -1903,7 +1931,7 @@ static bool whisper_decode(
Qcur), Qcur),
Qcur); Qcur);
Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); //Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
// Kcross is already scaled // Kcross is already scaled
struct ggml_tensor * Kcross = struct ggml_tensor * Kcross =
@ -1918,6 +1946,24 @@ static bool whisper_decode(
// ------ // ------
#ifdef USE_FLASH_ATTN
struct ggml_tensor * Q =
ggml_permute(ctxL,
ggml_cpy(ctxL,
Qcur,
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
0, 2, 1, 3);
struct ggml_tensor * K = ggml_permute(ctxL, Kcross, 0, 2, 1, 3);
struct ggml_tensor * V =
ggml_cpy(ctxL,
ggml_permute(ctxL, Vcross, 1, 2, 0, 3),
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, M, n_state/n_head, n_head)
);
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
#else
struct ggml_tensor * Q = struct ggml_tensor * Q =
ggml_permute(ctxL, ggml_permute(ctxL,
ggml_cpy(ctxL, ggml_cpy(ctxL,
@ -1944,6 +1990,7 @@ static bool whisper_decode(
struct ggml_tensor * V_trans = ggml_permute(ctxL, Vcross, 1, 2, 0, 3); struct ggml_tensor * V_trans = ggml_permute(ctxL, Vcross, 1, 2, 0, 3);
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max); struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
#endif
struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3); struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);

Loading…
Cancel
Save