|
|
|
@ -1457,7 +1457,7 @@ static bool whisper_encode(
|
|
|
|
|
layer.cross_attn_k_w,
|
|
|
|
|
cur);
|
|
|
|
|
|
|
|
|
|
Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
|
|
|
|
//Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
|
|
|
|
|
layer.cross_attn_v_w,
|
|
|
|
@ -1579,14 +1579,14 @@ static bool whisper_decode(
|
|
|
|
|
Qcur),
|
|
|
|
|
Qcur);
|
|
|
|
|
|
|
|
|
|
Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
|
|
|
|
//Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
|
|
|
|
|
|
|
|
|
// note: no bias for Key
|
|
|
|
|
struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
|
|
|
|
|
layer.attn_k_w,
|
|
|
|
|
cur);
|
|
|
|
|
|
|
|
|
|
Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
|
|
|
|
//Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
|
|
|
|
|
layer.attn_v_w,
|
|
|
|
@ -1609,6 +1609,33 @@ static bool whisper_decode(
|
|
|
|
|
|
|
|
|
|
// ------
|
|
|
|
|
|
|
|
|
|
#ifdef USE_FLASH_ATTN
|
|
|
|
|
struct ggml_tensor * Q =
|
|
|
|
|
ggml_permute(ctxL,
|
|
|
|
|
ggml_cpy(ctxL,
|
|
|
|
|
Qcur,
|
|
|
|
|
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
|
|
|
|
|
0, 2, 1, 3);
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * K =
|
|
|
|
|
ggml_permute(ctxL,
|
|
|
|
|
ggml_reshape_3d(ctxL,
|
|
|
|
|
ggml_view_1d(ctxL, model.memory_k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_k)*n_state),
|
|
|
|
|
n_state/n_head, n_head, n_past + N),
|
|
|
|
|
0, 2, 1, 3);
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * V =
|
|
|
|
|
ggml_cpy(ctxL,
|
|
|
|
|
ggml_permute(ctxL,
|
|
|
|
|
ggml_reshape_3d(ctxL,
|
|
|
|
|
ggml_view_1d(ctxL, model.memory_v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_v)*n_state),
|
|
|
|
|
n_state/n_head, n_head, n_past + N),
|
|
|
|
|
1, 2, 0, 3),
|
|
|
|
|
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_past + N, n_state/n_head, n_head)
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, true);
|
|
|
|
|
#else
|
|
|
|
|
struct ggml_tensor * Q =
|
|
|
|
|
ggml_permute(ctxL,
|
|
|
|
|
ggml_cpy(ctxL,
|
|
|
|
@ -1626,13 +1653,13 @@ static bool whisper_decode(
|
|
|
|
|
// K * Q
|
|
|
|
|
struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
|
|
|
|
|
|
|
|
|
|
//struct ggml_tensor * KQ_scaled =
|
|
|
|
|
// ggml_scale(ctxL,
|
|
|
|
|
// KQ,
|
|
|
|
|
// ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
|
|
|
|
|
// );
|
|
|
|
|
struct ggml_tensor * KQ_scaled =
|
|
|
|
|
ggml_scale(ctxL,
|
|
|
|
|
KQ,
|
|
|
|
|
ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ, n_past);
|
|
|
|
|
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ_scaled, n_past);
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_masked);
|
|
|
|
|
|
|
|
|
@ -1644,6 +1671,7 @@ static bool whisper_decode(
|
|
|
|
|
1, 2, 0, 3);
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
|
|
|
|
|
|
|
|
|
@ -1689,7 +1717,7 @@ static bool whisper_decode(
|
|
|
|
|
Qcur),
|
|
|
|
|
Qcur);
|
|
|
|
|
|
|
|
|
|
Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
|
|
|
|
//Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
|
|
|
|
|
|
|
|
|
// Kcross is already scaled
|
|
|
|
|
struct ggml_tensor * Kcross =
|
|
|
|
@ -1704,6 +1732,24 @@ static bool whisper_decode(
|
|
|
|
|
|
|
|
|
|
// ------
|
|
|
|
|
|
|
|
|
|
#ifdef USE_FLASH_ATTN
|
|
|
|
|
struct ggml_tensor * Q =
|
|
|
|
|
ggml_permute(ctxL,
|
|
|
|
|
ggml_cpy(ctxL,
|
|
|
|
|
Qcur,
|
|
|
|
|
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
|
|
|
|
|
0, 2, 1, 3);
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * K = ggml_permute(ctxL, Kcross, 0, 2, 1, 3);
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * V =
|
|
|
|
|
ggml_cpy(ctxL,
|
|
|
|
|
ggml_permute(ctxL, Vcross, 1, 2, 0, 3),
|
|
|
|
|
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, M, n_state/n_head, n_head)
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
|
|
|
|
|
#else
|
|
|
|
|
struct ggml_tensor * Q =
|
|
|
|
|
ggml_permute(ctxL,
|
|
|
|
|
ggml_cpy(ctxL,
|
|
|
|
@ -1730,6 +1776,7 @@ static bool whisper_decode(
|
|
|
|
|
struct ggml_tensor * V_trans = ggml_permute(ctxL, Vcross, 1, 2, 0, 3);
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
|
|
|
|
|
|
|
|
|
|