From 1a1dee46ceab23581d1e44c48fdbd77526d7a8ba Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 22 Jan 2023 13:10:45 +0200 Subject: [PATCH] whisper : use scratch buffers during inference (base model only) --- examples/main/main.cpp | 2 +- whisper.cpp | 636 ++++++++++++++++++++++++----------------- 2 files changed, 372 insertions(+), 266 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 10d6b3c..8deb3f0 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -514,7 +514,7 @@ int main(int argc, char ** argv) { for (int f = 0; f < (int) params.fname_inp.size(); ++f) { const auto fname_inp = params.fname_inp[f]; - const auto fname_outp = f < params.fname_outp.size() && !params.fname_outp[f].empty() ? params.fname_outp[f] : params.fname_inp[f]; + const auto fname_outp = f < (int) params.fname_outp.size() && !params.fname_outp[f].empty() ? params.fname_outp[f] : params.fname_inp[f]; std::vector pcmf32; // mono-channel F32 PCM std::vector> pcmf32s; // stereo-channel F32 PCM diff --git a/whisper.cpp b/whisper.cpp index e1f4553..bf6a30c 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -99,10 +99,36 @@ static void byteswap_tensor(ggml_tensor * tensor) { #define WHISPER_PRINT_DEBUG(...) #endif +#define WHISPER_USE_SCRATCH #define WHISPER_USE_FLASH_ATTN //#define WHISPER_USE_FLASH_FF #define WHISPER_MAX_DECODERS 16 +// TODO: make these part of the whisper_context instead of macros +#if defined(WHISPER_USE_SCRATCH) +static int g_buf_last = 0; +static size_t g_buf_max_size[16] = { 0 }; +#define USE_BUF(wctx, gctx, i) \ + do { \ + size_t last_size = 0; \ + if (i == -1) { \ + last_size = ggml_set_scratch(gctx, { 0, 0, nullptr, }); \ + } else { \ + auto & buf = wctx.buf_scratch[i]; \ + last_size = ggml_set_scratch(gctx, { 0, buf.size(), buf.data(), }); \ + } \ + if (g_buf_last >= 0) { \ + g_buf_max_size[g_buf_last] = std::max(g_buf_max_size[g_buf_last], last_size); \ + } \ + g_buf_last = i; \ + } while (0) + +#define GET_BUF_MAX_MEM(i) g_buf_max_size[i] +#else +#define USE_BUF(wctx, gctx, i) do {} while (0) +#define GET_BUF_MAX_MEM(i) 0 +#endif + // available whisper models enum e_model { MODEL_UNKNOWN, @@ -217,9 +243,33 @@ static const std::map> g_lang = { static const size_t MB = 1024*1024; -static const std::map MEM_REQ_SCRATCH = { +static const std::map MEM_REQ_SCRATCH0 = { + { MODEL_TINY, 132ull*MB }, + { MODEL_BASE, 15ull*MB }, + { MODEL_SMALL, 164ull*MB }, + { MODEL_MEDIUM, 184ull*MB }, + { MODEL_LARGE, 110ull*MB }, +}; + +static const std::map MEM_REQ_SCRATCH1 = { { MODEL_TINY, 132ull*MB }, - { MODEL_BASE, 144ull*MB }, + { MODEL_BASE, 24ull*MB }, + { MODEL_SMALL, 164ull*MB }, + { MODEL_MEDIUM, 184ull*MB }, + { MODEL_LARGE, 110ull*MB }, +}; + +static const std::map MEM_REQ_SCRATCH2 = { + { MODEL_TINY, 132ull*MB }, + { MODEL_BASE, 4ull*MB }, + { MODEL_SMALL, 164ull*MB }, + { MODEL_MEDIUM, 184ull*MB }, + { MODEL_LARGE, 110ull*MB }, +}; + +static const std::map MEM_REQ_SCRATCH3 = { + { MODEL_TINY, 132ull*MB }, + { MODEL_BASE, 4ull*MB }, { MODEL_SMALL, 164ull*MB }, { MODEL_MEDIUM, 184ull*MB }, { MODEL_LARGE, 110ull*MB }, @@ -251,36 +301,20 @@ static const std::map MEM_REQ_KV_CROSS = { static const std::map MEM_REQ_ENCODE = { { MODEL_TINY, 80ull*MB }, - { MODEL_BASE, 128ull*MB }, + { MODEL_BASE, 6ull*MB }, { MODEL_SMALL, 300ull*MB }, { MODEL_MEDIUM, 680ull*MB }, { MODEL_LARGE, 1100ull*MB }, }; -static const std::map MEM_REQ_ENCODE_LAYER = { - { MODEL_TINY, 104ull*MB }, - { MODEL_BASE, 138ull*MB }, - { MODEL_SMALL, 208ull*MB }, - { MODEL_MEDIUM, 280ull*MB }, - { MODEL_LARGE, 354ull*MB }, -}; - static const std::map MEM_REQ_DECODE = { { MODEL_TINY, 200ull*MB }, - { MODEL_BASE, 202ull*MB }, + { MODEL_BASE, 5ull*MB }, { MODEL_SMALL, 204ull*MB }, { MODEL_MEDIUM, 206ull*MB }, { MODEL_LARGE, 208ull*MB }, }; -static const std::map MEM_REQ_DECODE_LAYER = { - { MODEL_TINY, 32ull*MB }, - { MODEL_BASE, 44ull*MB }, - { MODEL_SMALL, 64ull*MB }, - { MODEL_MEDIUM, 84ull*MB }, - { MODEL_LARGE, 110ull*MB }, -}; - struct whisper_mel { int n_len; int n_mel; @@ -564,9 +598,8 @@ struct whisper_context { whisper_decoder decoders[WHISPER_MAX_DECODERS] = {}; // memory buffers used by encode / decode contexts - std::vector buf_scratch; + std::vector buf_scratch[4]; std::vector buf_compute; - std::vector buf_compute_layer; // decode output (2-dimensional array: [n_tokens][n_vocab]) std::vector logits; @@ -753,11 +786,13 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con { // this is the total memory required to run the inference const size_t mem_required = - MEM_REQ_SCRATCH.at (model.type) + + MEM_REQ_SCRATCH0.at (model.type) + + MEM_REQ_SCRATCH1.at (model.type) + + MEM_REQ_SCRATCH2.at (model.type) + + MEM_REQ_SCRATCH3.at (model.type) + scale*MEM_REQ_MODEL.at (model.type) + scale*MEM_REQ_KV_CROSS.at(model.type) + - scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)) + - scale*std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)); + scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)); // this is the memory required by one decoder const size_t mem_required_decoder = @@ -793,9 +828,12 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); } - wctx.buf_scratch.resize (MEM_REQ_SCRATCH.at(model.type)); - wctx.buf_compute.resize (scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); - wctx.buf_compute_layer.resize(scale*std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type))); + wctx.buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(model.type)); + wctx.buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(model.type)); + wctx.buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(model.type)); + wctx.buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(model.type)); + + wctx.buf_compute.resize(scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); } // load mel filters @@ -1328,6 +1366,8 @@ static bool whisper_encode( struct ggml_context * ctx0 = ggml_init(params); + USE_BUF(wctx, ctx0, 0); + struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels); assert(mel->type == GGML_TYPE_F32); { @@ -1346,10 +1386,10 @@ static bool whisper_encode( struct ggml_tensor * cur; - ggml_set_scratch(ctx0, { 0, wctx.buf_scratch.size(), wctx.buf_scratch.data(), }); - // convolution + gelu { + USE_BUF(wctx, ctx0, 1); + cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel); cur = ggml_add(ctx0, ggml_repeat(ctx0, @@ -1357,9 +1397,10 @@ static bool whisper_encode( cur), cur); - cur = ggml_gelu(ctx0, cur); + USE_BUF(wctx, ctx0, 0); + cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur); cur = ggml_add(ctx0, ggml_repeat(ctx0, @@ -1370,7 +1411,7 @@ static bool whisper_encode( cur = ggml_gelu(ctx0, cur); } - ggml_set_scratch(ctx0, { 0, 0, nullptr, }); + USE_BUF(wctx, ctx0, 3); // =================================================================== // NOTE: experimenting with partial evaluation of the encoder (ignore) @@ -1403,157 +1444,158 @@ static bool whisper_encode( for (int il = 0; il < n_layer; ++il) { const auto & layer = model.layers_encoder[il]; - // create separate context for each layer to reduce memory usage - - struct ggml_init_params paramsL; - paramsL.mem_size = wctx.buf_compute_layer.size(); - paramsL.mem_buffer = wctx.buf_compute_layer.data(); - - struct ggml_context * ctxL = ggml_init(paramsL); - - ggml_set_scratch(ctxL, { 0, wctx.buf_scratch.size(), wctx.buf_scratch.data(), }); - // norm { - cur = ggml_norm(ctxL, inpL); + USE_BUF(wctx, ctx0, 0); + + cur = ggml_norm(ctx0, inpL); // cur = ln_0_w*cur + ln_0_b - cur = ggml_add(ctxL, - ggml_mul(ctxL, - ggml_repeat(ctxL, layer.attn_ln_0_w, cur), + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, layer.attn_ln_0_w, cur), cur), - ggml_repeat(ctxL, layer.attn_ln_0_b, cur)); + ggml_repeat(ctx0, layer.attn_ln_0_b, cur)); } - ggml_set_scratch(ctxL, { 0, 0, nullptr, }); - // self-attention { - struct ggml_tensor * Qcur = ggml_mul_mat(ctxL, + USE_BUF(wctx, ctx0, 1); + + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.attn_q_w, cur); - Qcur = ggml_add(ctxL, - ggml_repeat(ctxL, + Qcur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.attn_q_b, Qcur), Qcur); - //Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); + //Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); // note: no bias for Key - struct ggml_tensor * Kcur = ggml_mul_mat(ctxL, + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.attn_k_w, cur); - //Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); + //Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); - struct ggml_tensor * Vcur = ggml_mul_mat(ctxL, + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.attn_v_w, cur); - Vcur = ggml_add(ctxL, - ggml_repeat(ctxL, + Vcur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.attn_v_b, Vcur), Vcur); // ------ + USE_BUF(wctx, ctx0, 0); + #ifdef WHISPER_USE_FLASH_ATTN struct ggml_tensor * Q = - ggml_permute(ctxL, - ggml_cpy(ctxL, + ggml_permute(ctx0, + ggml_cpy(ctx0, Qcur, - ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)), + ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)), 0, 2, 1, 3); struct ggml_tensor * K = - ggml_permute(ctxL, - ggml_cpy(ctxL, + ggml_permute(ctx0, + ggml_cpy(ctx0, Kcur, - ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)), + ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)), 0, 2, 1, 3); struct ggml_tensor * V = - ggml_cpy(ctxL, - ggml_permute(ctxL, - ggml_reshape_3d(ctxL, + ggml_cpy(ctx0, + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, Vcur, n_state/n_head, n_head, n_ctx), 1, 2, 0, 3), - ggml_new_tensor_3d(ctxL, wctx.wtype, n_ctx, n_state/n_head, n_head) + ggml_new_tensor_3d(ctx0, wctx.wtype, n_ctx, n_state/n_head, n_head) ); - struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false); + struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false); #else struct ggml_tensor * Q = - ggml_permute(ctxL, - ggml_cpy(ctxL, + ggml_permute(ctx0, + ggml_cpy(ctx0, Qcur, - ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)), + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)), 0, 2, 1, 3); struct ggml_tensor * K = - ggml_permute(ctxL, - ggml_cpy(ctxL, + ggml_permute(ctx0, + ggml_cpy(ctx0, Kcur, - ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)), + ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)), 0, 2, 1, 3); // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q); + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); struct ggml_tensor * KQ_scaled = - ggml_scale(ctxL, + ggml_scale(ctx0, KQ, - ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head)) + ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head)) ); - struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_scaled); + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_scaled); //struct ggml_tensor * V_trans = - // ggml_permute(ctxL, - // ggml_cpy(ctxL, + // ggml_permute(ctx0, + // ggml_cpy(ctx0, // Vcur, - // ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)), + // ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)), // 1, 2, 0, 3); - //struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max); + //struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); struct ggml_tensor * V = - ggml_cpy(ctxL, - ggml_permute(ctxL, - ggml_reshape_3d(ctxL, + ggml_cpy(ctx0, + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, Vcur, n_state/n_head, n_head, n_ctx), 0, 2, 1, 3), - ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_ctx, n_head) + ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_ctx, n_head) ); - struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, ggml_transpose(ctx0, V), KQ_soft_max); #endif + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3); + USE_BUF(wctx, ctx0, 1); - cur = ggml_cpy(ctxL, + cur = ggml_cpy(ctx0, KQV_merged, - ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, n_ctx)); + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx)); } // projection { - cur = ggml_mul_mat(ctxL, + USE_BUF(wctx, ctx0, 0); + + cur = ggml_mul_mat(ctx0, layer.attn_ln_1_w, cur); - cur = ggml_add(ctxL, - ggml_repeat(ctxL, layer.attn_ln_1_b, cur), + USE_BUF(wctx, ctx0, 1); + + cur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.attn_ln_1_b, cur), cur); } + USE_BUF(wctx, ctx0, 2); + // add the input - cur = ggml_add(ctxL, cur, inpL); + cur = ggml_add(ctx0, cur, inpL); struct ggml_tensor * inpFF = cur; @@ -1561,75 +1603,75 @@ static bool whisper_encode( { // norm { - cur = ggml_norm(ctxL, inpFF); + USE_BUF(wctx, ctx0, 0); + + cur = ggml_norm(ctx0, inpFF); + + USE_BUF(wctx, ctx0, 1); // cur = mlp_ln_w*cur + mlp_ln_b - cur = ggml_add(ctxL, - ggml_mul(ctxL, - ggml_repeat(ctxL, layer.mlp_ln_w, cur), + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, layer.mlp_ln_w, cur), cur), - ggml_repeat(ctxL, layer.mlp_ln_b, cur)); + ggml_repeat(ctx0, layer.mlp_ln_b, cur)); } #ifdef WHISPER_USE_FLASH_FF - cur = ggml_flash_ff(ctxL, - ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, wctx.wtype, n_state, N)), + USE_BUF(wctx, ctx0, 0); + + cur = ggml_flash_ff(ctx0, + ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wctx.wtype, n_state, N)), layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); #else + USE_BUF(wctx, ctx0, 0); + // fully connected - cur = ggml_mul_mat(ctxL, + cur = ggml_mul_mat(ctx0, layer.mlp_0_w, cur); - cur = ggml_add(ctxL, - ggml_repeat(ctxL, layer.mlp_0_b, cur), + USE_BUF(wctx, ctx0, 1); + + cur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.mlp_0_b, cur), cur); + USE_BUF(wctx, ctx0, 0); + // GELU activation - cur = ggml_gelu(ctxL, cur); + cur = ggml_gelu(ctx0, cur); + + USE_BUF(wctx, ctx0, 1); // projection - cur = ggml_mul_mat(ctxL, + cur = ggml_mul_mat(ctx0, layer.mlp_1_w, cur); - cur = ggml_add(ctxL, - ggml_repeat(ctxL, layer.mlp_1_b, cur), + USE_BUF(wctx, ctx0, 0); + + cur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.mlp_1_b, cur), cur); #endif } - // output from this layer - struct ggml_tensor * inpO = ggml_add(ctxL, cur, inpFF); + USE_BUF(wctx, ctx0, 3); - { - struct ggml_cgraph gf = {}; - gf.n_threads = n_threads; - - ggml_build_forward_expand(&gf, inpO); - ggml_graph_compute (ctxL, &gf); - - //ggml_graph_print(&gf); - } - - // TODO: this is a hack to have per-layer computation graphs - need to come up with something better - // input for next layer (inpO -> inpL) - memcpy(inpL->data, inpO->data, ggml_nbytes(inpL)); - inpL->op = GGML_OP_NONE; - inpL->src0 = nullptr; - inpL->src1 = nullptr; - - //printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0); - - ggml_free(ctxL); + inpL = ggml_add(ctx0, cur, inpFF); } cur = inpL; // norm { + USE_BUF(wctx, ctx0, 0); + cur = ggml_norm(ctx0, cur); + USE_BUF(wctx, ctx0, 1); + // cur = ln_f_g*cur + ln_f_b cur = ggml_add(ctx0, ggml_mul(ctx0, @@ -1638,6 +1680,8 @@ static bool whisper_encode( ggml_repeat(ctx0, model.e_ln_b, cur)); } + USE_BUF(wctx, ctx0, -1); + // run the computation { struct ggml_cgraph gf = {}; @@ -1650,18 +1694,18 @@ static bool whisper_encode( } // cur - //{ - // printf("ne0 = %d\n", cur->ne[0]); - // printf("ne1 = %d\n", cur->ne[1]); - // for (int i = 0; i < 10; ++i) { - // printf("%8.4f ", ((float *)(cur->data))[i]); - // } - // printf("... "); - // for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) { - // printf("%8.4f ", ((float *)(cur->data))[i]); - // } - // printf("\n"); - //} + { + printf("ne0 = %d\n", cur->ne[0]); + printf("ne1 = %d\n", cur->ne[1]); + for (int i = 0; i < 10; ++i) { + printf("%8.4f ", ((float *)(cur->data))[i]); + } + printf("... "); + for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) { + printf("%8.4f ", ((float *)(cur->data))[i]); + } + printf("\n"); + } // pre-compute cross-attention memory { @@ -1676,12 +1720,16 @@ static bool whisper_encode( for (int il = 0; il < model.hparams.n_text_layer; ++il) { auto & layer = model.layers_decoder[il]; + USE_BUF(wctx, ctx0, 0); + struct ggml_tensor * Kcross = ggml_mul_mat(ctx0, layer.cross_attn_k_w, cur); Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); + USE_BUF(wctx, ctx0, 1); + struct ggml_tensor * Vcross = ggml_mul_mat(ctx0, layer.cross_attn_v_w, cur); @@ -1692,6 +1740,8 @@ static bool whisper_encode( Vcross), Vcross); + USE_BUF(wctx, ctx0, -1); + //struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); //struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*n_ctx)); @@ -1707,7 +1757,8 @@ static bool whisper_encode( //////////////////////////////////////////////////////////////////////////// - //printf("%s: used_mem = %f MB\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0); + printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, + GET_BUF_MAX_MEM(0)/1024.0/1024.0, GET_BUF_MAX_MEM(1)/1024.0/1024.0, GET_BUF_MAX_MEM(2)/1024.0/1024.0, GET_BUF_MAX_MEM(3)/1024.0/1024.0); ggml_free(ctx0); @@ -1719,7 +1770,7 @@ static bool whisper_encode( // evaluate the decoder // -// given text prompt + audio features -> predicts the probabilities for the next token +// given text prompt + audio features -> computes the logits for the next token // // - model: the model // - n_threads: number of threads to use @@ -1763,6 +1814,9 @@ static bool whisper_decode( struct ggml_context * ctx0 = ggml_init(params); + struct ggml_cgraph gf = {}; + gf.n_threads = n_threads; + struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); memcpy(embd->data, tokens, N*ggml_element_size(embd)); @@ -1771,6 +1825,8 @@ static bool whisper_decode( ((int32_t *) position->data)[i] = n_past + i; } + USE_BUF(wctx, ctx0, 3); + // token encoding + position encoding struct ggml_tensor * cur = ggml_add(ctx0, @@ -1782,211 +1838,252 @@ static bool whisper_decode( for (int il = 0; il < n_layer; ++il) { const auto & layer = model.layers_decoder[il]; - struct ggml_init_params paramsL; - paramsL.mem_size = wctx.buf_compute_layer.size(); - paramsL.mem_buffer = wctx.buf_compute_layer.data(); - - struct ggml_context * ctxL = ggml_init(paramsL); - struct ggml_cgraph gf = {}; - gf.n_threads = n_threads; - // norm { - cur = ggml_norm(ctxL, inpL); + USE_BUF(wctx, ctx0, 0); + + cur = ggml_norm(ctx0, inpL); // cur = ln_0_w*cur + ln_0_b - cur = ggml_add(ctxL, - ggml_mul(ctxL, - ggml_repeat(ctxL, layer.attn_ln_0_w, cur), + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, layer.attn_ln_0_w, cur), cur), - ggml_repeat(ctxL, layer.attn_ln_0_b, cur)); + ggml_repeat(ctx0, layer.attn_ln_0_b, cur)); } // self-attention { - struct ggml_tensor * Qcur = ggml_mul_mat(ctxL, + USE_BUF(wctx, ctx0, 1); + + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.attn_q_w, cur); - Qcur = ggml_add(ctxL, - ggml_repeat(ctxL, + Qcur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.attn_q_b, Qcur), Qcur); - Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); + Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); // note: no bias for Key - struct ggml_tensor * Kcur = ggml_mul_mat(ctxL, + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.attn_k_w, cur); - Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); + Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); - struct ggml_tensor * Vcur = ggml_mul_mat(ctxL, + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.attn_v_w, cur); - Vcur = ggml_add(ctxL, - ggml_repeat(ctxL, + Vcur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.attn_v_b, Vcur), Vcur); + USE_BUF(wctx, ctx0, 0); + // store key and value to memory { - struct ggml_tensor * k = ggml_view_1d(ctxL, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past)); - struct ggml_tensor * v = ggml_view_1d(ctxL, kv_self.v, N*n_state, (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + n_past)); + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past)); + struct ggml_tensor * v = ggml_view_1d(ctx0, kv_self.v, N*n_state, (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + n_past)); - ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Kcur, k)); - ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Vcur, v)); + ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v)); } // ------ + USE_BUF(wctx, ctx0, 1); + struct ggml_tensor * Q = - ggml_permute(ctxL, - ggml_cpy(ctxL, + ggml_permute(ctx0, + ggml_cpy(ctx0, Qcur, - ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)), + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)), 0, 2, 1, 3); struct ggml_tensor * K = - ggml_permute(ctxL, - ggml_reshape_3d(ctxL, - ggml_view_1d(ctxL, kv_self.k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.k)*n_state), + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.k)*n_state), n_state/n_head, n_head, n_past + N), 0, 2, 1, 3); + USE_BUF(wctx, ctx0, 0); + // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q); + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + + USE_BUF(wctx, ctx0, 1); //struct ggml_tensor * KQ_scaled = - // ggml_scale(ctxL, + // ggml_scale(ctx0, // KQ, - // ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head)) + // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head)) // ); - struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ, n_past); + struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past); + + USE_BUF(wctx, ctx0, 0); - struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_masked); + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); + + USE_BUF(wctx, ctx0, 1); struct ggml_tensor * V_trans = - ggml_permute(ctxL, - ggml_reshape_3d(ctxL, - ggml_view_1d(ctxL, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state), + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state), n_state/n_head, n_head, n_past + N), 1, 2, 0, 3); - struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max); + USE_BUF(wctx, ctx0, 0); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); - struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3); + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - cur = ggml_cpy(ctxL, + USE_BUF(wctx, ctx0, 1); + + cur = ggml_cpy(ctx0, KQV_merged, - ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N)); + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N)); } + // projection { - cur = ggml_mul_mat(ctxL, + USE_BUF(wctx, ctx0, 0); + + cur = ggml_mul_mat(ctx0, layer.attn_ln_1_w, cur); - cur = ggml_add(ctxL, - ggml_repeat(ctxL, layer.attn_ln_1_b, cur), + USE_BUF(wctx, ctx0, 1); + + cur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.attn_ln_1_b, cur), cur); } + USE_BUF(wctx, ctx0, 2); + // add the input - struct ggml_tensor * inpCA = ggml_add(ctxL, cur, inpL); + struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL); // norm { - cur = ggml_norm(ctxL, inpCA); // note: we use inpCA here + USE_BUF(wctx, ctx0, 0); + + cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here + + USE_BUF(wctx, ctx0, 1); // cur = ln_0_w*cur + ln_0_b - cur = ggml_add(ctxL, - ggml_mul(ctxL, - ggml_repeat(ctxL, layer.cross_attn_ln_0_w, cur), + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, layer.cross_attn_ln_0_w, cur), cur), - ggml_repeat(ctxL, layer.cross_attn_ln_0_b, cur)); + ggml_repeat(ctx0, layer.cross_attn_ln_0_b, cur)); } // cross-attention { - struct ggml_tensor * Qcur = ggml_mul_mat(ctxL, + USE_BUF(wctx, ctx0, 0); + + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.cross_attn_q_w, cur); - Qcur = ggml_add(ctxL, - ggml_repeat(ctxL, + Qcur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.cross_attn_q_b, Qcur), Qcur); - Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); + Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); // Kcross is already scaled struct ggml_tensor * Kcross = - ggml_reshape_3d(ctxL, - ggml_view_1d(ctxL, wctx.kv_cross.k, M*n_state, il*M*ggml_element_size(wctx.kv_cross.k)*n_state), + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, wctx.kv_cross.k, M*n_state, il*M*ggml_element_size(wctx.kv_cross.k)*n_state), n_state/n_head, n_head, M); struct ggml_tensor * Vcross = - ggml_reshape_3d(ctxL, - ggml_view_1d(ctxL, wctx.kv_cross.v, M*n_state, il*M*ggml_element_size(wctx.kv_cross.v)*n_state), + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, wctx.kv_cross.v, M*n_state, il*M*ggml_element_size(wctx.kv_cross.v)*n_state), n_state/n_head, n_head, M); + struct ggml_tensor * V_trans = ggml_permute(ctx0, Vcross, 1, 2, 0, 3); + // ------ + USE_BUF(wctx, ctx0, 1); + struct ggml_tensor * Q = - ggml_permute(ctxL, - ggml_cpy(ctxL, + ggml_permute(ctx0, + ggml_cpy(ctx0, Qcur, - ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)), + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)), 0, 2, 1, 3); - struct ggml_tensor * K = ggml_permute(ctxL, Kcross, 0, 2, 1, 3); + struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3); + + USE_BUF(wctx, ctx0, 0); // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q); + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); //struct ggml_tensor * KQ_scaled = - // ggml_scale(ctxL, + // ggml_scale(ctx0, // KQ, - // ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head)) + // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head)) // ); // no masking for cross-attention - //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ_scaled, n_past); + //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); + + USE_BUF(wctx, ctx0, 1); - struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ); + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ); - struct ggml_tensor * V_trans = ggml_permute(ctxL, Vcross, 1, 2, 0, 3); + USE_BUF(wctx, ctx0, 0); - struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); - struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3); + USE_BUF(wctx, ctx0, 1); + + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); // cur = KQV_merged.contiguous().view(n_state, N) - cur = ggml_cpy(ctxL, + cur = ggml_cpy(ctx0, KQV_merged, - ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N)); + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N)); } // projection { - cur = ggml_mul_mat(ctxL, + USE_BUF(wctx, ctx0, 0); + + cur = ggml_mul_mat(ctx0, layer.cross_attn_ln_1_w, cur); - cur = ggml_add(ctxL, - ggml_repeat(ctxL, layer.cross_attn_ln_1_b, cur), + USE_BUF(wctx, ctx0, 1); + + cur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur), cur); } + USE_BUF(wctx, ctx0, 2); + // add the input - cur = ggml_add(ctxL, cur, inpCA); + cur = ggml_add(ctx0, cur, inpCA); struct ggml_tensor * inpFF = cur; @@ -1994,68 +2091,67 @@ static bool whisper_decode( { // norm { - cur = ggml_norm(ctxL, inpFF); + USE_BUF(wctx, ctx0, 0); + + cur = ggml_norm(ctx0, inpFF); + + USE_BUF(wctx, ctx0, 1); // cur = mlp_ln_w*cur + mlp_ln_b - cur = ggml_add(ctxL, - ggml_mul(ctxL, - ggml_repeat(ctxL, layer.mlp_ln_w, cur), + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, layer.mlp_ln_w, cur), cur), - ggml_repeat(ctxL, layer.mlp_ln_b, cur)); + ggml_repeat(ctx0, layer.mlp_ln_b, cur)); } + USE_BUF(wctx, ctx0, 0); + // fully connected - cur = ggml_mul_mat(ctxL, + cur = ggml_mul_mat(ctx0, layer.mlp_0_w, cur); - cur = ggml_add(ctxL, - ggml_repeat(ctxL, layer.mlp_0_b, cur), + USE_BUF(wctx, ctx0, 1); + + cur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.mlp_0_b, cur), cur); + USE_BUF(wctx, ctx0, 0); + // GELU activation - cur = ggml_gelu(ctxL, cur); + cur = ggml_gelu(ctx0, cur); + + USE_BUF(wctx, ctx0, 1); // projection - cur = ggml_mul_mat(ctxL, + cur = ggml_mul_mat(ctx0, layer.mlp_1_w, cur); - cur = ggml_add(ctxL, - ggml_repeat(ctxL, layer.mlp_1_b, cur), - cur); - } - - // output from this layer - struct ggml_tensor * inpO = ggml_add(ctxL, cur, inpFF); - - { - ggml_build_forward_expand(&gf, inpO); - ggml_graph_compute (ctxL, &gf); + USE_BUF(wctx, ctx0, 0); - //ggml_graph_print(&gf); + cur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.mlp_1_b, cur), + cur); } - // TODO: this is a hack to have per-layer computation graphs - need to come up with something better - // input for next layer (inpO -> inpL) - memcpy(inpL->data, inpO->data, ggml_nbytes(inpL)); - inpL->op = GGML_OP_NONE; - inpL->src0 = nullptr; - inpL->src1 = nullptr; - - if (N > 1) { - //printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0); - } + USE_BUF(wctx, ctx0, 3); - ggml_free(ctxL); + inpL = ggml_add(ctx0, cur, inpFF); } cur = inpL; // norm { + USE_BUF(wctx, ctx0, 0); + cur = ggml_norm(ctx0, cur); + USE_BUF(wctx, ctx0, 1); + cur = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.d_ln_w, cur), @@ -2063,24 +2159,34 @@ static bool whisper_decode( ggml_repeat(ctx0, model.d_ln_b, cur)); } + USE_BUF(wctx, ctx0, 0); + + // compute logits only for the last token + // comment this line to compute logits for all N tokens + // might be useful in the future + cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]); + struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur); + USE_BUF(wctx, ctx0, -1); + // run the computation { - struct ggml_cgraph gf = {}; - gf.n_threads = n_threads; - ggml_build_forward_expand(&gf, logits); ggml_graph_compute (ctx0, &gf); } - logits_out.resize(N*n_vocab); - memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab); + // extract logits for all N tokens + //logits_out.resize(N*n_vocab); + //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab); + + // extract logits only for the last token + logits_out.resize(n_vocab); + memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab); if (N > 1) { - //const float mem_per_token = ggml_used_mem(ctx0)/1024.0/1024.0/N; - //printf("%s: used_mem = %f MB / %f per token\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, mem_per_token); - //printf("%s: max mem = %f MB\n", __func__, mem_per_token*model.hparams.n_text_ctx); + printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, + GET_BUF_MAX_MEM(0)/1024.0/1024.0, GET_BUF_MAX_MEM(1)/1024.0/1024.0, GET_BUF_MAX_MEM(2)/1024.0/1024.0, GET_BUF_MAX_MEM(3)/1024.0/1024.0); } ggml_free(ctx0);