From 01669ee8b892ffc840db173d3eb89f55ffd0ab98 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 29 Jan 2023 09:26:44 +0200 Subject: [PATCH] whisper : reorder scratch buffers in the decoder --- whisper.cpp | 64 +++++++++++++++++++++++++---------------------------- 1 file changed, 30 insertions(+), 34 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index 3c73125..f123ed8 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -1705,18 +1705,18 @@ static bool whisper_encode( } // cur - { - printf("ne0 = %d\n", cur->ne[0]); - printf("ne1 = %d\n", cur->ne[1]); - for (int i = 0; i < 10; ++i) { - printf("%8.4f ", ((float *)(cur->data))[i]); - } - printf("... "); - for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) { - printf("%8.4f ", ((float *)(cur->data))[i]); - } - printf("\n"); - } + //{ + // printf("ne0 = %d\n", cur->ne[0]); + // printf("ne1 = %d\n", cur->ne[1]); + // for (int i = 0; i < 10; ++i) { + // printf("%8.4f ", ((float *)(cur->data))[i]); + // } + // printf("... "); + // for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) { + // printf("%8.4f ", ((float *)(cur->data))[i]); + // } + // printf("\n"); + //} // pre-compute cross-attention memory { @@ -1768,12 +1768,12 @@ static bool whisper_encode( //////////////////////////////////////////////////////////////////////////// - printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, - ggml_used_mem(ctx0)/1024.0/1024.0, - wctx.get_buf_max_mem(0)/1024.0/1024.0, - wctx.get_buf_max_mem(1)/1024.0/1024.0, - wctx.get_buf_max_mem(2)/1024.0/1024.0, - wctx.get_buf_max_mem(3)/1024.0/1024.0); + //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, + // ggml_used_mem(ctx0)/1024.0/1024.0, + // wctx.get_buf_max_mem(0)/1024.0/1024.0, + // wctx.get_buf_max_mem(1)/1024.0/1024.0, + // wctx.get_buf_max_mem(2)/1024.0/1024.0, + // wctx.get_buf_max_mem(3)/1024.0/1024.0); ggml_free(ctx0); @@ -1900,8 +1900,6 @@ static bool whisper_decode( Vcur), Vcur); - wctx.use_buf(ctx0, 0); - // store key and value to memory { struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past)); @@ -1913,7 +1911,7 @@ static bool whisper_decode( // ------ - wctx.use_buf(ctx0, 1); + wctx.use_buf(ctx0, 0); struct ggml_tensor * Q = ggml_permute(ctx0, @@ -1929,12 +1927,12 @@ static bool whisper_decode( n_state/n_head, n_head, n_past + N), 0, 2, 1, 3); - wctx.use_buf(ctx0, 0); + wctx.use_buf(ctx0, 1); // K * Q struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - wctx.use_buf(ctx0, 1); + wctx.use_buf(ctx0, 0); //struct ggml_tensor * KQ_scaled = // ggml_scale(ctx0, @@ -1944,11 +1942,11 @@ static bool whisper_decode( struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past); - wctx.use_buf(ctx0, 0); + wctx.use_buf(ctx0, 1); struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); - wctx.use_buf(ctx0, 1); + wctx.use_buf(ctx0, 0); struct ggml_tensor * V_trans = ggml_permute(ctx0, @@ -1957,14 +1955,12 @@ static bool whisper_decode( n_state/n_head, n_head, n_past + N), 1, 2, 0, 3); - wctx.use_buf(ctx0, 0); + wctx.use_buf(ctx0, 1); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - wctx.use_buf(ctx0, 1); - cur = ggml_cpy(ctx0, KQV_merged, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N)); @@ -2200,12 +2196,12 @@ static bool whisper_decode( memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab); if (N > 1) { - printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, - ggml_used_mem(ctx0)/1024.0/1024.0, - wctx.get_buf_max_mem(0)/1024.0/1024.0, - wctx.get_buf_max_mem(1)/1024.0/1024.0, - wctx.get_buf_max_mem(2)/1024.0/1024.0, - wctx.get_buf_max_mem(3)/1024.0/1024.0); + //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, + // ggml_used_mem(ctx0)/1024.0/1024.0, + // wctx.get_buf_max_mem(0)/1024.0/1024.0, + // wctx.get_buf_max_mem(1)/1024.0/1024.0, + // wctx.get_buf_max_mem(2)/1024.0/1024.0, + // wctx.get_buf_max_mem(3)/1024.0/1024.0); } ggml_free(ctx0);