From 59fdcd19c8b24ec6d0bdfab9847ca66c805ed831 Mon Sep 17 00:00:00 2001 From: sandrohanea <40202887+sandrohanea@users.noreply.github.com> Date: Sun, 5 Mar 2023 20:42:19 +0100 Subject: [PATCH] whisper : add whisper_state + default state on the whisper_context (#523) * Added whisper state + default state on the whisper_context * Fixed some examples and bindings * Fixed whisper_n_len (which was used in some binding) and added whisper_n_len_from_state * Fixed comments * whisper : reuse kv_cache_free() and fix compiler warnings * whisper : clean-up the API comments --------- Co-authored-by: Sandro Hanea Co-authored-by: Georgi Gerganov --- bindings/go/whisper.go | 4 +- bindings/ruby/ext/ruby_whisper.cpp | 2 +- examples/addon.node/addon.cpp | 4 +- examples/main/main.cpp | 4 +- whisper.cpp | 982 +++++++++++++++++------------ whisper.h | 158 +++-- 6 files changed, 701 insertions(+), 453 deletions(-) diff --git a/bindings/go/whisper.go b/bindings/go/whisper.go index 78ca07d..d47f7f7 100644 --- a/bindings/go/whisper.go +++ b/bindings/go/whisper.go @@ -20,7 +20,7 @@ extern bool callEncoderBegin(void* user_data); // Text segment callback // Called on every newly generated text segment // Use the whisper_full_...() functions to obtain the text segments -static void whisper_new_segment_cb(struct whisper_context* ctx, int n_new, void* user_data) { +static void whisper_new_segment_cb(struct whisper_context* ctx, struct whisper_state* state, int n_new, void* user_data) { if(user_data != NULL && ctx != NULL) { callNewSegment(user_data, n_new); } @@ -29,7 +29,7 @@ static void whisper_new_segment_cb(struct whisper_context* ctx, int n_new, void* // Encoder begin callback // If not NULL, called before the encoder starts // If it returns false, the computation is aborted -static bool whisper_encoder_begin_cb(struct whisper_context* ctx, void* user_data) { +static bool whisper_encoder_begin_cb(struct whisper_context* ctx, struct whisper_state* state, void* user_data) { if(user_data != NULL && ctx != NULL) { return callEncoderBegin(user_data); } diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index e7416ba..82027d4 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -199,7 +199,7 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { { static bool is_aborted = false; // NOTE: this should be atomic to avoid data race - rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) { + rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { bool is_aborted = *(bool*)user_data; return !is_aborted; }; diff --git a/examples/addon.node/addon.cpp b/examples/addon.node/addon.cpp index 2ef895f..8252327 100644 --- a/examples/addon.node/addon.cpp +++ b/examples/addon.node/addon.cpp @@ -72,7 +72,7 @@ int timestamp_to_sample(int64_t t, int n_samples) { return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100))); } -void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) { +void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) { const auto & params = *((whisper_print_user_data *) user_data)->params; const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s; @@ -260,7 +260,7 @@ int run(whisper_params ¶ms, std::vector> &result) { { static bool is_aborted = false; // NOTE: this should be atomic to avoid data race - wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) { + wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { bool is_aborted = *(bool*)user_data; return !is_aborted; }; diff --git a/examples/main/main.cpp b/examples/main/main.cpp index e1853c6..cd7d928 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -193,7 +193,7 @@ struct whisper_print_user_data { const std::vector> * pcmf32s; }; -void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) { +void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) { const auto & params = *((whisper_print_user_data *) user_data)->params; const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s; @@ -608,7 +608,7 @@ int main(int argc, char ** argv) { { static bool is_aborted = false; // NOTE: this should be atomic to avoid data race - wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) { + wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { bool is_aborted = *(bool*)user_data; return !is_aborted; }; diff --git a/whisper.cpp b/whisper.cpp index b4a0407..c8a904b 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -547,13 +547,11 @@ struct whisper_decoder { std::vector tokens_tmp; // used for whisper_decode calls }; -struct whisper_context { - int64_t t_load_us = 0; - int64_t t_mel_us = 0; +struct whisper_state { int64_t t_sample_us = 0; int64_t t_encode_us = 0; int64_t t_decode_us = 0; - int64_t t_start_us = 0; + int64_t t_mel_us = 0; int32_t n_sample = 0; // number of tokens sampled int32_t n_encode = 0; // number of encoder calls @@ -561,16 +559,10 @@ struct whisper_context { int32_t n_fail_p = 0; // number of logprob threshold failures int32_t n_fail_h = 0; // number of entropy threshold failures - ggml_type wtype; // weight type (FP32 or FP16) - - whisper_mel mel; - - whisper_model model; - whisper_vocab vocab; - // cross-attention KV cache for the decoders // shared between all decoders whisper_kv_cache kv_cross; + whisper_mel mel; whisper_decoder decoders[WHISPER_MAX_DECODERS] = {}; @@ -635,6 +627,18 @@ struct whisper_context { } }; +struct whisper_context { + int64_t t_load_us = 0; + int64_t t_start_us = 0; + + + ggml_type wtype = ggml_type::GGML_TYPE_F16; // weight type (FP32 or FP16) + + whisper_model model; + whisper_vocab vocab; + whisper_state * state = nullptr; +}; + template static void read_safe(whisper_model_loader * loader, T & dest) { loader->read(loader->context, &dest, sizeof(T)); @@ -821,32 +825,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con wctx.model.buf = new std::vector(); wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type)); - if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, wctx.wtype, model.hparams.n_text_ctx)) { - fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); - return false; - } - - { - const size_t memory_size = ggml_nbytes(wctx.decoders[0].kv_self.k) + ggml_nbytes(wctx.decoders[0].kv_self.v); - fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); - } - - if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_CROSS.at(model.type), wctx.kv_cross, wctx.wtype, model.hparams.n_audio_ctx)) { - fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__); - return false; - } - - { - const size_t memory_size = ggml_nbytes(wctx.kv_cross.k) + ggml_nbytes(wctx.kv_cross.v); - fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); - } - - wctx.buf_compute.resize(scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); - - wctx.buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(model.type)); - wctx.buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(model.type)); - wctx.buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(model.type)); - wctx.buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(model.type)); + // we skip initialization of the state until it is needed + // because it might be that state will always be provided externally. } // load mel filters @@ -929,17 +909,6 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con vocab.id_to_token[i] = word; } } - - wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx); - - wctx.logits_id.reserve(n_vocab); - - // TAGS: WHISPER_DECODER_INIT - wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx); - - wctx.decoders[0].probs.reserve (vocab.n_vocab); - wctx.decoders[0].logits.reserve (vocab.n_vocab); - wctx.decoders[0].logprobs.reserve(vocab.n_vocab); } size_t ctx_size = 0; @@ -1339,33 +1308,34 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con } } - wctx.rng = std::mt19937(0); - wctx.t_load_us = ggml_time_us() - t_start_us; return true; } -// evaluate the encoder +// evaluate the encoder with the given state // // given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder // part of the transformer model and returns the encoded features // -// - model: the model +// - wctx: the model +// - wstate: the state of the encoder // - n_threads: number of threads to use // - mel_offset: offset in the mel spectrogram (i.e. audio offset) // -static bool whisper_encode( +static bool whisper_encode_internal( whisper_context & wctx, + whisper_state & wstate, const int mel_offset, - const int n_threads) { + const int n_threads){ + const int64_t t_start_us = ggml_time_us(); const auto & model = wctx.model; - const auto & mel_inp = wctx.mel; + const auto & mel_inp = wstate.mel; const auto & hparams = model.hparams; - const int n_ctx = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx; + const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; const int n_state = hparams.n_audio_state; const int n_head = hparams.n_audio_head; const int n_layer = hparams.n_audio_layer; @@ -1374,12 +1344,12 @@ static bool whisper_encode( assert(mel_inp.n_mel == n_mels); struct ggml_init_params params; - params.mem_size = wctx.buf_compute.size(); - params.mem_buffer = wctx.buf_compute.data(); + params.mem_size = wstate.buf_compute.size(); + params.mem_buffer = wstate.buf_compute.data(); struct ggml_context * ctx0 = ggml_init(params); - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels); assert(mel->type == GGML_TYPE_F32); @@ -1401,30 +1371,30 @@ static bool whisper_encode( // convolution + gelu { - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel); cur = ggml_add(ctx0, - ggml_repeat(ctx0, - model.e_conv_1_b, - cur), - cur); + ggml_repeat(ctx0, + model.e_conv_1_b, + cur), + cur); cur = ggml_gelu(ctx0, cur); - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur); cur = ggml_add(ctx0, - ggml_repeat(ctx0, - model.e_conv_2_b, - cur), - cur); + ggml_repeat(ctx0, + model.e_conv_2_b, + cur), + cur); cur = ggml_gelu(ctx0, cur); } - wctx.use_buf(ctx0, 3); + wstate.use_buf(ctx0, 3); // =================================================================== // NOTE: experimenting with partial evaluation of the encoder (ignore) @@ -1439,7 +1409,7 @@ static bool whisper_encode( //} static int iter = 0; - + const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe); const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter; @@ -1459,54 +1429,54 @@ static bool whisper_encode( // norm { - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_norm(ctx0, inpL); // cur = ln_0_w*cur + ln_0_b cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, layer.attn_ln_0_w, cur), - cur), - ggml_repeat(ctx0, layer.attn_ln_0_b, cur)); + ggml_mul(ctx0, + ggml_repeat(ctx0, layer.attn_ln_0_w, cur), + cur), + ggml_repeat(ctx0, layer.attn_ln_0_b, cur)); } // self-attention { - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, - layer.attn_q_w, - cur); + layer.attn_q_w, + cur); Qcur = ggml_add(ctx0, - ggml_repeat(ctx0, - layer.attn_q_b, - Qcur), - Qcur); + ggml_repeat(ctx0, + layer.attn_q_b, + Qcur), + Qcur); //Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); // note: no bias for Key struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, - layer.attn_k_w, - cur); + layer.attn_k_w, + cur); //Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, - layer.attn_v_w, - cur); + layer.attn_v_w, + cur); Vcur = ggml_add(ctx0, - ggml_repeat(ctx0, - layer.attn_v_b, - Vcur), - Vcur); + ggml_repeat(ctx0, + layer.attn_v_b, + Vcur), + Vcur); // ------ - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); #ifdef WHISPER_USE_FLASH_ATTN struct ggml_tensor * Q = @@ -1583,29 +1553,29 @@ static bool whisper_encode( #endif struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx)); + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx)); } // projection { - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_mul_mat(ctx0, - layer.attn_ln_1_w, - cur); + layer.attn_ln_1_w, + cur); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); cur = ggml_add(ctx0, - ggml_repeat(ctx0, layer.attn_ln_1_b, cur), - cur); + ggml_repeat(ctx0, layer.attn_ln_1_b, cur), + cur); } - wctx.use_buf(ctx0, 2); + wstate.use_buf(ctx0, 2); // add the input cur = ggml_add(ctx0, cur, inpL); @@ -1616,61 +1586,61 @@ static bool whisper_encode( { // norm { - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_norm(ctx0, inpFF); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); // cur = mlp_ln_w*cur + mlp_ln_b cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, layer.mlp_ln_w, cur), - cur), - ggml_repeat(ctx0, layer.mlp_ln_b, cur)); - } + ggml_mul(ctx0, + ggml_repeat(ctx0, layer.mlp_ln_w, cur), + cur), + ggml_repeat(ctx0, layer.mlp_ln_b, cur)); + } #ifdef WHISPER_USE_FLASH_FF - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_flash_ff(ctx0, - ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wctx.wtype, n_state, n_ctx)), - layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); + ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.wtype, n_state, n_ctx)), + layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); #else - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); // fully connected cur = ggml_mul_mat(ctx0, - layer.mlp_0_w, - cur); + layer.mlp_0_w, + cur); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); cur = ggml_add(ctx0, - ggml_repeat(ctx0, layer.mlp_0_b, cur), - cur); + ggml_repeat(ctx0, layer.mlp_0_b, cur), + cur); - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); // GELU activation cur = ggml_gelu(ctx0, cur); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); // projection cur = ggml_mul_mat(ctx0, - layer.mlp_1_w, - cur); + layer.mlp_1_w, + cur); - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_add(ctx0, - ggml_repeat(ctx0, layer.mlp_1_b, cur), - cur); + ggml_repeat(ctx0, layer.mlp_1_b, cur), + cur); #endif - } +} - wctx.use_buf(ctx0, 3); + wstate.use_buf(ctx0, 3); inpL = ggml_add(ctx0, cur, inpFF); } @@ -1679,21 +1649,21 @@ static bool whisper_encode( // norm { - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_norm(ctx0, cur); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); // cur = ln_f_g*cur + ln_f_b cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, model.e_ln_w, cur), - cur), - ggml_repeat(ctx0, model.e_ln_b, cur)); + ggml_mul(ctx0, + ggml_repeat(ctx0, model.e_ln_w, cur), + cur), + ggml_repeat(ctx0, model.e_ln_b, cur)); } - wctx.use_buf(ctx0, -1); + wstate.use_buf(ctx0, -1); // run the computation { @@ -1701,7 +1671,7 @@ static bool whisper_encode( gf.n_threads = n_threads; ggml_build_forward_expand(&gf, cur); - ggml_graph_compute (ctx0, &gf); + ggml_graph_compute(ctx0, &gf); //ggml_graph_print(&gf); } @@ -1731,34 +1701,34 @@ static bool whisper_encode( cur->src1 = nullptr; for (int il = 0; il < model.hparams.n_text_layer; ++il) { - auto & layer = model.layers_decoder[il]; + auto& layer = model.layers_decoder[il]; - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); - struct ggml_tensor * Kcross = ggml_mul_mat(ctx0, - layer.cross_attn_k_w, - cur); + struct ggml_tensor* Kcross = ggml_mul_mat(ctx0, + layer.cross_attn_k_w, + cur); - Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); + Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state) / n_head, -0.25))); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); - struct ggml_tensor * Vcross = ggml_mul_mat(ctx0, - layer.cross_attn_v_w, - cur); + struct ggml_tensor* Vcross = ggml_mul_mat(ctx0, + layer.cross_attn_v_w, + cur); Vcross = ggml_add(ctx0, - ggml_repeat(ctx0, - layer.cross_attn_v_b, - Vcross), - Vcross); + ggml_repeat(ctx0, + layer.cross_attn_v_b, + Vcross), + Vcross); - wctx.use_buf(ctx0, -1); + wstate.use_buf(ctx0, -1); - //struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); - //struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); - struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*n_ctx)); - struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*n_ctx)); + //struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); + //struct ggml_tensor * v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); + struct ggml_tensor* k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx)); + struct ggml_tensor* v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx)); ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k)); ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v)); @@ -1779,8 +1749,8 @@ static bool whisper_encode( ggml_free(ctx0); - wctx.t_encode_us += ggml_time_us() - t_start_us; - wctx.n_encode++; + wstate.t_encode_us += ggml_time_us() - t_start_us; + wstate.n_encode++; return true; } @@ -1795,8 +1765,9 @@ static bool whisper_encode( // - n_tokens: number of tokens in the prompt // - n_past: number of past tokens to prefix the prompt with // -static bool whisper_decode( +static bool whisper_decode_internal( whisper_context & wctx, + whisper_state & wstate, whisper_decoder & decoder, const whisper_token * tokens, const int n_tokens, @@ -1811,7 +1782,7 @@ static bool whisper_decode( WHISPER_ASSERT(!!kv_self.ctx); - auto & logits_out = wctx.logits; + auto & logits_out = wstate.logits; const int n_vocab = hparams.n_vocab; @@ -1821,13 +1792,13 @@ static bool whisper_decode( const int n_layer = hparams.n_text_layer; const int N = n_tokens; - const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx; + const int M = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx); struct ggml_init_params params; - params.mem_size = wctx.buf_compute.size(); - params.mem_buffer = wctx.buf_compute.data(); + params.mem_size = wstate.buf_compute.size(); + params.mem_buffer = wstate.buf_compute.data(); struct ggml_context * ctx0 = ggml_init(params); @@ -1842,7 +1813,7 @@ static bool whisper_decode( ((int32_t *) position->data)[i] = n_past + i; } - wctx.use_buf(ctx0, 3); + wstate.use_buf(ctx0, 3); // token encoding + position encoding struct ggml_tensor * cur = @@ -1857,7 +1828,7 @@ static bool whisper_decode( // norm { - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_norm(ctx0, inpL); @@ -1871,7 +1842,7 @@ static bool whisper_decode( // self-attention { - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.attn_q_w, @@ -1913,7 +1884,7 @@ static bool whisper_decode( // ------ - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); struct ggml_tensor * Q = ggml_permute(ctx0, @@ -1929,12 +1900,12 @@ static bool whisper_decode( n_state/n_head, n_head, n_past + N), 0, 2, 1, 3); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); // K * Q struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); //struct ggml_tensor * KQ_scaled = // ggml_scale(ctx0, @@ -1944,11 +1915,11 @@ static bool whisper_decode( struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); struct ggml_tensor * V_trans = ggml_permute(ctx0, @@ -1957,7 +1928,7 @@ static bool whisper_decode( n_state/n_head, n_head, n_past + N), 1, 2, 0, 3); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); @@ -1970,31 +1941,31 @@ static bool whisper_decode( // projection { - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_mul_mat(ctx0, layer.attn_ln_1_w, cur); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); cur = ggml_add(ctx0, ggml_repeat(ctx0, layer.attn_ln_1_b, cur), cur); } - wctx.use_buf(ctx0, 2); + wstate.use_buf(ctx0, 2); // add the input struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL); // norm { - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); // cur = ln_0_w*cur + ln_0_b cur = ggml_add(ctx0, @@ -2006,7 +1977,7 @@ static bool whisper_decode( // cross-attention { - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.cross_attn_q_w, @@ -2023,19 +1994,19 @@ static bool whisper_decode( // Kcross is already scaled struct ggml_tensor * Kcross = ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, wctx.kv_cross.k, M*n_state, il*M*ggml_element_size(wctx.kv_cross.k)*n_state), + ggml_view_1d(ctx0, wstate.kv_cross.k, M*n_state, il*M*ggml_element_size(wstate.kv_cross.k)*n_state), n_state/n_head, n_head, M); struct ggml_tensor * Vcross = ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, wctx.kv_cross.v, M*n_state, il*M*ggml_element_size(wctx.kv_cross.v)*n_state), + ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state), n_state/n_head, n_head, M); struct ggml_tensor * V_trans = ggml_permute(ctx0, Vcross, 1, 2, 0, 3); // ------ - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); struct ggml_tensor * Q = ggml_permute(ctx0, @@ -2046,7 +2017,7 @@ static bool whisper_decode( struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3); - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); // K * Q struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); @@ -2060,15 +2031,15 @@ static bool whisper_decode( // no masking for cross-attention //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ); - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); @@ -2080,20 +2051,20 @@ static bool whisper_decode( // projection { - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_mul_mat(ctx0, layer.cross_attn_ln_1_w, cur); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); cur = ggml_add(ctx0, ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur), cur); } - wctx.use_buf(ctx0, 2); + wstate.use_buf(ctx0, 2); // add the input cur = ggml_add(ctx0, cur, inpCA); @@ -2104,11 +2075,11 @@ static bool whisper_decode( { // norm { - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_norm(ctx0, inpFF); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); // cur = mlp_ln_w*cur + mlp_ln_b cur = ggml_add(ctx0, @@ -2118,39 +2089,39 @@ static bool whisper_decode( ggml_repeat(ctx0, layer.mlp_ln_b, cur)); } - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); // fully connected cur = ggml_mul_mat(ctx0, layer.mlp_0_w, cur); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); cur = ggml_add(ctx0, ggml_repeat(ctx0, layer.mlp_0_b, cur), cur); - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); // GELU activation cur = ggml_gelu(ctx0, cur); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); // projection cur = ggml_mul_mat(ctx0, layer.mlp_1_w, cur); - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_add(ctx0, ggml_repeat(ctx0, layer.mlp_1_b, cur), cur); } - wctx.use_buf(ctx0, 3); + wstate.use_buf(ctx0, 3); inpL = ggml_add(ctx0, cur, inpFF); } @@ -2159,11 +2130,11 @@ static bool whisper_decode( // norm { - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_norm(ctx0, cur); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); cur = ggml_add(ctx0, ggml_mul(ctx0, @@ -2172,7 +2143,7 @@ static bool whisper_decode( ggml_repeat(ctx0, model.d_ln_b, cur)); } - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); // compute logits only for the last token // comment this line to compute logits for all N tokens @@ -2181,7 +2152,7 @@ static bool whisper_decode( struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur); - wctx.use_buf(ctx0, -1); + wstate.use_buf(ctx0, -1); // run the computation { @@ -2208,8 +2179,8 @@ static bool whisper_decode( ggml_free(ctx0); - wctx.t_decode_us += ggml_time_us() - t_start_us; - wctx.n_decode++; + wstate.t_decode_us += ggml_time_us() - t_start_us; + wstate.n_decode++; return true; } @@ -2313,7 +2284,7 @@ static void fft(const std::vector & in, std::vector & out) { // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124 static bool log_mel_spectrogram( - whisper_context & wctx, + whisper_state & wstate, const float * samples, const int n_samples, const int /*sample_rate*/, @@ -2433,7 +2404,7 @@ static bool log_mel_spectrogram( mel.data[i] = (mel.data[i] + 4.0)/4.0; } - wctx.t_mel_us += ggml_time_us() - t_start_us; + wstate.t_mel_us += ggml_time_us() - t_start_us; return true; } @@ -2507,7 +2478,56 @@ static std::vector tokenize(const whisper_vocab & vocab, cons // interface implementation // -struct whisper_context * whisper_init_from_file(const char * path_model) { +struct whisper_state * whisper_init_state(whisper_context * ctx) { + whisper_state * state = new whisper_state; + + const size_t scale = ctx->model.hparams.f16 ? 1 : 2; + + + if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->wtype, ctx->model.hparams.n_text_ctx)) { + fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); + return nullptr; + } + + { + const size_t memory_size = ggml_nbytes(state->decoders[0].kv_self.k) + ggml_nbytes(state->decoders[0].kv_self.v); + fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); + } + + if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_CROSS.at(ctx->model.type), state->kv_cross, ctx->wtype, ctx->model.hparams.n_audio_ctx)) { + fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__); + return nullptr; + } + + { + const size_t memory_size = ggml_nbytes(state->kv_cross.k) + ggml_nbytes(state->kv_cross.v); + fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); + } + + + state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx); + + state->logits_id.reserve(ctx->model.hparams.n_vocab); + + // TAGS: WHISPER_DECODER_INIT + state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx); + + state->decoders[0].probs.reserve(ctx->vocab.n_vocab); + state->decoders[0].logits.reserve(ctx->vocab.n_vocab); + state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab); + state->buf_compute.resize(scale * std::max(MEM_REQ_ENCODE.at(ctx->model.type), MEM_REQ_DECODE.at(ctx->model.type))); + + state->buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(ctx->model.type)); + state->buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(ctx->model.type)); + state->buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(ctx->model.type)); + state->buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(ctx->model.type)); + + state->rng = std::mt19937(0); + + return state; +} + +struct whisper_context * whisper_init_from_file_no_state(const char * path_model) { whisper_model_loader loader = {}; fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model); @@ -2535,10 +2555,10 @@ struct whisper_context * whisper_init_from_file(const char * path_model) { fin->close(); }; - return whisper_init(&loader); + return whisper_init_no_state(&loader); } -struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) { +struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size) { struct buf_context { uint8_t* buffer; size_t size; @@ -2571,10 +2591,10 @@ struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_s loader.close = [](void * /*ctx*/) { }; - return whisper_init(&loader); + return whisper_init_no_state(&loader); } -struct whisper_context * whisper_init(struct whisper_model_loader * loader) { +struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader) { ggml_time_init(); whisper_context * ctx = new whisper_context; @@ -2591,6 +2611,64 @@ struct whisper_context * whisper_init(struct whisper_model_loader * loader) { return ctx; } +struct whisper_context * whisper_init_from_file(const char * path_model) { + whisper_context * ctx = whisper_init_from_file_no_state(path_model); + if (!ctx) { + return nullptr; + } + + ctx->state = whisper_init_state(ctx); + if (!ctx->state) { + whisper_free(ctx); + return nullptr; + } + + return ctx; +} + +struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) { + whisper_context * ctx = whisper_init_from_buffer_no_state(buffer, buffer_size); + if (!ctx) { + return nullptr; + } + + ctx->state = whisper_init_state(ctx); + if (!ctx->state) { + whisper_free(ctx); + return nullptr; + } + + return ctx; +} + +struct whisper_context * whisper_init(struct whisper_model_loader * loader) { + whisper_context * ctx = whisper_init_no_state(loader); + if (!ctx) { + return nullptr; + } + + ctx->state = whisper_init_state(ctx); + if (!ctx->state) { + whisper_free(ctx); + return nullptr; + } + + return ctx; +} + +void whisper_free_state(struct whisper_state * state) +{ + if (state) { + kv_cache_free(state->kv_cross); + + for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) { + kv_cache_free(state->decoders[i].kv_self); + } + + delete state; + } +} + void whisper_free(struct whisper_context * ctx) { if (ctx) { if (ctx->model.ctx) { @@ -2599,20 +2677,29 @@ void whisper_free(struct whisper_context * ctx) { if (ctx->model.buf) { delete ctx->model.buf; } - if (ctx->kv_cross.ctx) { - ggml_free(ctx->kv_cross.ctx); - } - for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) { - if (ctx->decoders[i].kv_self.ctx) { - ggml_free(ctx->decoders[i].kv_self.ctx); - } - } + + whisper_free_state(ctx->state); + delete ctx; } } +int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { + if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) { + fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__); + return -1; + } + + return 0; +} + int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { - if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) { + return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads); +} + +// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 +int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { + if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, state->mel)) { fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__); return -1; } @@ -2622,11 +2709,26 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { - if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) { - fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__); + return whisper_pcm_to_mel_phase_vocoder_with_state(ctx, ctx->state, samples, n_samples, n_threads); +} + +int whisper_set_mel_with_state( + struct whisper_context * /*ctx*/, + struct whisper_state * state, + const float * data, + int n_len, + int n_mel) { + if (n_mel != WHISPER_N_MEL) { + fprintf(stderr, "%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL); return -1; } + state->mel.n_len = n_len; + state->mel.n_mel = n_mel; + + state->mel.data.resize(n_len*n_mel); + memcpy(state->mel.data.data(), data, n_len*n_mel*sizeof(float)); + return 0; } @@ -2635,22 +2737,20 @@ int whisper_set_mel( const float * data, int n_len, int n_mel) { - if (n_mel != WHISPER_N_MEL) { - fprintf(stderr, "%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL); + return whisper_set_mel_with_state(ctx, ctx->state, data, n_len, n_mel); +} + +int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) { + if (!whisper_encode_internal(*ctx, *state, offset, n_threads)) { + fprintf(stderr, "%s: failed to eval\n", __func__); return -1; } - ctx->mel.n_len = n_len; - ctx->mel.n_mel = n_mel; - - ctx->mel.data.resize(n_len*n_mel); - memcpy(ctx->mel.data.data(), data, n_len*n_mel*sizeof(float)); - return 0; } int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) { - if (!whisper_encode(*ctx, offset, n_threads)) { + if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads)) { fprintf(stderr, "%s: failed to eval\n", __func__); return -1; } @@ -2658,11 +2758,28 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) { return 0; } +int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) { + const int selected_decoder_id = 0; + + if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) { + fprintf(stderr, "%s: failed to eval\n", __func__); + return 1; + } + + return 0; +} + int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) { - // TODO: add selected_decoder_id to context + // TODO: add selected_decoder_id to state const int selected_decoder_id = 0; - if (!whisper_decode(*ctx, ctx->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) { + if (ctx->state == nullptr) { + fprintf(stderr, "%s: ERROR state was not loaded.\n", __func__); + return false; + } + + + if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) { fprintf(stderr, "%s: failed to eval\n", __func__); return 1; } @@ -2720,11 +2837,12 @@ const char * whisper_lang_str(int id) { return nullptr; } -int whisper_lang_auto_detect( +int whisper_lang_auto_detect_with_state( struct whisper_context * ctx, - int offset_ms, - int n_threads, - float * lang_probs) { + struct whisper_state * state, + int offset_ms, + int n_threads, + float * lang_probs) { const int seek = offset_ms/10; if (seek < 0) { @@ -2732,8 +2850,8 @@ int whisper_lang_auto_detect( return -1; } - if (seek >= ctx->mel.n_len) { - fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, ctx->mel.n_len*10); + if (seek >= state->mel.n_len) { + fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len*10); return -2; } @@ -2745,17 +2863,17 @@ int whisper_lang_auto_detect( const std::vector prompt = { whisper_token_sot(ctx) }; - if (whisper_decode(ctx, prompt.data(), prompt.size(), 0, n_threads) != 0) { + if (whisper_decode_with_state(ctx, state, prompt.data(), prompt.size(), 0, n_threads) != 0) { fprintf(stderr, "%s: failed to decode\n", __func__); return -7; } - auto & logits_id = ctx->logits_id; + auto & logits_id = state->logits_id; logits_id.clear(); for (const auto & kv : g_lang) { const auto token_lang = whisper_token_lang(ctx, kv.second.first); - logits_id.emplace_back(ctx->logits[token_lang], kv.second.first); + logits_id.emplace_back(state->logits[token_lang], kv.second.first); } // sort descending @@ -2794,8 +2912,20 @@ int whisper_lang_auto_detect( return logits_id[0].second; } +int whisper_lang_auto_detect( + struct whisper_context * ctx, + int offset_ms, + int n_threads, + float * lang_probs) { + return whisper_lang_auto_detect_with_state(ctx, ctx->state, offset_ms, n_threads, lang_probs); +} + +int whisper_n_len_from_state(struct whisper_state * state) { + return state->mel.n_len; +} + int whisper_n_len(struct whisper_context * ctx) { - return ctx->mel.n_len; + return ctx->state->mel.n_len; } int whisper_n_vocab(struct whisper_context * ctx) { @@ -2815,7 +2945,12 @@ int whisper_is_multilingual(struct whisper_context * ctx) { } float * whisper_get_logits(struct whisper_context * ctx) { - return ctx->logits.data(); + return ctx->state->logits.data(); +} + + +float * whisper_get_logits_from_state(struct whisper_state * state) { + return state->logits.data(); } const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) { @@ -2861,24 +2996,29 @@ whisper_token whisper_token_transcribe(void) { void whisper_print_timings(struct whisper_context * ctx) { const int64_t t_end_us = ggml_time_us(); - const int32_t n_sample = std::max(1, ctx->n_sample); - const int32_t n_encode = std::max(1, ctx->n_encode); - const int32_t n_decode = std::max(1, ctx->n_decode); - fprintf(stderr, "\n"); - fprintf(stderr, "%s: fallbacks = %3d p / %3d h\n", __func__, ctx->n_fail_p, ctx->n_fail_h); - fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us/1000.0f); - fprintf(stderr, "%s: mel time = %8.2f ms\n", __func__, ctx->t_mel_us/1000.0f); - fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_sample_us, n_sample, 1e-3f*ctx->t_sample_us/n_sample); - fprintf(stderr, "%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_encode_us, n_encode, 1e-3f*ctx->t_encode_us/n_encode); - fprintf(stderr, "%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_decode_us, n_decode, 1e-3f*ctx->t_decode_us/n_decode); + fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f); + if (ctx->state != nullptr) { + + const int32_t n_sample = std::max(1, ctx->state->n_sample); + const int32_t n_encode = std::max(1, ctx->state->n_encode); + const int32_t n_decode = std::max(1, ctx->state->n_decode); + + fprintf(stderr, "%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h); + fprintf(stderr, "%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f); + fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample); + fprintf(stderr, "%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode); + fprintf(stderr, "%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode); + } fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); } void whisper_reset_timings(struct whisper_context * ctx) { - ctx->t_sample_us = 0; - ctx->t_encode_us = 0; - ctx->t_decode_us = 0; + if (ctx->state != nullptr) { + ctx->state->t_sample_us = 0; + ctx->state->t_encode_us = 0; + ctx->state->t_decode_us = 0; + } } const char * whisper_print_system_info(void) { @@ -2991,6 +3131,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str static std::vector get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window); static void whisper_exp_compute_token_level_timestamps( struct whisper_context & ctx, + struct whisper_state & state, int i_segment, float thold_pt, float thold_ptsum); @@ -3023,8 +3164,8 @@ static inline bool should_split_on_word(const char * txt, bool split_on_word) { // wrap the last segment to max_len characters // returns the number of new segments -static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool split_on_word) { - auto segment = ctx.result_all.back(); +static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_state & state, int max_len, bool split_on_word) { + auto segment = state.result_all.back(); int res = 1; int acc = 0; @@ -3046,24 +3187,24 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool trim(text); } - ctx.result_all.back().text = std::move(text); - ctx.result_all.back().t1 = token.t0; - ctx.result_all.back().tokens.resize(i); + state.result_all.back().text = std::move(text); + state.result_all.back().t1 = token.t0; + state.result_all.back().tokens.resize(i); - ctx.result_all.push_back({}); - ctx.result_all.back().t0 = token.t0; - ctx.result_all.back().t1 = segment.t1; + state.result_all.push_back({}); + state.result_all.back().t0 = token.t0; + state.result_all.back().t1 = segment.t1; // add tokens [i, end] to the new segment - ctx.result_all.back().tokens.insert( - ctx.result_all.back().tokens.end(), + state.result_all.back().tokens.insert( + state.result_all.back().tokens.end(), segment.tokens.begin() + i, segment.tokens.end()); acc = 0; text = ""; - segment = ctx.result_all.back(); + segment = state.result_all.back(); i = -1; res++; @@ -3076,7 +3217,7 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool if (split_on_word) { trim(text); } - ctx.result_all.back().text = std::move(text); + state.result_all.back().text = std::move(text); return res; } @@ -3093,6 +3234,7 @@ static const std::vector non_speech_tokens = { // - computes logprobs and probs static void whisper_process_logits( struct whisper_context & ctx, + struct whisper_state & state, const struct whisper_full_params params, struct whisper_decoder & decoder, float temperature) { @@ -3111,7 +3253,7 @@ static void whisper_process_logits( auto & logprobs = decoder.logprobs; { logits.resize(n_logits); - memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float)); + memcpy(logits.data(), state.logits.data() + (state.logits.size() - n_logits), n_logits*sizeof(float)); if (temperature > 0.0f) { for (int i = 0; i < n_logits; i++) { @@ -3149,7 +3291,7 @@ static void whisper_process_logits( logits[vocab.token_transcribe] = -INFINITY; if (params.logits_filter_callback) { - params.logits_filter_callback(&ctx, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data); + params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data); } // suppress non-speech tokens @@ -3310,6 +3452,7 @@ static void whisper_process_logits( static whisper_token_data whisper_sample_token( whisper_context & ctx, + whisper_state & state, const whisper_decoder & decoder, bool best) { whisper_token_data result = { @@ -3354,7 +3497,7 @@ static whisper_token_data whisper_sample_token( } else { std::discrete_distribution<> dist(probs.begin(), probs.end()); - result.id = dist(ctx.rng); + result.id = dist(state.rng); result.p = probs[result.id]; result.plog = logprobs[result.id]; } @@ -3364,13 +3507,14 @@ static whisper_token_data whisper_sample_token( result.pt = result.p; } - ctx.n_sample++; + state.n_sample++; return result; } static std::vector whisper_sample_token_topk( whisper_context & ctx, + whisper_state & state, const whisper_decoder & decoder, int k) { const auto & vocab = ctx.vocab; @@ -3381,7 +3525,7 @@ static std::vector whisper_sample_token_topk( const int n_logits = vocab.n_vocab; - auto & logits_id = ctx.logits_id; + auto & logits_id = state.logits_id; logits_id.clear(); for (int i = 0; i < n_logits; ++i) { @@ -3434,7 +3578,7 @@ static std::vector whisper_sample_token_topk( } } - ctx.n_sample++; + state.n_sample++; return result; } @@ -3488,24 +3632,25 @@ static void whisper_sequence_score( } } -int whisper_full( +int whisper_full_with_state( struct whisper_context * ctx, - struct whisper_full_params params, - const float * samples, - int n_samples) { + struct whisper_state * state, + struct whisper_full_params params, + const float * samples, + int n_samples) { // clear old results - auto & result_all = ctx->result_all; + auto & result_all = state->result_all; result_all.clear(); // compute log mel spectrogram if (params.speed_up) { - if (whisper_pcm_to_mel_phase_vocoder(ctx, samples, n_samples, params.n_threads) != 0) { + if (whisper_pcm_to_mel_phase_vocoder_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) { fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__); return -1; } } else { - if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) { + if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) { fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__); return -2; } @@ -3515,26 +3660,26 @@ int whisper_full( if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) { std::vector probs(whisper_lang_max_id() + 1, 0.0f); - const auto lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, probs.data()); + const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data()); if (lang_id < 0) { fprintf(stderr, "%s: failed to auto-detect language\n", __func__); return -3; } - ctx->lang_id = lang_id; + state->lang_id = lang_id; params.language = whisper_lang_str(lang_id); fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]); } if (params.token_timestamps) { - ctx->t_beg = 0; - ctx->t_last = 0; - ctx->tid_last = 0; - ctx->energy = get_signal_energy(samples, n_samples, 32); + state->t_beg = 0; + state->t_last = 0; + state->tid_last = 0; + state->energy = get_signal_energy(samples, n_samples, 32); } const int seek_start = params.offset_ms/10; - const int seek_end = seek_start + (params.duration_ms == 0 ? whisper_n_len(ctx) : params.duration_ms/10); + const int seek_end = seek_start + (params.duration_ms == 0 ? whisper_n_len_from_state(state) : params.duration_ms/10); // if length of spectrogram is less than 1s (100 samples), then return // basically don't process anything that is less than 1s @@ -3572,10 +3717,10 @@ int whisper_full( // TAGS: WHISPER_DECODER_INIT for (int j = 1; j < n_decoders; j++) { - auto & decoder = ctx->decoders[j]; + auto & decoder = state->decoders[j]; if (decoder.kv_self.ctx == nullptr) { - decoder.kv_self = ctx->decoders[0].kv_self; + decoder.kv_self = state->decoders[0].kv_self; if (!kv_cache_reinit(decoder.kv_self)) { fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j); return -4; @@ -3583,7 +3728,7 @@ int whisper_full( WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j); - decoder.sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity()); + decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity()); decoder.probs.resize (ctx->vocab.n_vocab); decoder.logits.resize (ctx->vocab.n_vocab); @@ -3592,7 +3737,7 @@ int whisper_full( } // the accumulated text context so far - auto & prompt_past = ctx->prompt_past; + auto & prompt_past = state->prompt_past; if (params.no_context) { prompt_past.clear(); } @@ -3611,13 +3756,13 @@ int whisper_full( fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx)); return -5; } - ctx->exp_n_audio_ctx = params.audio_ctx; + state->exp_n_audio_ctx = params.audio_ctx; // these tokens determine the task that will be performed std::vector prompt_init = { whisper_token_sot(ctx) }; if (whisper_is_multilingual(ctx)) { const int lang_id = whisper_lang_id(params.language); - ctx->lang_id = lang_id; + state->lang_id = lang_id; prompt_init.push_back(whisper_token_lang(ctx, lang_id)); if (params.translate) { prompt_init.push_back(whisper_token_translate()); @@ -3669,14 +3814,14 @@ int whisper_full( } if (params.encoder_begin_callback) { - if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) { + if (params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data) == false) { fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__); break; } } // encode audio features starting at offset seek - if (!whisper_encode(*ctx, seek, params.n_threads)) { + if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads)) { fprintf(stderr, "%s: failed to encode\n", __func__); return -6; } @@ -3717,7 +3862,7 @@ int whisper_full( // TAGS: WHISPER_DECODER_INIT for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; + auto & decoder = state->decoders[j]; decoder.kv_self.n = 0; @@ -3759,7 +3904,7 @@ int whisper_full( } WHISPER_PRINT_DEBUG("\n\n"); - if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) { + if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) { fprintf(stderr, "%s: failed to decode\n", __func__); return -7; } @@ -3767,24 +3912,24 @@ int whisper_full( { const int64_t t_start_sample_us = ggml_time_us(); - whisper_process_logits(*ctx, params, ctx->decoders[0], t_cur); + whisper_process_logits(*ctx, *state, params, state->decoders[0], t_cur); - ctx->decoders[0].kv_self.n += prompt.size(); + state->decoders[0].kv_self.n += prompt.size(); for (int j = 1; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; + auto & decoder = state->decoders[j]; - memcpy(decoder.kv_self.k->data, ctx->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k)); - memcpy(decoder.kv_self.v->data, ctx->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v)); + memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k)); + memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v)); decoder.kv_self.n += prompt.size(); - memcpy(decoder.probs.data(), ctx->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0])); - memcpy(decoder.logits.data(), ctx->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0])); - memcpy(decoder.logprobs.data(), ctx->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0])); + memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0])); + memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0])); + memcpy(decoder.logprobs.data(), state->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0])); } - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + state->t_sample_us += ggml_time_us() - t_start_sample_us; } } @@ -3795,7 +3940,7 @@ int whisper_full( if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) { kv_bufs.resize(n_decoders_cur); for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; + auto & decoder = state->decoders[j]; if (decoder.completed || decoder.failed) { continue; @@ -3813,7 +3958,7 @@ int whisper_full( // generate new sequence candidates for each decoder for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; + auto & decoder = state->decoders[j]; if (decoder.completed || decoder.failed) { continue; @@ -3823,16 +3968,16 @@ int whisper_full( case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: { if (t_cur < 1e-6f) { - decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true)); + decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, true)); } else { - decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false)); + decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, false)); } decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog; } break; case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: { - const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size); + const auto tokens_new = whisper_sample_token_topk(*ctx, *state, decoder, params.beam_search.beam_size); for (const auto & token : tokens_new) { beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence }); @@ -3857,7 +4002,7 @@ int whisper_full( uint32_t cur_c = 0; for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; + auto & decoder = state->decoders[j]; if (decoder.completed || decoder.failed) { continue; @@ -3886,7 +4031,7 @@ int whisper_full( // - check if the sequence is failed // - update sliding window based on timestamp tokens for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; + auto & decoder = state->decoders[j]; if (decoder.completed || decoder.failed) { continue; @@ -3968,7 +4113,7 @@ int whisper_full( bool completed_all = true; for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; + auto & decoder = state->decoders[j]; if (decoder.completed || decoder.failed) { continue; @@ -3982,11 +4127,11 @@ int whisper_full( } } - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + state->t_sample_us += ggml_time_us() - t_start_sample_us; // obtain logits for the next token for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; + auto & decoder = state->decoders[j]; if (decoder.failed || decoder.completed) { continue; @@ -3997,7 +4142,7 @@ int whisper_full( //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta); - if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) { + if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) { fprintf(stderr, "%s: failed to decode\n", __func__); return -8; } @@ -4005,11 +4150,11 @@ int whisper_full( { const int64_t t_start_sample_us = ggml_time_us(); - whisper_process_logits(*ctx, params, decoder, t_cur); + whisper_process_logits(*ctx, *state, params, decoder, t_cur); ++decoder.kv_self.n; - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + state->t_sample_us += ggml_time_us() - t_start_sample_us; } } } @@ -4019,7 +4164,7 @@ int whisper_full( double best_score = -INFINITY; for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; + auto & decoder = state->decoders[j]; if (decoder.failed) { continue; @@ -4036,7 +4181,7 @@ int whisper_full( __func__, j, decoder.sequence.entropy, params.entropy_thold); decoder.failed = true; - ctx->n_fail_h++; + state->n_fail_h++; continue; } @@ -4054,11 +4199,11 @@ int whisper_full( { bool success = true; - const auto & decoder = ctx->decoders[best_decoder_id]; + const auto & decoder = state->decoders[best_decoder_id]; if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) { success = false; - ctx->n_fail_p++; + state->n_fail_p++; } if (success) { @@ -4075,7 +4220,7 @@ int whisper_full( // output results through a user-provided callback { - const auto & best_decoder = ctx->decoders[best_decoder_id]; + const auto & best_decoder = state->decoders[best_decoder_id]; const auto seek_delta = best_decoder.seek_delta; const auto result_len = best_decoder.sequence.result_len; @@ -4138,14 +4283,14 @@ int whisper_full( if (params.token_timestamps) { whisper_exp_compute_token_level_timestamps( - *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + *ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum); if (params.max_len > 0) { - n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word); + n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word); } } if (params.new_segment_callback) { - params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data); + params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data); } } text = ""; @@ -4182,14 +4327,14 @@ int whisper_full( if (params.token_timestamps) { whisper_exp_compute_token_level_timestamps( - *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + *ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum); if (params.max_len > 0) { - n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word); + n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word); } } if (params.new_segment_callback) { - params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data); + params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data); } } } @@ -4204,6 +4349,15 @@ int whisper_full( return 0; } + +int whisper_full( + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples) { + return whisper_full_with_state(ctx, ctx->state, params, samples, n_samples); +} + int whisper_full_parallel( struct whisper_context * ctx, struct whisper_full_params params, @@ -4213,40 +4367,10 @@ int whisper_full_parallel( if (n_processors == 1) { return whisper_full(ctx, params, samples, n_samples); } - int ret = 0; - // prepare separate contexts for each thread - std::vector ctxs(n_processors - 1); - - for (int i = 0; i < n_processors - 1; ++i) { - auto & ctx_p = ctxs[i]; - - ctx_p = *ctx; - - ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx); - - ctx_p.logits_id.reserve(ctx_p.vocab.n_vocab); - - if (!kv_cache_reinit(ctx_p.kv_cross)) { - fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention, processor %d\n", __func__, i); - return false; - } - - // TAGS: WHISPER_DECODER_INIT - for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) { - if (ctx_p.decoders[j].kv_self.ctx && !kv_cache_reinit(ctx_p.decoders[j].kv_self)) { - fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d, processor %d\n", __func__, j, i); - return false; - } - - ctx_p.decoders[j].sequence.tokens.reserve(ctx_p.model.hparams.n_text_ctx); - - ctx_p.decoders[j].probs.reserve (ctx_p.vocab.n_vocab); - ctx_p.decoders[j].logits.reserve (ctx_p.vocab.n_vocab); - ctx_p.decoders[j].logprobs.reserve(ctx_p.vocab.n_vocab); - } - } + // prepare separate states for each thread + std::vector states; const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000; const int n_samples_per_processor = (n_samples - offset_samples)/n_processors; @@ -4256,6 +4380,9 @@ int whisper_full_parallel( std::vector workers(n_processors - 1); for (int i = 0; i < n_processors - 1; ++i) { + // create a new state for each thread + states.push_back(whisper_init_state(ctx)); + const int start_samples = offset_samples + (i + 1)*n_samples_per_processor; const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor; @@ -4268,13 +4395,17 @@ int whisper_full_parallel( params_cur.new_segment_callback = nullptr; params_cur.new_segment_callback_user_data = nullptr; - workers[i] = std::thread(whisper_full, &ctxs[i], std::move(params_cur), samples + start_samples, n_samples_cur); + workers[i] = std::thread(whisper_full_with_state, ctx, states[i], std::move(params_cur), samples + start_samples, n_samples_cur); } { auto params_cur = params; - ret = whisper_full(ctx, std::move(params_cur), samples, offset_samples + n_samples_per_processor); + // We need to disable the print real-time for this one as well, otherwise it will show only for the first chunk. + params_cur.print_realtime = false; + + // Run the first transformation using default state but only for the first chunk. + ret = whisper_full_with_state(ctx, ctx->state, std::move(params_cur), samples, offset_samples + n_samples_per_processor); } for (int i = 0; i < n_processors - 1; ++i) { @@ -4283,45 +4414,43 @@ int whisper_full_parallel( const int64_t offset_t = (int64_t) params.offset_ms/10.0; - // combine results into ctx->result_all + // combine results into result_state->result_all from all other states for (int i = 0; i < n_processors - 1; ++i) { - auto & results_i = ctxs[i].result_all; + auto& results_i = states[i]->result_all; - for (auto & result : results_i) { + for (auto& result : results_i) { // correct the segment timestamp taking into account the offset - result.t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t; - result.t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t; + result.t0 += 100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t; + result.t1 += 100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t; + // make sure that segments are not overlapping - if (!ctx->result_all.empty()) { - result.t0 = std::max(result.t0, ctx->result_all.back().t1); + if (!ctx->state->result_all.empty()) { + result.t0 = std::max(result.t0, ctx->state->result_all.back().t1); } - ctx->result_all.push_back(std::move(result)); + ctx->state->result_all.push_back(std::move(result)); // call the new_segment_callback for each segment if (params.new_segment_callback) { - params.new_segment_callback(ctx, 1, params.new_segment_callback_user_data); + params.new_segment_callback(ctx, ctx->state, 1, params.new_segment_callback_user_data); } } - ctx->t_mel_us += ctxs[i].t_mel_us; - ctx->t_sample_us += ctxs[i].t_sample_us; - ctx->t_encode_us += ctxs[i].t_encode_us; - ctx->t_decode_us += ctxs[i].t_decode_us; + ctx->state->t_mel_us += states[i]->t_mel_us; - kv_cache_free(ctx->kv_cross); + ctx->state->t_sample_us += states[i]->t_sample_us; + ctx->state->t_encode_us += states[i]->t_encode_us; + ctx->state->t_decode_us += states[i]->t_decode_us; - for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) { - kv_cache_free(ctx->decoders[j].kv_self); - } + whisper_free_state(states[i]); } // average the timings - ctx->t_mel_us /= n_processors; - ctx->t_sample_us /= n_processors; - ctx->t_encode_us /= n_processors; - ctx->t_decode_us /= n_processors; + ctx->state->t_mel_us /= n_processors; + ctx->state->t_sample_us /= n_processors; + ctx->state->t_encode_us /= n_processors; + ctx->state->t_decode_us /= n_processors; // print information about the audio boundaries fprintf(stderr, "\n"); @@ -4334,44 +4463,84 @@ int whisper_full_parallel( return ret; } +int whisper_full_n_segments_from_state(struct whisper_state * state) { + return state->result_all.size(); +} + int whisper_full_n_segments(struct whisper_context * ctx) { - return ctx->result_all.size(); + return ctx->state->result_all.size(); +} + +int whisper_full_lang_id_from_state(struct whisper_state * state) { + return state->lang_id; } int whisper_full_lang_id(struct whisper_context * ctx) { - return ctx->lang_id; + return ctx->state->lang_id; +} + +int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) { + return state->result_all[i_segment].t0; } int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) { - return ctx->result_all[i_segment].t0; + return ctx->state->result_all[i_segment].t0; +} + +int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment) { + return state->result_all[i_segment].t1; } int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) { - return ctx->result_all[i_segment].t1; + return ctx->state->result_all[i_segment].t1; +} + +const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment) { + return state->result_all[i_segment].text.c_str(); } const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) { - return ctx->result_all[i_segment].text.c_str(); + return ctx->state->result_all[i_segment].text.c_str(); +} + +int whisper_full_n_tokens_from_state(struct whisper_state * state, int i_segment) { + return state->result_all[i_segment].tokens.size(); } int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment) { - return ctx->result_all[i_segment].tokens.size(); + return ctx->state->result_all[i_segment].tokens.size(); } -const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) { - return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str(); +const char * whisper_full_get_token_text_from_state(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token) { + return ctx->vocab.id_to_token[state->result_all[i_segment].tokens[i_token].id].c_str(); +} + +const char* whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) { + return ctx->vocab.id_to_token[ctx->state->result_all[i_segment].tokens[i_token].id].c_str(); +} + +whisper_token whisper_full_get_token_id_from_state(struct whisper_state * state, int i_segment, int i_token) { + return state->result_all[i_segment].tokens[i_token].id; } whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) { - return ctx->result_all[i_segment].tokens[i_token].id; + return ctx->state->result_all[i_segment].tokens[i_token].id; +} + +struct whisper_token_data whisper_full_get_token_data_from_state(struct whisper_state * state, int i_segment, int i_token) { + return state->result_all[i_segment].tokens[i_token]; } struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token) { - return ctx->result_all[i_segment].tokens[i_token]; + return ctx->state->result_all[i_segment].tokens[i_token]; +} + +float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token) { + return state->result_all[i_segment].tokens[i_token].p; } float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) { - return ctx->result_all[i_segment].tokens[i_token].p; + return ctx->state->result_all[i_segment].tokens[i_token].p; } // ================================================================================================= @@ -4583,13 +4752,14 @@ static std::vector get_signal_energy(const float * signal, int n_samples, static void whisper_exp_compute_token_level_timestamps( struct whisper_context & ctx, + struct whisper_state & state, int i_segment, float thold_pt, float thold_ptsum) { - auto & segment = ctx.result_all[i_segment]; + auto & segment = state.result_all[i_segment]; auto & tokens = segment.tokens; - const int n_samples = ctx.energy.size(); + const int n_samples = state.energy.size(); if (n_samples == 0) { fprintf(stderr, "%s: no signal data available\n", __func__); @@ -4612,9 +4782,9 @@ static void whisper_exp_compute_token_level_timestamps( return; } - auto & t_beg = ctx.t_beg; - auto & t_last = ctx.t_last; - auto & tid_last = ctx.tid_last; + auto & t_beg = state.t_beg; + auto & t_last = state.t_last; + auto & tid_last = state.tid_last; for (int j = 0; j < n; ++j) { auto & token = tokens[j]; @@ -4737,15 +4907,15 @@ static void whisper_exp_compute_token_level_timestamps( float sum = 0.0f; for (int k = ss0; k < ss1; k++) { - sum += ctx.energy[k]; + sum += state.energy[k]; } const float thold = 0.5*sum/ns; { int k = s0; - if (ctx.energy[k] > thold && j > 0) { - while (k > 0 && ctx.energy[k] > thold) { + if (state.energy[k] > thold && j > 0) { + while (k > 0 && state.energy[k] > thold) { k--; } tokens[j].t0 = sample_to_timestamp(k); @@ -4755,7 +4925,7 @@ static void whisper_exp_compute_token_level_timestamps( s0 = k; } } else { - while (ctx.energy[k] < thold && k < s1) { + while (state.energy[k] < thold && k < s1) { k++; } s0 = k; @@ -4765,8 +4935,8 @@ static void whisper_exp_compute_token_level_timestamps( { int k = s1; - if (ctx.energy[k] > thold) { - while (k < n_samples - 1 && ctx.energy[k] > thold) { + if (state.energy[k] > thold) { + while (k < n_samples - 1 && state.energy[k] > thold) { k++; } tokens[j].t1 = sample_to_timestamp(k); @@ -4776,7 +4946,7 @@ static void whisper_exp_compute_token_level_timestamps( s1 = k; } } else { - while (ctx.energy[k] < thold && k > s0) { + while (state.energy[k] < thold && k > s0) { k--; } s1 = k; diff --git a/whisper.h b/whisper.h index 3eb8d08..3984195 100644 --- a/whisper.h +++ b/whisper.h @@ -66,6 +66,7 @@ extern "C" { // struct whisper_context; + struct whisper_state; typedef int whisper_token; @@ -101,11 +102,20 @@ extern "C" { WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size); WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader); - // Frees all memory allocated by the model. - WHISPER_API void whisper_free(struct whisper_context * ctx); + // These are the same as the above, but the internal state of the context is not allocated automatically + // It is the responsibility of the caller to allocate the state using whisper_init_state() (#523) + WHISPER_API struct whisper_context * whisper_init_from_file_no_state(const char * path_model); + WHISPER_API struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size); + WHISPER_API struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader); + + WHISPER_API struct whisper_state * whisper_init_state(struct whisper_context * ctx); + + // Frees all allocated memory + WHISPER_API void whisper_free (struct whisper_context * ctx); + WHISPER_API void whisper_free_state(struct whisper_state * state); // Convert RAW PCM audio to log mel spectrogram. - // The resulting spectrogram is stored inside the provided whisper context. + // The resulting spectrogram is stored inside the default state of the provided whisper context. // Returns 0 on success WHISPER_API int whisper_pcm_to_mel( struct whisper_context * ctx, @@ -113,17 +123,30 @@ extern "C" { int n_samples, int n_threads); - // Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2. - // The resulting spectrogram is stored inside the provided whisper context. + WHISPER_API int whisper_pcm_to_mel_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + const float * samples, + int n_samples, + int n_threads); + + // Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2. + // The resulting spectrogram is stored inside the default state of the provided whisper context. // Returns 0 on success WHISPER_API int whisper_pcm_to_mel_phase_vocoder( - struct whisper_context* ctx, - const float* samples, - int n_samples, - int n_threads); - - - // This can be used to set a custom log mel spectrogram inside the provided whisper context. + struct whisper_context * ctx, + const float * samples, + int n_samples, + int n_threads); + + WHISPER_API int whisper_pcm_to_mel_phase_vocoder_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + const float * samples, + int n_samples, + int n_threads); + + // This can be used to set a custom log mel spectrogram inside the default state of the provided whisper context. // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram. // n_mel must be 80 // Returns 0 on success @@ -133,7 +156,14 @@ extern "C" { int n_len, int n_mel); - // Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context. + WHISPER_API int whisper_set_mel_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + const float * data, + int n_len, + int n_mel); + + // Run the Whisper encoder on the log mel spectrogram stored inside the default state in the provided whisper context. // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first. // offset can be used to specify the offset of the first frame in the spectrogram. // Returns 0 on success @@ -142,6 +172,12 @@ extern "C" { int offset, int n_threads); + WHISPER_API int whisper_encode_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + int offset, + int n_threads); + // Run the Whisper decoder to obtain the logits and probabilities for the next token. // Make sure to call whisper_encode() first. // tokens + n_tokens is the provided context for the decoder. @@ -155,6 +191,14 @@ extern "C" { int n_past, int n_threads); + WHISPER_API int whisper_decode_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + const whisper_token * tokens, + int n_tokens, + int n_past, + int n_threads); + // Convert the provided text into tokens. // The tokens pointer must be large enough to hold the resulting tokens. // Returns the number of tokens on success, no more than n_max_tokens @@ -190,17 +234,26 @@ extern "C" { int n_threads, float * lang_probs); - WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length - WHISPER_API int whisper_n_vocab (struct whisper_context * ctx); - WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx); - WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx); - WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx); + WHISPER_API int whisper_lang_auto_detect_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + int offset_ms, + int n_threads, + float * lang_probs); + + WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length + WHISPER_API int whisper_n_len_from_state(struct whisper_state * state); // mel length + WHISPER_API int whisper_n_vocab (struct whisper_context * ctx); + WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx); + WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx); + WHISPER_API int whisper_is_multilingual (struct whisper_context * ctx); // Token logits obtained from the last call to whisper_decode() // The logits for the last token are stored in the last row // Rows: n_tokens // Cols: n_vocab - WHISPER_API float * whisper_get_logits(struct whisper_context * ctx); + WHISPER_API float * whisper_get_logits (struct whisper_context * ctx); + WHISPER_API float * whisper_get_logits_from_state(struct whisper_state * state); // Token Id -> String. Uses the vocabulary in the provided context WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token); @@ -218,7 +271,7 @@ extern "C" { WHISPER_API whisper_token whisper_token_translate (void); WHISPER_API whisper_token whisper_token_transcribe(void); - // Performance information + // Performance information from the default state. WHISPER_API void whisper_print_timings(struct whisper_context * ctx); WHISPER_API void whisper_reset_timings(struct whisper_context * ctx); @@ -236,18 +289,19 @@ extern "C" { // Text segment callback // Called on every newly generated text segment // Use the whisper_full_...() functions to obtain the text segments - typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data); + typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data); // Encoder begin callback // If not NULL, called before the encoder starts // If it returns false, the computation is aborted - typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data); + typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data); // Logits filter callback // Can be used to modify the logits before sampling // If not NULL, called after applying temperature to logits typedef void (*whisper_logits_filter_callback)( struct whisper_context * ctx, + struct whisper_state * state, const whisper_token_data * tokens, int n_tokens, float * logits, @@ -334,6 +388,7 @@ extern "C" { WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy); // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text + // Not thread safe for same context // Uses the specified decoding strategy to obtain the text. WHISPER_API int whisper_full( struct whisper_context * ctx, @@ -341,7 +396,16 @@ extern "C" { const float * samples, int n_samples); - // Split the input audio in chunks and process each chunk separately using whisper_full() + WHISPER_API int whisper_full_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + struct whisper_full_params params, + const float * samples, + int n_samples); + + // Split the input audio in chunks and process each chunk separately using whisper_full_with_state() + // Result is stored in the default state of the context + // Not thread safe if executed in parallel on the same context. // It seems this approach can offer some speedup in some cases. // However, the transcription accuracy can be worse at the beginning and end of each chunk. WHISPER_API int whisper_full_parallel( @@ -351,33 +415,47 @@ extern "C" { int n_samples, int n_processors); - // Number of generated text segments. + // Number of generated text segments // A segment can be a few words, a sentence, or even a paragraph. - WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx); + WHISPER_API int whisper_full_n_segments (struct whisper_context * ctx); + WHISPER_API int whisper_full_n_segments_from_state(struct whisper_state * state); - // Language id associated with the current context + // Language id associated with the context's default state WHISPER_API int whisper_full_lang_id(struct whisper_context * ctx); - // Get the start and end time of the specified segment. - WHISPER_API int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment); - WHISPER_API int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment); + // Language id associated with the provided state + WHISPER_API int whisper_full_lang_id_from_state(struct whisper_state * state); + + // Get the start and end time of the specified segment + WHISPER_API int64_t whisper_full_get_segment_t0 (struct whisper_context * ctx, int i_segment); + WHISPER_API int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment); + + WHISPER_API int64_t whisper_full_get_segment_t1 (struct whisper_context * ctx, int i_segment); + WHISPER_API int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment); + + // Get the text of the specified segment + WHISPER_API const char * whisper_full_get_segment_text (struct whisper_context * ctx, int i_segment); + WHISPER_API const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment); - // Get the text of the specified segment. - WHISPER_API const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment); + // Get number of tokens in the specified segment + WHISPER_API int whisper_full_n_tokens (struct whisper_context * ctx, int i_segment); + WHISPER_API int whisper_full_n_tokens_from_state(struct whisper_state * state, int i_segment); - // Get number of tokens in the specified segment. - WHISPER_API int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment); + // Get the token text of the specified token in the specified segment + WHISPER_API const char * whisper_full_get_token_text (struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API const char * whisper_full_get_token_text_from_state(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token); - // Get the token text of the specified token in the specified segment. - WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token); - WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API whisper_token whisper_full_get_token_id_from_state(struct whisper_state * state, int i_segment, int i_token); - // Get token data for the specified token in the specified segment. + // Get token data for the specified token in the specified segment // This contains probabilities, timestamps, etc. - WHISPER_API whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API whisper_token_data whisper_full_get_token_data (struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API whisper_token_data whisper_full_get_token_data_from_state(struct whisper_state * state, int i_segment, int i_token); - // Get the probability of the specified token in the specified segment. - WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token); + // Get the probability of the specified token in the specified segment + WHISPER_API float whisper_full_get_token_p (struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token); ////////////////////////////////////////////////////////////////////////////