whisper : experiments with Flash Attention in the decoder

pull/284/head
Georgi Gerganov 1 year ago
parent f30b5d322c
commit e2aa556a99
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -1457,7 +1457,7 @@ static bool whisper_encode(
layer.cross_attn_k_w,
cur);
Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
//Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
layer.cross_attn_v_w,
@ -1579,14 +1579,14 @@ static bool whisper_decode(
Qcur),
Qcur);
Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
//Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
// note: no bias for Key
struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
layer.attn_k_w,
cur);
Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
//Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
layer.attn_v_w,
@ -1609,6 +1609,33 @@ static bool whisper_decode(
// ------
#ifdef USE_FLASH_ATTN
struct ggml_tensor * Q =
ggml_permute(ctxL,
ggml_cpy(ctxL,
Qcur,
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
0, 2, 1, 3);
struct ggml_tensor * K =
ggml_permute(ctxL,
ggml_reshape_3d(ctxL,
ggml_view_1d(ctxL, model.memory_k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_k)*n_state),
n_state/n_head, n_head, n_past + N),
0, 2, 1, 3);
struct ggml_tensor * V =
ggml_cpy(ctxL,
ggml_permute(ctxL,
ggml_reshape_3d(ctxL,
ggml_view_1d(ctxL, model.memory_v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_v)*n_state),
n_state/n_head, n_head, n_past + N),
1, 2, 0, 3),
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_past + N, n_state/n_head, n_head)
);
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, true);
#else
struct ggml_tensor * Q =
ggml_permute(ctxL,
ggml_cpy(ctxL,
@ -1626,13 +1653,13 @@ static bool whisper_decode(
// K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
//struct ggml_tensor * KQ_scaled =
// ggml_scale(ctxL,
// KQ,
// ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
// );
struct ggml_tensor * KQ_scaled =
ggml_scale(ctxL,
KQ,
ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
);
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ, n_past);
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ_scaled, n_past);
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_masked);
@ -1644,6 +1671,7 @@ static bool whisper_decode(
1, 2, 0, 3);
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
#endif
struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
@ -1689,7 +1717,7 @@ static bool whisper_decode(
Qcur),
Qcur);
Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
//Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
// Kcross is already scaled
struct ggml_tensor * Kcross =
@ -1704,6 +1732,24 @@ static bool whisper_decode(
// ------
#ifdef USE_FLASH_ATTN
struct ggml_tensor * Q =
ggml_permute(ctxL,
ggml_cpy(ctxL,
Qcur,
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
0, 2, 1, 3);
struct ggml_tensor * K = ggml_permute(ctxL, Kcross, 0, 2, 1, 3);
struct ggml_tensor * V =
ggml_cpy(ctxL,
ggml_permute(ctxL, Vcross, 1, 2, 0, 3),
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, M, n_state/n_head, n_head)
);
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
#else
struct ggml_tensor * Q =
ggml_permute(ctxL,
ggml_cpy(ctxL,
@ -1730,6 +1776,7 @@ static bool whisper_decode(
struct ggml_tensor * V_trans = ggml_permute(ctxL, Vcross, 1, 2, 0, 3);
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
#endif
struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);

Loading…
Cancel
Save