diff --git a/whisper.cpp b/whisper.cpp index 24e16bd..0a54654 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -296,14 +296,14 @@ struct whisper_mel { int n_len; int n_mel; - std::vector data; + std::vector data{}; }; struct whisper_filters { int32_t n_mel; int32_t n_fft; - std::vector data; + std::vector data{}; }; struct whisper_vocab { @@ -540,20 +540,18 @@ struct whisper_decoder { bool has_ts; // have we already sampled a non-beg timestamp token for the current segment? // new token probs, logits and logprobs after the last whisper_decode (1-dimensional array: [n_vocab]) - std::vector probs; - std::vector logits; - std::vector logprobs; + std::vector probs{}; + std::vector logits{}; + std::vector logprobs{}; - std::vector tokens_tmp; // used for whisper_decode calls + std::vector tokens_tmp{}; // used for whisper_decode calls }; -struct whisper_context { - int64_t t_load_us = 0; - int64_t t_mel_us = 0; +struct whisper_state { int64_t t_sample_us = 0; int64_t t_encode_us = 0; int64_t t_decode_us = 0; - int64_t t_start_us = 0; + int64_t t_mel_us = 0; int32_t n_sample = 0; // number of tokens sampled int32_t n_encode = 0; // number of encoder calls @@ -561,16 +559,10 @@ struct whisper_context { int32_t n_fail_p = 0; // number of logprob threshold failures int32_t n_fail_h = 0; // number of entropy threshold failures - ggml_type wtype; // weight type (FP32 or FP16) - - whisper_mel mel; - - whisper_model model; - whisper_vocab vocab; - // cross-attention KV cache for the decoders // shared between all decoders whisper_kv_cache kv_cross; + whisper_mel mel; whisper_decoder decoders[WHISPER_MAX_DECODERS] = {}; @@ -635,6 +627,17 @@ struct whisper_context { } }; +struct whisper_context { + int64_t t_load_us = 0; + int64_t t_start_us = 0; + + + ggml_type wtype = ggml_type::GGML_TYPE_F16; // weight type (FP32 or FP16) + + whisper_model model; + whisper_vocab vocab; +}; + template static void read_safe(whisper_model_loader * loader, T & dest) { loader->read(loader->context, &dest, sizeof(T)); @@ -820,33 +823,6 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con wctx.model.buf = new std::vector(); wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type)); - - if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, wctx.wtype, model.hparams.n_text_ctx)) { - fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); - return false; - } - - { - const size_t memory_size = ggml_nbytes(wctx.decoders[0].kv_self.k) + ggml_nbytes(wctx.decoders[0].kv_self.v); - fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); - } - - if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_CROSS.at(model.type), wctx.kv_cross, wctx.wtype, model.hparams.n_audio_ctx)) { - fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__); - return false; - } - - { - const size_t memory_size = ggml_nbytes(wctx.kv_cross.k) + ggml_nbytes(wctx.kv_cross.v); - fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); - } - - wctx.buf_compute.resize(scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); - - wctx.buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(model.type)); - wctx.buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(model.type)); - wctx.buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(model.type)); - wctx.buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(model.type)); } // load mel filters @@ -929,17 +905,6 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con vocab.id_to_token[i] = word; } } - - wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx); - - wctx.logits_id.reserve(n_vocab); - - // TAGS: WHISPER_DECODER_INIT - wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx); - - wctx.decoders[0].probs.reserve (vocab.n_vocab); - wctx.decoders[0].logits.reserve (vocab.n_vocab); - wctx.decoders[0].logprobs.reserve(vocab.n_vocab); } size_t ctx_size = 0; @@ -1339,8 +1304,6 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con } } - wctx.rng = std::mt19937(0); - wctx.t_load_us = ggml_time_us() - t_start_us; return true; @@ -1357,15 +1320,16 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con // static bool whisper_encode( whisper_context & wctx, + whisper_state & wstate, const int mel_offset, const int n_threads) { const int64_t t_start_us = ggml_time_us(); const auto & model = wctx.model; - const auto & mel_inp = wctx.mel; + const auto & mel_inp = wstate.mel; const auto & hparams = model.hparams; - const int n_ctx = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx; + const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; const int n_state = hparams.n_audio_state; const int n_head = hparams.n_audio_head; const int n_layer = hparams.n_audio_layer; @@ -1374,12 +1338,12 @@ static bool whisper_encode( assert(mel_inp.n_mel == n_mels); struct ggml_init_params params; - params.mem_size = wctx.buf_compute.size(); - params.mem_buffer = wctx.buf_compute.data(); + params.mem_size = wstate.buf_compute.size(); + params.mem_buffer = wstate.buf_compute.data(); struct ggml_context * ctx0 = ggml_init(params); - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels); assert(mel->type == GGML_TYPE_F32); @@ -1401,7 +1365,7 @@ static bool whisper_encode( // convolution + gelu { - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel); cur = ggml_add(ctx0, @@ -1412,7 +1376,7 @@ static bool whisper_encode( cur = ggml_gelu(ctx0, cur); - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur); cur = ggml_add(ctx0, @@ -1424,7 +1388,7 @@ static bool whisper_encode( cur = ggml_gelu(ctx0, cur); } - wctx.use_buf(ctx0, 3); + wstate.use_buf(ctx0, 3); // =================================================================== // NOTE: experimenting with partial evaluation of the encoder (ignore) @@ -1459,7 +1423,7 @@ static bool whisper_encode( // norm { - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_norm(ctx0, inpL); @@ -1473,7 +1437,7 @@ static bool whisper_encode( // self-attention { - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.attn_q_w, @@ -1506,7 +1470,7 @@ static bool whisper_encode( // ------ - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); #ifdef WHISPER_USE_FLASH_ATTN struct ggml_tensor * Q = @@ -1583,7 +1547,7 @@ static bool whisper_encode( #endif struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); cur = ggml_cpy(ctx0, KQV_merged, @@ -1592,20 +1556,20 @@ static bool whisper_encode( // projection { - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_mul_mat(ctx0, layer.attn_ln_1_w, cur); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); cur = ggml_add(ctx0, ggml_repeat(ctx0, layer.attn_ln_1_b, cur), cur); } - wctx.use_buf(ctx0, 2); + wstate.use_buf(ctx0, 2); // add the input cur = ggml_add(ctx0, cur, inpL); @@ -1616,11 +1580,11 @@ static bool whisper_encode( { // norm { - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_norm(ctx0, inpFF); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); // cur = mlp_ln_w*cur + mlp_ln_b cur = ggml_add(ctx0, @@ -1631,38 +1595,38 @@ static bool whisper_encode( } #ifdef WHISPER_USE_FLASH_FF - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_flash_ff(ctx0, ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wctx.wtype, n_state, n_ctx)), layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); #else - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); // fully connected cur = ggml_mul_mat(ctx0, layer.mlp_0_w, cur); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); cur = ggml_add(ctx0, ggml_repeat(ctx0, layer.mlp_0_b, cur), cur); - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); // GELU activation cur = ggml_gelu(ctx0, cur); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); // projection cur = ggml_mul_mat(ctx0, layer.mlp_1_w, cur); - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_add(ctx0, ggml_repeat(ctx0, layer.mlp_1_b, cur), @@ -1670,7 +1634,7 @@ static bool whisper_encode( #endif } - wctx.use_buf(ctx0, 3); + wstate.use_buf(ctx0, 3); inpL = ggml_add(ctx0, cur, inpFF); } @@ -1679,11 +1643,11 @@ static bool whisper_encode( // norm { - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_norm(ctx0, cur); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); // cur = ln_f_g*cur + ln_f_b cur = ggml_add(ctx0, @@ -1693,7 +1657,7 @@ static bool whisper_encode( ggml_repeat(ctx0, model.e_ln_b, cur)); } - wctx.use_buf(ctx0, -1); + wstate.use_buf(ctx0, -1); // run the computation { @@ -1733,7 +1697,7 @@ static bool whisper_encode( for (int il = 0; il < model.hparams.n_text_layer; ++il) { auto & layer = model.layers_decoder[il]; - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); struct ggml_tensor * Kcross = ggml_mul_mat(ctx0, layer.cross_attn_k_w, @@ -1741,7 +1705,7 @@ static bool whisper_encode( Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); struct ggml_tensor * Vcross = ggml_mul_mat(ctx0, layer.cross_attn_v_w, @@ -1753,12 +1717,12 @@ static bool whisper_encode( Vcross), Vcross); - wctx.use_buf(ctx0, -1); + wstate.use_buf(ctx0, -1); - //struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); - //struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); - struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*n_ctx)); - struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*n_ctx)); + //struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); + //struct ggml_tensor * v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); + struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx)); + struct ggml_tensor * v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx)); ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k)); ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v)); @@ -1779,8 +1743,8 @@ static bool whisper_encode( ggml_free(ctx0); - wctx.t_encode_us += ggml_time_us() - t_start_us; - wctx.n_encode++; + wstate.t_encode_us += ggml_time_us() - t_start_us; + wstate.n_encode++; return true; } @@ -1797,6 +1761,7 @@ static bool whisper_encode( // static bool whisper_decode( whisper_context & wctx, + whisper_state & wstate, whisper_decoder & decoder, const whisper_token * tokens, const int n_tokens, @@ -1811,7 +1776,7 @@ static bool whisper_decode( WHISPER_ASSERT(!!kv_self.ctx); - auto & logits_out = wctx.logits; + auto & logits_out = wstate.logits; const int n_vocab = hparams.n_vocab; @@ -1821,13 +1786,13 @@ static bool whisper_decode( const int n_layer = hparams.n_text_layer; const int N = n_tokens; - const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx; + const int M = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx); struct ggml_init_params params; - params.mem_size = wctx.buf_compute.size(); - params.mem_buffer = wctx.buf_compute.data(); + params.mem_size = wstate.buf_compute.size(); + params.mem_buffer = wstate.buf_compute.data(); struct ggml_context * ctx0 = ggml_init(params); @@ -1842,7 +1807,7 @@ static bool whisper_decode( ((int32_t *) position->data)[i] = n_past + i; } - wctx.use_buf(ctx0, 3); + wstate.use_buf(ctx0, 3); // token encoding + position encoding struct ggml_tensor * cur = @@ -1857,7 +1822,7 @@ static bool whisper_decode( // norm { - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_norm(ctx0, inpL); @@ -1871,7 +1836,7 @@ static bool whisper_decode( // self-attention { - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.attn_q_w, @@ -1913,7 +1878,7 @@ static bool whisper_decode( // ------ - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); struct ggml_tensor * Q = ggml_permute(ctx0, @@ -1929,12 +1894,12 @@ static bool whisper_decode( n_state/n_head, n_head, n_past + N), 0, 2, 1, 3); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); // K * Q struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); //struct ggml_tensor * KQ_scaled = // ggml_scale(ctx0, @@ -1944,11 +1909,11 @@ static bool whisper_decode( struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); struct ggml_tensor * V_trans = ggml_permute(ctx0, @@ -1957,7 +1922,7 @@ static bool whisper_decode( n_state/n_head, n_head, n_past + N), 1, 2, 0, 3); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); @@ -1970,31 +1935,31 @@ static bool whisper_decode( // projection { - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_mul_mat(ctx0, layer.attn_ln_1_w, cur); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); cur = ggml_add(ctx0, ggml_repeat(ctx0, layer.attn_ln_1_b, cur), cur); } - wctx.use_buf(ctx0, 2); + wstate.use_buf(ctx0, 2); // add the input struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL); // norm { - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); // cur = ln_0_w*cur + ln_0_b cur = ggml_add(ctx0, @@ -2006,7 +1971,7 @@ static bool whisper_decode( // cross-attention { - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.cross_attn_q_w, @@ -2023,19 +1988,19 @@ static bool whisper_decode( // Kcross is already scaled struct ggml_tensor * Kcross = ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, wctx.kv_cross.k, M*n_state, il*M*ggml_element_size(wctx.kv_cross.k)*n_state), + ggml_view_1d(ctx0, wstate.kv_cross.k, M*n_state, il*M*ggml_element_size(wstate.kv_cross.k)*n_state), n_state/n_head, n_head, M); struct ggml_tensor * Vcross = ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, wctx.kv_cross.v, M*n_state, il*M*ggml_element_size(wctx.kv_cross.v)*n_state), + ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state), n_state/n_head, n_head, M); struct ggml_tensor * V_trans = ggml_permute(ctx0, Vcross, 1, 2, 0, 3); // ------ - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); struct ggml_tensor * Q = ggml_permute(ctx0, @@ -2046,7 +2011,7 @@ static bool whisper_decode( struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3); - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); // K * Q struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); @@ -2060,15 +2025,15 @@ static bool whisper_decode( // no masking for cross-attention //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ); - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); @@ -2080,20 +2045,20 @@ static bool whisper_decode( // projection { - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_mul_mat(ctx0, layer.cross_attn_ln_1_w, cur); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); cur = ggml_add(ctx0, ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur), cur); } - wctx.use_buf(ctx0, 2); + wstate.use_buf(ctx0, 2); // add the input cur = ggml_add(ctx0, cur, inpCA); @@ -2104,11 +2069,11 @@ static bool whisper_decode( { // norm { - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_norm(ctx0, inpFF); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); // cur = mlp_ln_w*cur + mlp_ln_b cur = ggml_add(ctx0, @@ -2118,39 +2083,39 @@ static bool whisper_decode( ggml_repeat(ctx0, layer.mlp_ln_b, cur)); } - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); // fully connected cur = ggml_mul_mat(ctx0, layer.mlp_0_w, cur); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); cur = ggml_add(ctx0, ggml_repeat(ctx0, layer.mlp_0_b, cur), cur); - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); // GELU activation cur = ggml_gelu(ctx0, cur); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); // projection cur = ggml_mul_mat(ctx0, layer.mlp_1_w, cur); - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_add(ctx0, ggml_repeat(ctx0, layer.mlp_1_b, cur), cur); } - wctx.use_buf(ctx0, 3); + wstate.use_buf(ctx0, 3); inpL = ggml_add(ctx0, cur, inpFF); } @@ -2159,11 +2124,11 @@ static bool whisper_decode( // norm { - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_norm(ctx0, cur); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); cur = ggml_add(ctx0, ggml_mul(ctx0, @@ -2172,7 +2137,7 @@ static bool whisper_decode( ggml_repeat(ctx0, model.d_ln_b, cur)); } - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); // compute logits only for the last token // comment this line to compute logits for all N tokens @@ -2181,7 +2146,7 @@ static bool whisper_decode( struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur); - wctx.use_buf(ctx0, -1); + wstate.use_buf(ctx0, -1); // run the computation { @@ -2208,8 +2173,8 @@ static bool whisper_decode( ggml_free(ctx0); - wctx.t_decode_us += ggml_time_us() - t_start_us; - wctx.n_decode++; + wstate.t_decode_us += ggml_time_us() - t_start_us; + wstate.n_decode++; return true; } @@ -2314,6 +2279,7 @@ static void fft(const std::vector & in, std::vector & out) { // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124 static bool log_mel_spectrogram( whisper_context & wctx, + whisper_state & wstate, const float * samples, const int n_samples, const int /*sample_rate*/, @@ -2433,7 +2399,7 @@ static bool log_mel_spectrogram( mel.data[i] = (mel.data[i] + 4.0)/4.0; } - wctx.t_mel_us += ggml_time_us() - t_start_us; + wstate.t_mel_us += ggml_time_us() - t_start_us; return true; } @@ -2507,6 +2473,54 @@ static std::vector tokenize(const whisper_vocab & vocab, cons // interface implementation // +struct whisper_state* whisper_init_state(whisper_context * ctx) { + whisper_state * state = new whisper_state; + + const size_t scale = ctx->model.hparams.f16 ? 1 : 2; + + + if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->wtype, ctx->model.hparams.n_text_ctx)) { + fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); + return false; + } + + { + const size_t memory_size = ggml_nbytes(state->decoders[0].kv_self.k) + ggml_nbytes(state->decoders[0].kv_self.v); + fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); + } + + if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_CROSS.at(ctx->model.type), state->kv_cross, ctx->wtype, ctx->model.hparams.n_audio_ctx)) { + fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__); + return false; + } + + { + const size_t memory_size = ggml_nbytes(state->kv_cross.k) + ggml_nbytes(state->kv_cross.v); + fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); + } + + + state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx); + + state->logits_id.reserve(ctx->model.hparams.n_vocab); + + // TAGS: WHISPER_DECODER_INIT + state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx); + + state->decoders[0].probs.reserve(ctx->vocab.n_vocab); + state->decoders[0].logits.reserve(ctx->vocab.n_vocab); + state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab); + state->buf_compute.resize(scale * std::max(MEM_REQ_ENCODE.at(ctx->model.type), MEM_REQ_DECODE.at(ctx->model.type))); + + state->buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(ctx->model.type)); + state->buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(ctx->model.type)); + state->buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(ctx->model.type)); + state->buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(ctx->model.type)); + + state->rng = std::mt19937(0); + return state; +} + struct whisper_context * whisper_init_from_file(const char * path_model) { whisper_model_loader loader = {}; @@ -2591,6 +2605,21 @@ struct whisper_context * whisper_init(struct whisper_model_loader * loader) { return ctx; } +void whisper_free_state(struct whisper_state* state) +{ + if (state) { + if (state->kv_cross.ctx) { + ggml_free(state->kv_cross.ctx); + } + for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) { + if (state->decoders[i].kv_self.ctx) { + ggml_free(state->decoders[i].kv_self.ctx); + } + } + delete state; + } +} + void whisper_free(struct whisper_context * ctx) { if (ctx) { if (ctx->model.ctx) { @@ -2599,20 +2628,12 @@ void whisper_free(struct whisper_context * ctx) { if (ctx->model.buf) { delete ctx->model.buf; } - if (ctx->kv_cross.ctx) { - ggml_free(ctx->kv_cross.ctx); - } - for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) { - if (ctx->decoders[i].kv_self.ctx) { - ggml_free(ctx->decoders[i].kv_self.ctx); - } - } delete ctx; } } -int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { - if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) { +int whisper_pcm_to_mel(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { + if (!log_mel_spectrogram(*ctx, *state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) { fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__); return -1; } @@ -2621,8 +2642,8 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int } // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 -int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { - if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) { +int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { + if (!log_mel_spectrogram(*ctx, *state, samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, state->mel)) { fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__); return -1; } @@ -2631,26 +2652,26 @@ int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * } int whisper_set_mel( - struct whisper_context * ctx, - const float * data, - int n_len, + struct whisper_state * state, + const float * data, + int n_len, int n_mel) { if (n_mel != WHISPER_N_MEL) { fprintf(stderr, "%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL); return -1; } - ctx->mel.n_len = n_len; - ctx->mel.n_mel = n_mel; + state->mel.n_len = n_len; + state->mel.n_mel = n_mel; - ctx->mel.data.resize(n_len*n_mel); - memcpy(ctx->mel.data.data(), data, n_len*n_mel*sizeof(float)); + state->mel.data.resize(n_len*n_mel); + memcpy(state->mel.data.data(), data, n_len*n_mel*sizeof(float)); return 0; } -int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) { - if (!whisper_encode(*ctx, offset, n_threads)) { +int whisper_encode(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) { + if (!whisper_encode(*ctx, *state, offset, n_threads)) { fprintf(stderr, "%s: failed to eval\n", __func__); return -1; } @@ -2658,11 +2679,11 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) { return 0; } -int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) { +int whisper_decode(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) { // TODO: add selected_decoder_id to context const int selected_decoder_id = 0; - if (!whisper_decode(*ctx, ctx->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) { + if (!whisper_decode(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) { fprintf(stderr, "%s: failed to eval\n", __func__); return 1; } @@ -2722,9 +2743,10 @@ const char * whisper_lang_str(int id) { int whisper_lang_auto_detect( struct whisper_context * ctx, - int offset_ms, - int n_threads, - float * lang_probs) { + struct whisper_state * state, + int offset_ms, + int n_threads, + float * lang_probs) { const int seek = offset_ms/10; if (seek < 0) { @@ -2732,30 +2754,30 @@ int whisper_lang_auto_detect( return -1; } - if (seek >= ctx->mel.n_len) { - fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, ctx->mel.n_len*10); + if (seek >= state->mel.n_len) { + fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len*10); return -2; } // run the encoder - if (whisper_encode(ctx, seek, n_threads) != 0) { + if (whisper_encode(ctx, state, seek, n_threads) != 0) { fprintf(stderr, "%s: failed to encode\n", __func__); return -6; } const std::vector prompt = { whisper_token_sot(ctx) }; - if (whisper_decode(ctx, prompt.data(), prompt.size(), 0, n_threads) != 0) { + if (whisper_decode(ctx, state, prompt.data(), prompt.size(), 0, n_threads) != 0) { fprintf(stderr, "%s: failed to decode\n", __func__); return -7; } - auto & logits_id = ctx->logits_id; + auto & logits_id = state->logits_id; logits_id.clear(); for (const auto & kv : g_lang) { const auto token_lang = whisper_token_lang(ctx, kv.second.first); - logits_id.emplace_back(ctx->logits[token_lang], kv.second.first); + logits_id.emplace_back(state->logits[token_lang], kv.second.first); } // sort descending @@ -2794,8 +2816,8 @@ int whisper_lang_auto_detect( return logits_id[0].second; } -int whisper_n_len(struct whisper_context * ctx) { - return ctx->mel.n_len; +int whisper_n_len(struct whisper_state * state) { + return state->mel.n_len; } int whisper_n_vocab(struct whisper_context * ctx) { @@ -2814,8 +2836,8 @@ int whisper_is_multilingual(struct whisper_context * ctx) { return ctx->vocab.is_multilingual() ? 1 : 0; } -float * whisper_get_logits(struct whisper_context * ctx) { - return ctx->logits.data(); +float * whisper_get_logits(struct whisper_state * state) { + return state->logits.data(); } const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) { @@ -2858,27 +2880,27 @@ whisper_token whisper_token_transcribe(void) { return whisper_vocab::token_transcribe; } -void whisper_print_timings(struct whisper_context * ctx) { +void whisper_print_timings(struct whisper_context * ctx, struct whisper_state * state) { const int64_t t_end_us = ggml_time_us(); - const int32_t n_sample = std::max(1, ctx->n_sample); - const int32_t n_encode = std::max(1, ctx->n_encode); - const int32_t n_decode = std::max(1, ctx->n_decode); + const int32_t n_sample = std::max(1, state->n_sample); + const int32_t n_encode = std::max(1, state->n_encode); + const int32_t n_decode = std::max(1, state->n_decode); fprintf(stderr, "\n"); - fprintf(stderr, "%s: fallbacks = %3d p / %3d h\n", __func__, ctx->n_fail_p, ctx->n_fail_h); + fprintf(stderr, "%s: fallbacks = %3d p / %3d h\n", __func__, state->n_fail_p, state->n_fail_h); fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us/1000.0f); - fprintf(stderr, "%s: mel time = %8.2f ms\n", __func__, ctx->t_mel_us/1000.0f); - fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_sample_us, n_sample, 1e-3f*ctx->t_sample_us/n_sample); - fprintf(stderr, "%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_encode_us, n_encode, 1e-3f*ctx->t_encode_us/n_encode); - fprintf(stderr, "%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_decode_us, n_decode, 1e-3f*ctx->t_decode_us/n_decode); + fprintf(stderr, "%s: mel time = %8.2f ms\n", __func__, state->t_mel_us/1000.0f); + fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f* state->t_sample_us, n_sample, 1e-3f* state->t_sample_us/n_sample); + fprintf(stderr, "%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f* state->t_encode_us, n_encode, 1e-3f* state->t_encode_us/n_encode); + fprintf(stderr, "%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f* state->t_decode_us, n_decode, 1e-3f* state->t_decode_us/n_decode); fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); } -void whisper_reset_timings(struct whisper_context * ctx) { - ctx->t_sample_us = 0; - ctx->t_encode_us = 0; - ctx->t_decode_us = 0; +void whisper_reset_timings(struct whisper_state * state) { + state->t_sample_us = 0; + state->t_encode_us = 0; + state->t_decode_us = 0; } const char * whisper_print_system_info(void) { @@ -2988,6 +3010,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str static std::vector get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window); static void whisper_exp_compute_token_level_timestamps( struct whisper_context & ctx, + struct whisper_state & state, int i_segment, float thold_pt, float thold_ptsum); @@ -3020,8 +3043,8 @@ static inline bool should_split_on_word(const char * txt, bool split_on_word) { // wrap the last segment to max_len characters // returns the number of new segments -static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool split_on_word) { - auto segment = ctx.result_all.back(); +static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_state & state, int max_len, bool split_on_word) { + auto segment = state.result_all.back(); int res = 1; int acc = 0; @@ -3043,24 +3066,24 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool trim(text); } - ctx.result_all.back().text = std::move(text); - ctx.result_all.back().t1 = token.t0; - ctx.result_all.back().tokens.resize(i); + state.result_all.back().text = std::move(text); + state.result_all.back().t1 = token.t0; + state.result_all.back().tokens.resize(i); - ctx.result_all.push_back({}); - ctx.result_all.back().t0 = token.t0; - ctx.result_all.back().t1 = segment.t1; + state.result_all.push_back({}); + state.result_all.back().t0 = token.t0; + state.result_all.back().t1 = segment.t1; // add tokens [i, end] to the new segment - ctx.result_all.back().tokens.insert( - ctx.result_all.back().tokens.end(), + state.result_all.back().tokens.insert( + state.result_all.back().tokens.end(), segment.tokens.begin() + i, segment.tokens.end()); acc = 0; text = ""; - segment = ctx.result_all.back(); + segment = state.result_all.back(); i = -1; res++; @@ -3073,7 +3096,7 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool if (split_on_word) { trim(text); } - ctx.result_all.back().text = std::move(text); + state.result_all.back().text = std::move(text); return res; } @@ -3091,6 +3114,7 @@ static const std::vector non_speech_tokens // - computes logprobs and probs static void whisper_process_logits( const struct whisper_context & ctx, + const struct whisper_state & state, const struct whisper_full_params params, struct whisper_decoder & decoder, float temperature) { @@ -3109,7 +3133,7 @@ static void whisper_process_logits( auto & logprobs = decoder.logprobs; { logits.resize(n_logits); - memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float)); + memcpy(logits.data(), state.logits.data() + (state.logits.size() - n_logits), n_logits*sizeof(float)); if (temperature > 0.0f) { for (int i = 0; i < n_logits; i++) { @@ -3310,6 +3334,7 @@ static void whisper_process_logits( static whisper_token_data whisper_sample_token( whisper_context & ctx, + whisper_state & state, const whisper_decoder & decoder, bool best) { whisper_token_data result = { @@ -3354,7 +3379,7 @@ static whisper_token_data whisper_sample_token( } else { std::discrete_distribution<> dist(probs.begin(), probs.end()); - result.id = dist(ctx.rng); + result.id = dist(state.rng); result.p = probs[result.id]; result.plog = logprobs[result.id]; } @@ -3364,13 +3389,14 @@ static whisper_token_data whisper_sample_token( result.pt = result.p; } - ctx.n_sample++; + state.n_sample++; return result; } static std::vector whisper_sample_token_topk( - whisper_context & ctx, + whisper_context & ctx, + whisper_state & state, const whisper_decoder & decoder, int k) { const auto & vocab = ctx.vocab; @@ -3381,7 +3407,7 @@ static std::vector whisper_sample_token_topk( const int n_logits = vocab.n_vocab; - auto & logits_id = ctx.logits_id; + auto & logits_id = state.logits_id; logits_id.clear(); for (int i = 0; i < n_logits; ++i) { @@ -3434,7 +3460,7 @@ static std::vector whisper_sample_token_topk( } } - ctx.n_sample++; + state.n_sample++; return result; } @@ -3488,24 +3514,23 @@ static void whisper_sequence_score( } } -int whisper_full( +int whisper_full_state( struct whisper_context * ctx, - struct whisper_full_params params, - const float * samples, - int n_samples) { - // clear old results - auto & result_all = ctx->result_all; - - result_all.clear(); + struct whisper_state * state, + struct whisper_full_params params, + const float * samples, + int n_samples) { + + auto& result_all = state->result_all; // compute log mel spectrogram if (params.speed_up) { - if (whisper_pcm_to_mel_phase_vocoder(ctx, samples, n_samples, params.n_threads) != 0) { + if (whisper_pcm_to_mel_phase_vocoder(ctx, state, samples, n_samples, params.n_threads) != 0) { fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__); return -1; } } else { - if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) { + if (whisper_pcm_to_mel(ctx, state, samples, n_samples, params.n_threads) != 0) { fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__); return -2; } @@ -3515,26 +3540,26 @@ int whisper_full( if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) { std::vector probs(whisper_lang_max_id() + 1, 0.0f); - const auto lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, probs.data()); + const auto lang_id = whisper_lang_auto_detect(ctx, state, 0, params.n_threads, probs.data()); if (lang_id < 0) { fprintf(stderr, "%s: failed to auto-detect language\n", __func__); return -3; } - ctx->lang_id = lang_id; + state->lang_id = lang_id; params.language = whisper_lang_str(lang_id); fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]); } if (params.token_timestamps) { - ctx->t_beg = 0; - ctx->t_last = 0; - ctx->tid_last = 0; - ctx->energy = get_signal_energy(samples, n_samples, 32); + state->t_beg = 0; + state->t_last = 0; + state->tid_last = 0; + state->energy = get_signal_energy(samples, n_samples, 32); } const int seek_start = params.offset_ms/10; - const int seek_end = seek_start + (params.duration_ms == 0 ? whisper_n_len(ctx) : params.duration_ms/10); + const int seek_end = seek_start + (params.duration_ms == 0 ? whisper_n_len(state) : params.duration_ms/10); // if length of spectrogram is less than 1s (100 samples), then return // basically don't process anything that is less than 1s @@ -3572,10 +3597,10 @@ int whisper_full( // TAGS: WHISPER_DECODER_INIT for (int j = 1; j < n_decoders; j++) { - auto & decoder = ctx->decoders[j]; + auto & decoder = state->decoders[j]; if (decoder.kv_self.ctx == nullptr) { - decoder.kv_self = ctx->decoders[0].kv_self; + decoder.kv_self = state->decoders[0].kv_self; if (!kv_cache_reinit(decoder.kv_self)) { fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j); return -4; @@ -3583,7 +3608,7 @@ int whisper_full( WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j); - decoder.sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity()); + decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity()); decoder.probs.resize (ctx->vocab.n_vocab); decoder.logits.resize (ctx->vocab.n_vocab); @@ -3592,7 +3617,7 @@ int whisper_full( } // the accumulated text context so far - auto & prompt_past = ctx->prompt_past; + auto & prompt_past = state->prompt_past; if (params.no_context) { prompt_past.clear(); } @@ -3611,13 +3636,13 @@ int whisper_full( fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx)); return -5; } - ctx->exp_n_audio_ctx = params.audio_ctx; + state->exp_n_audio_ctx = params.audio_ctx; // these tokens determine the task that will be performed std::vector prompt_init = { whisper_token_sot(ctx) }; if (whisper_is_multilingual(ctx)) { const int lang_id = whisper_lang_id(params.language); - ctx->lang_id = lang_id; + state->lang_id = lang_id; prompt_init.push_back(whisper_token_lang(ctx, lang_id)); if (params.translate) { prompt_init.push_back(whisper_token_translate()); @@ -3669,14 +3694,14 @@ int whisper_full( } if (params.encoder_begin_callback) { - if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) { + if (params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data) == false) { fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__); break; } } // encode audio features starting at offset seek - if (!whisper_encode(*ctx, seek, params.n_threads)) { + if (!whisper_encode(*ctx, *state, seek, params.n_threads)) { fprintf(stderr, "%s: failed to encode\n", __func__); return -6; } @@ -3717,7 +3742,7 @@ int whisper_full( // TAGS: WHISPER_DECODER_INIT for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; + auto & decoder = state->decoders[j]; decoder.kv_self.n = 0; @@ -3759,7 +3784,7 @@ int whisper_full( } WHISPER_PRINT_DEBUG("\n\n"); - if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) { + if (!whisper_decode(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) { fprintf(stderr, "%s: failed to decode\n", __func__); return -7; } @@ -3767,24 +3792,24 @@ int whisper_full( { const int64_t t_start_sample_us = ggml_time_us(); - whisper_process_logits(*ctx, params, ctx->decoders[0], t_cur); + whisper_process_logits(*ctx, *state, params, state->decoders[0], t_cur); - ctx->decoders[0].kv_self.n += prompt.size(); + state->decoders[0].kv_self.n += prompt.size(); for (int j = 1; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; + auto & decoder = state->decoders[j]; - memcpy(decoder.kv_self.k->data, ctx->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k)); - memcpy(decoder.kv_self.v->data, ctx->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v)); + memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k)); + memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v)); decoder.kv_self.n += prompt.size(); - memcpy(decoder.probs.data(), ctx->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0])); - memcpy(decoder.logits.data(), ctx->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0])); - memcpy(decoder.logprobs.data(), ctx->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0])); + memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0])); + memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0])); + memcpy(decoder.logprobs.data(), state->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0])); } - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + state->t_sample_us += ggml_time_us() - t_start_sample_us; } } @@ -3795,7 +3820,7 @@ int whisper_full( if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) { kv_bufs.resize(n_decoders_cur); for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; + auto & decoder = state->decoders[j]; if (decoder.completed || decoder.failed) { continue; @@ -3813,7 +3838,7 @@ int whisper_full( // generate new sequence candidates for each decoder for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; + auto & decoder = state->decoders[j]; if (decoder.completed || decoder.failed) { continue; @@ -3823,16 +3848,16 @@ int whisper_full( case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: { if (t_cur < 1e-6f) { - decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true)); + decoder.sequence.tokens.push_back(whisper_sample_token(*ctx,*state, decoder, true)); } else { - decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false)); + decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, false)); } decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog; } break; case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: { - const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size); + const auto tokens_new = whisper_sample_token_topk(*ctx, *state, decoder, params.beam_search.beam_size); for (const auto & token : tokens_new) { beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence }); @@ -3857,7 +3882,7 @@ int whisper_full( int cur_c = 0; for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; + auto & decoder = state->decoders[j]; if (decoder.completed || decoder.failed) { continue; @@ -3886,7 +3911,7 @@ int whisper_full( // - check if the sequence is failed // - update sliding window based on timestamp tokens for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; + auto & decoder = state->decoders[j]; if (decoder.completed || decoder.failed) { continue; @@ -3968,7 +3993,7 @@ int whisper_full( bool completed_all = true; for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; + auto & decoder = state->decoders[j]; if (decoder.completed || decoder.failed) { continue; @@ -3982,11 +4007,11 @@ int whisper_full( } } - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + state->t_sample_us += ggml_time_us() - t_start_sample_us; // obtain logits for the next token for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; + auto & decoder = state->decoders[j]; if (decoder.failed || decoder.completed) { continue; @@ -3997,7 +4022,7 @@ int whisper_full( //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta); - if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) { + if (!whisper_decode(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) { fprintf(stderr, "%s: failed to decode\n", __func__); return -8; } @@ -4005,11 +4030,11 @@ int whisper_full( { const int64_t t_start_sample_us = ggml_time_us(); - whisper_process_logits(*ctx, params, decoder, t_cur); + whisper_process_logits(*ctx, *state, params, decoder, t_cur); ++decoder.kv_self.n; - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + state->t_sample_us += ggml_time_us() - t_start_sample_us; } } } @@ -4019,7 +4044,7 @@ int whisper_full( double best_score = -INFINITY; for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; + auto & decoder = state->decoders[j]; if (decoder.failed) { continue; @@ -4036,7 +4061,7 @@ int whisper_full( __func__, j, decoder.sequence.entropy, params.entropy_thold); decoder.failed = true; - ctx->n_fail_h++; + state->n_fail_h++; continue; } @@ -4054,11 +4079,11 @@ int whisper_full( { bool success = true; - const auto & decoder = ctx->decoders[best_decoder_id]; + const auto & decoder = state->decoders[best_decoder_id]; if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) { success = false; - ctx->n_fail_p++; + state->n_fail_p++; } if (success) { @@ -4075,7 +4100,7 @@ int whisper_full( // output results through a user-provided callback { - const auto & best_decoder = ctx->decoders[best_decoder_id]; + const auto & best_decoder = state->decoders[best_decoder_id]; const auto seek_delta = best_decoder.seek_delta; const auto result_len = best_decoder.sequence.result_len; @@ -4138,14 +4163,14 @@ int whisper_full( if (params.token_timestamps) { whisper_exp_compute_token_level_timestamps( - *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + *ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum); if (params.max_len > 0) { - n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word); + n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word); } } if (params.new_segment_callback) { - params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data); + params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data); } } text = ""; @@ -4182,14 +4207,14 @@ int whisper_full( if (params.token_timestamps) { whisper_exp_compute_token_level_timestamps( - *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + *ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum); if (params.max_len > 0) { - n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word); + n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word); } } if (params.new_segment_callback) { - params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data); + params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data); } } } @@ -4204,53 +4229,48 @@ int whisper_full( return 0; } -int whisper_full_parallel( - struct whisper_context * ctx, - struct whisper_full_params params, - const float * samples, - int n_samples, - int n_processors) { - if (n_processors == 1) { - return whisper_full(ctx, params, samples, n_samples); - } +int whisper_full( + struct whisper_context* ctx, + struct whisper_full_params params, + const float* samples, + int n_samples) { - int ret = 0; + // Initialize a new state for the current transformation + whisper_state* state = whisper_init_state(ctx); - // prepare separate contexts for each thread - std::vector ctxs(n_processors - 1); + int result = whisper_full_state(ctx, state, params, samples, n_samples); - for (int i = 0; i < n_processors - 1; ++i) { - auto & ctx_p = ctxs[i]; + whisper_free_state(state); + + return result; - ctx_p = *ctx; +} - ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx); +int whisper_full_parallel( + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples, + int n_processors) { + if (n_processors == 1) { + return whisper_full(ctx, params, samples, n_samples); + } - ctx_p.logits_id.reserve(ctx_p.vocab.n_vocab); + int ret = 0; - if (!kv_cache_reinit(ctx_p.kv_cross)) { - fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention, processor %d\n", __func__, i); - return false; - } + const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000; + const int n_samples_per_processor = (n_samples - offset_samples)/n_processors; + const int64_t offset_t = (int64_t)params.offset_ms / 10.0; - // TAGS: WHISPER_DECODER_INIT - for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) { - if (ctx_p.decoders[j].kv_self.ctx && !kv_cache_reinit(ctx_p.decoders[j].kv_self)) { - fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d, processor %d\n", __func__, j, i); - return false; - } + auto aggregated_whisper_state = whisper_init_state(ctx); - ctx_p.decoders[j].sequence.tokens.reserve(ctx_p.model.hparams.n_text_ctx); + std::vector states{}; - ctx_p.decoders[j].probs.reserve (ctx_p.vocab.n_vocab); - ctx_p.decoders[j].logits.reserve (ctx_p.vocab.n_vocab); - ctx_p.decoders[j].logprobs.reserve(ctx_p.vocab.n_vocab); - } + for (int i = 0; i < n_processors; i++) + { + states.push_back(whisper_init_state(ctx)); } - const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000; - const int n_samples_per_processor = (n_samples - offset_samples)/n_processors; - // the calling thread will process the first chunk // while the other threads will process the remaining chunks @@ -4268,60 +4288,60 @@ int whisper_full_parallel( params_cur.new_segment_callback = nullptr; params_cur.new_segment_callback_user_data = nullptr; - workers[i] = std::thread(whisper_full, &ctxs[i], std::move(params_cur), samples + start_samples, n_samples_cur); + workers[i] = std::thread(whisper_full_state, ctx, states[i + 1], std::move(params_cur), samples + start_samples, n_samples_cur); } { auto params_cur = params; - ret = whisper_full(ctx, std::move(params_cur), samples, offset_samples + n_samples_per_processor); + // We need to disable the print real-time for this one as well, otherwise it will show only for the first chunk. + params_cur.print_realtime = false; + + // Run the first transformation using the first state and for the first chunk. + ret = whisper_full_state(ctx, states[0], std::move(params_cur), samples, offset_samples + n_samples_per_processor); } for (int i = 0; i < n_processors - 1; ++i) { workers[i].join(); } - const int64_t offset_t = (int64_t) params.offset_ms/10.0; - // combine results into ctx->result_all for (int i = 0; i < n_processors - 1; ++i) { - auto & results_i = ctxs[i].result_all; + auto & results_i = states[i]->result_all; for (auto & result : results_i) { // correct the segment timestamp taking into account the offset result.t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t; result.t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t; + // make sure that segments are not overlapping - if (!ctx->result_all.empty()) { - result.t0 = std::max(result.t0, ctx->result_all.back().t1); + if (!aggregated_whisper_state->result_all.empty()) { + result.t0 = std::max(result.t0, aggregated_whisper_state->result_all.back().t1); } - ctx->result_all.push_back(std::move(result)); + aggregated_whisper_state->result_all.push_back(std::move(result)); // call the new_segment_callback for each segment if (params.new_segment_callback) { - params.new_segment_callback(ctx, 1, params.new_segment_callback_user_data); + params.new_segment_callback(ctx, aggregated_whisper_state, 1, params.new_segment_callback_user_data); } } - ctx->t_mel_us += ctxs[i].t_mel_us; - ctx->t_sample_us += ctxs[i].t_sample_us; - ctx->t_encode_us += ctxs[i].t_encode_us; - ctx->t_decode_us += ctxs[i].t_decode_us; + aggregated_whisper_state->t_mel_us += states[i]->t_mel_us; - kv_cache_free(ctx->kv_cross); + aggregated_whisper_state->t_sample_us += states[i]->t_sample_us; + aggregated_whisper_state->t_encode_us += states[i]->t_encode_us; + aggregated_whisper_state->t_decode_us += states[i]->t_decode_us; - for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) { - kv_cache_free(ctx->decoders[j].kv_self); - } + whisper_free_state(states[i]); } // average the timings - ctx->t_mel_us /= n_processors; - ctx->t_sample_us /= n_processors; - ctx->t_encode_us /= n_processors; - ctx->t_decode_us /= n_processors; + aggregated_whisper_state->t_mel_us /= n_processors; + aggregated_whisper_state->t_sample_us /= n_processors; + aggregated_whisper_state->t_encode_us /= n_processors; + aggregated_whisper_state->t_decode_us /= n_processors; // print information about the audio boundaries fprintf(stderr, "\n"); @@ -4331,47 +4351,49 @@ int whisper_full_parallel( } fprintf(stderr, "%s: the transcription quality may be degraded near these boundaries\n", __func__); + whisper_free_state(aggregated_whisper_state); + return ret; } -int whisper_full_n_segments(struct whisper_context * ctx) { - return ctx->result_all.size(); +int whisper_full_n_segments(struct whisper_state * state) { + return state->result_all.size(); } -int whisper_full_lang_id(struct whisper_context * ctx) { - return ctx->lang_id; +int whisper_full_lang_id(struct whisper_state * state) { + return state->lang_id; } -int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) { - return ctx->result_all[i_segment].t0; +int64_t whisper_full_get_segment_t0(struct whisper_state * state, int i_segment) { + return state->result_all[i_segment].t0; } -int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) { - return ctx->result_all[i_segment].t1; +int64_t whisper_full_get_segment_t1(struct whisper_state * state, int i_segment) { + return state->result_all[i_segment].t1; } -const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) { - return ctx->result_all[i_segment].text.c_str(); +const char * whisper_full_get_segment_text(struct whisper_state * state, int i_segment) { + return state->result_all[i_segment].text.c_str(); } -int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment) { - return ctx->result_all[i_segment].tokens.size(); +int whisper_full_n_tokens(struct whisper_state * state, int i_segment) { + return state->result_all[i_segment].tokens.size(); } -const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) { - return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str(); +const char * whisper_full_get_token_text(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token) { + return ctx->vocab.id_to_token[state->result_all[i_segment].tokens[i_token].id].c_str(); } -whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) { - return ctx->result_all[i_segment].tokens[i_token].id; +whisper_token whisper_full_get_token_id(struct whisper_state * state, int i_segment, int i_token) { + return state->result_all[i_segment].tokens[i_token].id; } -struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token) { - return ctx->result_all[i_segment].tokens[i_token]; +struct whisper_token_data whisper_full_get_token_data(struct whisper_state * state, int i_segment, int i_token) { + return state->result_all[i_segment].tokens[i_token]; } -float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) { - return ctx->result_all[i_segment].tokens[i_token].p; +float whisper_full_get_token_p(struct whisper_state * state, int i_segment, int i_token) { + return state->result_all[i_segment].tokens[i_token].p; } // ================================================================================================= @@ -4582,14 +4604,16 @@ static std::vector get_signal_energy(const float * signal, int n_samples, } static void whisper_exp_compute_token_level_timestamps( - struct whisper_context & ctx, - int i_segment, - float thold_pt, - float thold_ptsum) { - auto & segment = ctx.result_all[i_segment]; + struct whisper_context & ctx, + struct whisper_state & state, + int i_segment, + float thold_pt, + float thold_ptsum) { + + auto & segment = state.result_all[i_segment]; auto & tokens = segment.tokens; - const int n_samples = ctx.energy.size(); + const int n_samples = state.energy.size(); if (n_samples == 0) { fprintf(stderr, "%s: no signal data available\n", __func__); @@ -4612,9 +4636,9 @@ static void whisper_exp_compute_token_level_timestamps( return; } - auto & t_beg = ctx.t_beg; - auto & t_last = ctx.t_last; - auto & tid_last = ctx.tid_last; + auto & t_beg = state.t_beg; + auto & t_last = state.t_last; + auto & tid_last = state.tid_last; for (int j = 0; j < n; ++j) { auto & token = tokens[j]; @@ -4737,15 +4761,15 @@ static void whisper_exp_compute_token_level_timestamps( float sum = 0.0f; for (int k = ss0; k < ss1; k++) { - sum += ctx.energy[k]; + sum += state.energy[k]; } const float thold = 0.5*sum/ns; { int k = s0; - if (ctx.energy[k] > thold && j > 0) { - while (k > 0 && ctx.energy[k] > thold) { + if (state.energy[k] > thold && j > 0) { + while (k > 0 && state.energy[k] > thold) { k--; } tokens[j].t0 = sample_to_timestamp(k); @@ -4755,7 +4779,7 @@ static void whisper_exp_compute_token_level_timestamps( s0 = k; } } else { - while (ctx.energy[k] < thold && k < s1) { + while (state.energy[k] < thold && k < s1) { k++; } s0 = k; @@ -4765,8 +4789,8 @@ static void whisper_exp_compute_token_level_timestamps( { int k = s1; - if (ctx.energy[k] > thold) { - while (k < n_samples - 1 && ctx.energy[k] > thold) { + if (state.energy[k] > thold) { + while (k < n_samples - 1 && state.energy[k] > thold) { k++; } tokens[j].t1 = sample_to_timestamp(k); @@ -4776,7 +4800,7 @@ static void whisper_exp_compute_token_level_timestamps( s1 = k; } } else { - while (ctx.energy[k] < thold && k > s0) { + while (state.energy[k] < thold && k > s0) { k--; } s1 = k; diff --git a/whisper.h b/whisper.h index 7eece79..0f79602 100644 --- a/whisper.h +++ b/whisper.h @@ -32,7 +32,7 @@ extern "C" { // // C interface // - // The following interface is thread-safe as long as the sample whisper_context is not used by multiple threads + // The following interface is thread-safe as long as the sample whisper_state is not used by multiple threads // concurrently. // // Basic usage: @@ -66,6 +66,7 @@ extern "C" { // struct whisper_context; + struct whisper_state; typedef int whisper_token; @@ -101,14 +102,21 @@ extern "C" { WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size); WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader); + // Function to create a new state used to run a transformation. + WHISPER_API struct whisper_state* whisper_init_state(whisper_context* ctx); + // Frees all memory allocated by the model. WHISPER_API void whisper_free(struct whisper_context * ctx); + // Frees all memory allocated by the current state. + WHISPER_API void whisper_free_state(struct whisper_state * state); + // Convert RAW PCM audio to log mel spectrogram. // The resulting spectrogram is stored inside the provided whisper context. // Returns 0 on success WHISPER_API int whisper_pcm_to_mel( struct whisper_context * ctx, + struct whisper_state * state, const float * samples, int n_samples, int n_threads); @@ -117,28 +125,30 @@ extern "C" { // The resulting spectrogram is stored inside the provided whisper context. // Returns 0 on success WHISPER_API int whisper_pcm_to_mel_phase_vocoder( - struct whisper_context* ctx, - const float* samples, - int n_samples, - int n_threads); + struct whisper_context * ctx, + struct whisper_state * state, + const float * samples, + int n_samples, + int n_threads); - // This can be used to set a custom log mel spectrogram inside the provided whisper context. + // This can be used to set a custom log mel spectrogram inside the provided whisper state. // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram. // n_mel must be 80 // Returns 0 on success WHISPER_API int whisper_set_mel( - struct whisper_context * ctx, + struct whisper_state * ctx, const float * data, int n_len, int n_mel); - // Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context. + // Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper state. // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first. // offset can be used to specify the offset of the first frame in the spectrogram. // Returns 0 on success WHISPER_API int whisper_encode( struct whisper_context * ctx, + struct whisper_state * state, int offset, int n_threads); @@ -150,6 +160,7 @@ extern "C" { // TODO: add support for multiple decoders WHISPER_API int whisper_decode( struct whisper_context * ctx, + struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, @@ -186,11 +197,12 @@ extern "C" { // ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69 WHISPER_API int whisper_lang_auto_detect( struct whisper_context * ctx, + struct whisper_state * state, int offset_ms, int n_threads, float * lang_probs); - WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length + WHISPER_API int whisper_n_len (struct whisper_state * ctx); // mel length WHISPER_API int whisper_n_vocab (struct whisper_context * ctx); WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx); WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx); @@ -200,7 +212,7 @@ extern "C" { // The logits for the last token are stored in the last row // Rows: n_tokens // Cols: n_vocab - WHISPER_API float * whisper_get_logits(struct whisper_context * ctx); + WHISPER_API float * whisper_get_logits(struct whisper_state * ctx); // Token Id -> String. Uses the vocabulary in the provided context WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token); @@ -219,8 +231,8 @@ extern "C" { WHISPER_API whisper_token whisper_token_transcribe(void); // Performance information - WHISPER_API void whisper_print_timings(struct whisper_context * ctx); - WHISPER_API void whisper_reset_timings(struct whisper_context * ctx); + WHISPER_API void whisper_print_timings(struct whisper_context * ctx, struct whisper_state * state); + WHISPER_API void whisper_reset_timings(struct whisper_state * ctx); // Print system information WHISPER_API const char * whisper_print_system_info(void); @@ -236,12 +248,12 @@ extern "C" { // Text segment callback // Called on every newly generated text segment // Use the whisper_full_...() functions to obtain the text segments - typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data); + typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data); // Encoder begin callback // If not NULL, called before the encoder starts // If it returns false, the computation is aborted - typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data); + typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data); // Parameters for the whisper_full() function // If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp: @@ -339,31 +351,31 @@ extern "C" { // Number of generated text segments. // A segment can be a few words, a sentence, or even a paragraph. - WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx); + WHISPER_API int whisper_full_n_segments(struct whisper_state * state); // Language id associated with the current context - WHISPER_API int whisper_full_lang_id(struct whisper_context * ctx); + WHISPER_API int whisper_full_lang_id(struct whisper_state * state); // Get the start and end time of the specified segment. - WHISPER_API int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment); - WHISPER_API int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment); + WHISPER_API int64_t whisper_full_get_segment_t0(struct whisper_state * state, int i_segment); + WHISPER_API int64_t whisper_full_get_segment_t1(struct whisper_state * state, int i_segment); // Get the text of the specified segment. - WHISPER_API const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment); + WHISPER_API const char * whisper_full_get_segment_text(struct whisper_state * state, int i_segment); // Get number of tokens in the specified segment. - WHISPER_API int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment); + WHISPER_API int whisper_full_n_tokens(struct whisper_state* state, int i_segment); // Get the token text of the specified token in the specified segment. - WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token); - WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token); + WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_state * ctx, int i_segment, int i_token); // Get token data for the specified token in the specified segment. // This contains probabilities, timestamps, etc. - WHISPER_API whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API whisper_token_data whisper_full_get_token_data(struct whisper_state * ctx, int i_segment, int i_token); // Get the probability of the specified token in the specified segment. - WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API float whisper_full_get_token_p(struct whisper_state * state, int i_segment, int i_token); ////////////////////////////////////////////////////////////////////////////