From 1163f266d225cd020568679d1e2c3decadb1e35f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 9 Jan 2023 19:02:13 +0200 Subject: [PATCH] whisper : move probs_id buffer to whisper_context --- whisper.cpp | 46 ++++++++++++++++++++++++---------------------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index d57dd9c..5b7b3fe 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -220,10 +220,6 @@ struct whisper_vocab { std::map token_to_id; std::map id_to_token; - // used to avoid memory allocations during sampling - // TODO: move to whisper_context in the future - std::vector> probs_id; - id token_eot = 50256; id token_sot = 50257; id token_prev = 50360; @@ -387,14 +383,14 @@ struct whisper_model { struct ggml_tensor * e_ln_b; // decoder.positional_embedding - struct ggml_tensor * d_pe; // DD + struct ggml_tensor * d_pe; // decoder.token_embedding - struct ggml_tensor * d_te; // DD + struct ggml_tensor * d_te; // decoder.ln - struct ggml_tensor * d_ln_w; // DD - struct ggml_tensor * d_ln_b; // DD + struct ggml_tensor * d_ln_w; + struct ggml_tensor * d_ln_b; std::vector layers_encoder; std::vector layers_decoder; @@ -451,6 +447,8 @@ struct whisper_context { std::vector work_logits; std::vector work_logprobs; + std::vector> probs_id; + // [EXPERIMENTAL] token-level timestamps data int64_t t_beg; int64_t t_last; @@ -545,10 +543,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con fprintf(stderr, "%s: type = %d\n", __func__, model.type); wctx.buf_model = new std::vector(); - wctx.buf_model->resize(MEM_REQ_MODEL.at(model.type)); - wctx.buf_memory.resize(MEM_REQ_MEMORY.at(model.type)); - wctx.buf_memory_cross.resize(MEM_REQ_MEMORY_CROSS.at(model.type)); - wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); + wctx.buf_model->resize (MEM_REQ_MODEL.at(model.type)); + wctx.buf_memory.resize (MEM_REQ_MEMORY.at(model.type)); + wctx.buf_memory_cross.resize (MEM_REQ_MEMORY_CROSS.at(model.type)); + wctx.buf_compute.resize (std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type))); } @@ -638,7 +636,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con wctx.work_logits.reserve(vocab.n_vocab); wctx.work_logprobs.reserve(vocab.n_vocab); - vocab.probs_id.reserve(n_vocab); + wctx.probs_id.reserve(n_vocab); } { @@ -1008,7 +1006,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con model.memory_v = ggml_new_tensor_1d(ctx, wtype, n_elements); } - const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v); + const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v); fprintf(stderr, "%s: memory size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); } @@ -1900,17 +1898,19 @@ static bool whisper_decode( // the most basic sampling scheme - select the top token static whisper_token_data whisper_sample_best( - whisper_vocab & vocab, + whisper_context & ctx, const float * probs, - bool force_timestamp, - bool is_initial) { + bool force_timestamp, + bool is_initial) { whisper_token_data result = { 0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f, }; + const auto & vocab = ctx.vocab; + const int n_logits = vocab.n_vocab; - auto & probs_id = vocab.probs_id; + auto & probs_id = ctx.probs_id; probs_id.clear(); for (int i = 0; i < n_logits; i++) { @@ -2461,7 +2461,7 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) { const int64_t t_start_sample_us = ggml_time_us(); - const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), false, false); + const auto res = whisper_sample_best(*ctx, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), false, false); ctx->t_sample_us += ggml_time_us() - t_start_sample_us; @@ -2471,7 +2471,7 @@ struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) { struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial) { const int64_t t_start_sample_us = ggml_time_us(); - const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), true, is_initial); + const auto res = whisper_sample_best(*ctx, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), true, is_initial); ctx->t_sample_us += ggml_time_us() - t_start_sample_us; @@ -2558,7 +2558,9 @@ int whisper_lang_auto_detect( return -7; } - std::vector> probs_id; + auto & probs_id = ctx->probs_id; + probs_id.clear(); + for (const auto & kv : g_lang) { const auto token_lang = whisper_token_lang(ctx, kv.second.first); probs_id.emplace_back(ctx->probs[token_lang], kv.second.first); @@ -2566,7 +2568,7 @@ int whisper_lang_auto_detect( // sort descending { - using pair_type = decltype(probs_id)::value_type; + using pair_type = std::remove_reference::type::value_type; std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) { return a.first > b.first; });