diff --git a/examples/command/command.cpp b/examples/command/command.cpp index 2bdaf87..6132509 100644 --- a/examples/command/command.cpp +++ b/examples/command/command.cpp @@ -413,7 +413,7 @@ bool vad_simple(std::vector & pcmf32, int sample_rate, int last_ms, float return true; } -std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector & pcmf32, float & prob, int64_t & t_ms) { +std::string transcribe(whisper_context * ctx, whisper_state * state, const whisper_params & params, const std::vector & pcmf32, float & prob, int64_t & t_ms) { const auto t_start = std::chrono::high_resolution_clock::now(); prob = 0.0f; @@ -435,22 +435,22 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con wparams.audio_ctx = params.audio_ctx; wparams.speed_up = params.speed_up; - if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { + if (whisper_full_with_state(ctx, state, wparams, pcmf32.data(), pcmf32.size()) != 0) { return ""; } int prob_n = 0; std::string result; - const int n_segments = whisper_full_n_segments(ctx); + const int n_segments = whisper_full_n_segments(state); for (int i = 0; i < n_segments; ++i) { - const char * text = whisper_full_get_segment_text(ctx, i); + const char * text = whisper_full_get_segment_text(state, i); result += text; - const int n_tokens = whisper_full_n_tokens(ctx, i); + const int n_tokens = whisper_full_n_tokens(state, i); for (int j = 0; j < n_tokens; ++j) { - const auto token = whisper_full_get_token_data(ctx, i, j); + const auto token = whisper_full_get_token_data(state, i, j); prob += token.p; ++prob_n; @@ -545,7 +545,7 @@ bool process_sdl_events() { // command-list mode // guide the transcription to match the most likely command from a provided list -int process_command_list(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) { +int process_command_list(struct whisper_context * ctx, struct whisper_state * state, audio_async &audio, const whisper_params ¶ms) { fprintf(stderr, "\n"); fprintf(stderr, "%s: guided mode\n", __func__); @@ -666,7 +666,7 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const wparams.prompt_n_tokens = k_tokens.size(); // run the transformer and a single decoding pass - if (whisper_full(ctx, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) { + if (whisper_full_with_state(ctx, state, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) { fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__); break; } @@ -674,7 +674,7 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const // estimate command probability // NOTE: not optimal { - const auto * logits = whisper_get_logits(ctx); + const auto * logits = whisper_get_logits(state); std::vector probs(whisper_n_vocab(ctx), 0.0f); @@ -757,7 +757,7 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const // always-prompt mode // transcribe the voice into text after valid prompt -int always_prompt_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) { +int always_prompt_transcription(struct whisper_context * ctx, struct whisper_state * state, audio_async & audio, const whisper_params & params) { bool is_running = true; bool ask_prompt = true; @@ -799,7 +799,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi // detect the commands audio.get(params.command_ms, pcmf32_cur); - const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms)); + const auto txt = ::trim(::transcribe(ctx, state, params, pcmf32_cur, prob, t_ms)); const auto words = get_words(txt); @@ -835,7 +835,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi // general-purpose mode // freely transcribe the voice into text -int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) { +int process_general_transcription(struct whisper_context * ctx, struct whisper_state * state, audio_async &audio, const whisper_params ¶ms) { bool is_running = true; bool have_prompt = false; bool ask_prompt = true; @@ -879,7 +879,7 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud // wait for activation phrase audio.get(params.prompt_ms, pcmf32_cur); - const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms)); + const auto txt = ::trim(::transcribe(ctx, state, params, pcmf32_cur, prob0, t_ms)); fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms); @@ -905,7 +905,7 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud // prepend the prompt audio pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end()); - const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms)); + const auto txt = ::trim(::transcribe(ctx, state, params, pcmf32_cur, prob, t_ms)); prob = 100.0f*(prob - prob0); @@ -958,6 +958,8 @@ int main(int argc, char ** argv) { struct whisper_context * ctx = whisper_init_from_file(params.model.c_str()); + struct whisper_state * state = whisper_init_state(ctx); + // print some info about the processing { fprintf(stderr, "\n"); @@ -995,16 +997,17 @@ int main(int argc, char ** argv) { int ret_val = 0; if (!params.commands.empty()) { - ret_val = process_command_list(ctx, audio, params); + ret_val = process_command_list(ctx, state, audio, params); } else if (!params.prompt.empty()) { - ret_val = always_prompt_transcription(ctx, audio, params); + ret_val = always_prompt_transcription(ctx, state, audio, params); } else { - ret_val = process_general_transcription(ctx, audio, params); + ret_val = process_general_transcription(ctx, state, audio, params); } audio.pause(); - whisper_print_timings(ctx); + whisper_print_timings(ctx, state); + whisper_free_state(state); whisper_free(ctx); return ret_val; diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 0dba72f..3272908 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -503,7 +503,7 @@ int main(int argc, char ** argv) { } // state init for the current transformation - struct whisper_state* state = whisper_init_state(ctx); + struct whisper_state * state = whisper_init_state(ctx); // initial prompt std::vector prompt_tokens; diff --git a/examples/stream/stream.cpp b/examples/stream/stream.cpp index b2a1824..2544bae 100644 --- a/examples/stream/stream.cpp +++ b/examples/stream/stream.cpp @@ -458,6 +458,7 @@ int main(int argc, char ** argv) { } struct whisper_context * ctx = whisper_init_from_file(params.model.c_str()); + struct whisper_state * state = whisper_init_state(ctx); std::vector pcmf32 (n_samples_30s, 0.0f); std::vector pcmf32_old; @@ -622,7 +623,7 @@ int main(int argc, char ** argv) { wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data(); wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size(); - if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { + if (whisper_full_with_state(ctx, state, wparams, pcmf32.data(), pcmf32.size()) != 0) { fprintf(stderr, "%s: failed to process audio\n", argv[0]); return 6; } @@ -645,9 +646,9 @@ int main(int argc, char ** argv) { printf("\n"); } - const int n_segments = whisper_full_n_segments(ctx); + const int n_segments = whisper_full_n_segments(state); for (int i = 0; i < n_segments; ++i) { - const char * text = whisper_full_get_segment_text(ctx, i); + const char * text = whisper_full_get_segment_text(state, i); if (params.no_timestamps) { printf("%s", text); @@ -657,8 +658,8 @@ int main(int argc, char ** argv) { fout << text; } } else { - const int64_t t0 = whisper_full_get_segment_t0(ctx, i); - const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + const int64_t t0 = whisper_full_get_segment_t0(state, i); + const int64_t t1 = whisper_full_get_segment_t1(state, i); printf ("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text); @@ -690,11 +691,11 @@ int main(int argc, char ** argv) { if (!params.no_context) { prompt_tokens.clear(); - const int n_segments = whisper_full_n_segments(ctx); + const int n_segments = whisper_full_n_segments(state); for (int i = 0; i < n_segments; ++i) { - const int token_count = whisper_full_n_tokens(ctx, i); + const int token_count = whisper_full_n_tokens(state, i); for (int j = 0; j < token_count; ++j) { - prompt_tokens.push_back(whisper_full_get_token_id(ctx, i, j)); + prompt_tokens.push_back(whisper_full_get_token_id(state, i, j)); } } } @@ -704,7 +705,8 @@ int main(int argc, char ** argv) { audio.pause(); - whisper_print_timings(ctx); + whisper_print_timings(ctx, state); + whisper_free_state(state); whisper_free(ctx); return 0; diff --git a/examples/talk/talk.cpp b/examples/talk/talk.cpp index 55cd46a..b5f533e 100644 --- a/examples/talk/talk.cpp +++ b/examples/talk/talk.cpp @@ -16,6 +16,7 @@ #include #include #include +#include // command-line parameters struct whisper_params { @@ -419,7 +420,7 @@ bool vad_simple(std::vector & pcmf32, int sample_rate, int last_ms, float return true; } -std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector & pcmf32, float & prob, int64_t & t_ms) { +std::string transcribe(whisper_context * ctx, whisper_state * state, const whisper_params & params, const std::vector & pcmf32, float & prob, int64_t & t_ms) { const auto t_start = std::chrono::high_resolution_clock::now(); prob = 0.0f; @@ -441,22 +442,22 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con wparams.audio_ctx = params.audio_ctx; wparams.speed_up = params.speed_up; - if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { + if (whisper_full_with_state(ctx, state, wparams, pcmf32.data(), pcmf32.size()) != 0) { return ""; } int prob_n = 0; std::string result; - const int n_segments = whisper_full_n_segments(ctx); + const int n_segments = whisper_full_n_segments(state); for (int i = 0; i < n_segments; ++i) { - const char * text = whisper_full_get_segment_text(ctx, i); + const char * text = whisper_full_get_segment_text(state, i); result += text; - const int n_tokens = whisper_full_n_tokens(ctx, i); + const int n_tokens = whisper_full_n_tokens(state, i); for (int j = 0; j < n_tokens; ++j) { - const auto token = whisper_full_get_token_data(ctx, i, j); + const auto token = whisper_full_get_token_data(state, i, j); prob += token.p; ++prob_n; @@ -470,6 +471,8 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con const auto t_end = std::chrono::high_resolution_clock::now(); t_ms = std::chrono::duration_cast(t_end - t_start).count(); + whisper_free_state(state); + return result; } @@ -500,6 +503,8 @@ int main(int argc, char ** argv) { struct whisper_context * ctx_wsp = whisper_init_from_file(params.model_wsp.c_str()); + struct whisper_state * state_wsp = whisper_init_state(ctx_wsp); + // gpt init struct gpt2_context * ctx_gpt = gpt2_init(params.model_gpt.c_str()); @@ -591,7 +596,7 @@ int main(int argc, char ** argv) { std::string text_heard; if (!force_speak) { - text_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prob0, t_ms)); + text_heard = ::trim(::transcribe(ctx_wsp, state_wsp, params, pcmf32_cur, prob0, t_ms)); } // remove text between brackets using regex @@ -688,7 +693,8 @@ int main(int argc, char ** argv) { audio.pause(); - whisper_print_timings(ctx_wsp); + whisper_print_timings(ctx_wsp, state_wsp); + whisper_free_state(state_wsp); whisper_free(ctx_wsp); return 0; diff --git a/whisper.cpp b/whisper.cpp index acb31da..dedc58e 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -2481,7 +2481,7 @@ struct whisper_state* whisper_init_state(whisper_context * ctx) { if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->wtype, ctx->model.hparams.n_text_ctx)) { fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); - return false; + return NULL; } { @@ -2491,7 +2491,7 @@ struct whisper_state* whisper_init_state(whisper_context * ctx) { if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_CROSS.at(ctx->model.type), state->kv_cross, ctx->wtype, ctx->model.hparams.n_audio_ctx)) { fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__); - return false; + return NULL; } { @@ -3514,7 +3514,7 @@ static void whisper_sequence_score( } } -int whisper_full_state( +int whisper_full_with_state( struct whisper_context * ctx, struct whisper_state * state, struct whisper_full_params params, @@ -3523,6 +3523,11 @@ int whisper_full_state( auto& result_all = state->result_all; + // clear previous results in case some state is re-used + if (!result_all.empty()) { + result_all.clear(); + } + // compute log mel spectrogram if (params.speed_up) { if (whisper_pcm_to_mel_phase_vocoder(ctx, state, samples, n_samples, params.n_threads) != 0) { @@ -4238,7 +4243,7 @@ int whisper_full( // Initialize a new state for the current transformation whisper_state* state = whisper_init_state(ctx); - int result = whisper_full_state(ctx, state, params, samples, n_samples); + int result = whisper_full_with_state(ctx, state, params, samples, n_samples); whisper_free_state(state); @@ -4287,7 +4292,7 @@ int whisper_full_parallel( params_cur.new_segment_callback = nullptr; params_cur.new_segment_callback_user_data = nullptr; - workers[i] = std::thread(whisper_full_state, ctx, states[i + 1], std::move(params_cur), samples + start_samples, n_samples_cur); + workers[i] = std::thread(whisper_full_with_state, ctx, states[i + 1], std::move(params_cur), samples + start_samples, n_samples_cur); } { @@ -4297,7 +4302,7 @@ int whisper_full_parallel( params_cur.print_realtime = false; // Run the first transformation using the first state and for the first chunk. - ret = whisper_full_state(ctx, states[0], std::move(params_cur), samples, offset_samples + n_samples_per_processor); + ret = whisper_full_with_state(ctx, states[0], std::move(params_cur), samples, offset_samples + n_samples_per_processor); } for (int i = 0; i < n_processors - 1; ++i) { diff --git a/whisper.h b/whisper.h index a1e2ff2..0d67125 100644 --- a/whisper.h +++ b/whisper.h @@ -339,6 +339,14 @@ extern "C" { const float * samples, int n_samples); + // Same as whisper_full() but uses the specified state, so that the caller can read the result on the state after the method call + WHISPER_API int whisper_full_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + struct whisper_full_params params, + const float * samples, + int n_samples); + // Split the input audio in chunks and process each chunk separately using whisper_full() // It seems this approach can offer some speedup in some cases. // However, the transcription accuracy can be worse at the beginning and end of each chunk.