whisper : reorder scratch buffers in the decoder

pull/431/head
Georgi Gerganov 2 years ago
parent 62205aedb9
commit 01669ee8b8
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -1705,18 +1705,18 @@ static bool whisper_encode(
} }
// cur // cur
{ //{
printf("ne0 = %d\n", cur->ne[0]); // printf("ne0 = %d\n", cur->ne[0]);
printf("ne1 = %d\n", cur->ne[1]); // printf("ne1 = %d\n", cur->ne[1]);
for (int i = 0; i < 10; ++i) { // for (int i = 0; i < 10; ++i) {
printf("%8.4f ", ((float *)(cur->data))[i]); // printf("%8.4f ", ((float *)(cur->data))[i]);
} // }
printf("... "); // printf("... ");
for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) { // for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
printf("%8.4f ", ((float *)(cur->data))[i]); // printf("%8.4f ", ((float *)(cur->data))[i]);
} // }
printf("\n"); // printf("\n");
} //}
// pre-compute cross-attention memory // pre-compute cross-attention memory
{ {
@ -1768,12 +1768,12 @@ static bool whisper_encode(
//////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////
printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
ggml_used_mem(ctx0)/1024.0/1024.0, // ggml_used_mem(ctx0)/1024.0/1024.0,
wctx.get_buf_max_mem(0)/1024.0/1024.0, // wctx.get_buf_max_mem(0)/1024.0/1024.0,
wctx.get_buf_max_mem(1)/1024.0/1024.0, // wctx.get_buf_max_mem(1)/1024.0/1024.0,
wctx.get_buf_max_mem(2)/1024.0/1024.0, // wctx.get_buf_max_mem(2)/1024.0/1024.0,
wctx.get_buf_max_mem(3)/1024.0/1024.0); // wctx.get_buf_max_mem(3)/1024.0/1024.0);
ggml_free(ctx0); ggml_free(ctx0);
@ -1900,8 +1900,6 @@ static bool whisper_decode(
Vcur), Vcur),
Vcur); Vcur);
wctx.use_buf(ctx0, 0);
// store key and value to memory // store key and value to memory
{ {
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past)); struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past));
@ -1913,7 +1911,7 @@ static bool whisper_decode(
// ------ // ------
wctx.use_buf(ctx0, 1); wctx.use_buf(ctx0, 0);
struct ggml_tensor * Q = struct ggml_tensor * Q =
ggml_permute(ctx0, ggml_permute(ctx0,
@ -1929,12 +1927,12 @@ static bool whisper_decode(
n_state/n_head, n_head, n_past + N), n_state/n_head, n_head, n_past + N),
0, 2, 1, 3); 0, 2, 1, 3);
wctx.use_buf(ctx0, 0); wctx.use_buf(ctx0, 1);
// K * Q // K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
wctx.use_buf(ctx0, 1); wctx.use_buf(ctx0, 0);
//struct ggml_tensor * KQ_scaled = //struct ggml_tensor * KQ_scaled =
// ggml_scale(ctx0, // ggml_scale(ctx0,
@ -1944,11 +1942,11 @@ static bool whisper_decode(
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past); struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
wctx.use_buf(ctx0, 0); wctx.use_buf(ctx0, 1);
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
wctx.use_buf(ctx0, 1); wctx.use_buf(ctx0, 0);
struct ggml_tensor * V_trans = struct ggml_tensor * V_trans =
ggml_permute(ctx0, ggml_permute(ctx0,
@ -1957,14 +1955,12 @@ static bool whisper_decode(
n_state/n_head, n_head, n_past + N), n_state/n_head, n_head, n_past + N),
1, 2, 0, 3); 1, 2, 0, 3);
wctx.use_buf(ctx0, 0); wctx.use_buf(ctx0, 1);
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
wctx.use_buf(ctx0, 1);
cur = ggml_cpy(ctx0, cur = ggml_cpy(ctx0,
KQV_merged, KQV_merged,
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N)); ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N));
@ -2200,12 +2196,12 @@ static bool whisper_decode(
memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab); memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
if (N > 1) { if (N > 1) {
printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
ggml_used_mem(ctx0)/1024.0/1024.0, // ggml_used_mem(ctx0)/1024.0/1024.0,
wctx.get_buf_max_mem(0)/1024.0/1024.0, // wctx.get_buf_max_mem(0)/1024.0/1024.0,
wctx.get_buf_max_mem(1)/1024.0/1024.0, // wctx.get_buf_max_mem(1)/1024.0/1024.0,
wctx.get_buf_max_mem(2)/1024.0/1024.0, // wctx.get_buf_max_mem(2)/1024.0/1024.0,
wctx.get_buf_max_mem(3)/1024.0/1024.0); // wctx.get_buf_max_mem(3)/1024.0/1024.0);
} }
ggml_free(ctx0); ggml_free(ctx0);

Loading…
Cancel
Save