From a0da7f71a2b9d75cc858639380a21de2ef568ca7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 19 Feb 2023 19:39:05 +0200 Subject: [PATCH] command : wip in progress, improve guided decoding --- examples/addon.node/addon.cpp | 4 +- examples/command/command.cpp | 86 +++++++++++++++++++++++++++++++++-- examples/common-sdl.h | 6 +-- examples/main/main.cpp | 4 +- 4 files changed, 89 insertions(+), 11 deletions(-) diff --git a/examples/addon.node/addon.cpp b/examples/addon.node/addon.cpp index 2ef895f..18a2085 100644 --- a/examples/addon.node/addon.cpp +++ b/examples/addon.node/addon.cpp @@ -72,7 +72,7 @@ int timestamp_to_sample(int64_t t, int n_samples) { return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100))); } -void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) { +void whisper_print_segment(struct whisper_context * ctx, int n_new, void * user_data) { const auto & params = *((whisper_print_user_data *) user_data)->params; const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s; @@ -250,7 +250,7 @@ int run(whisper_params ¶ms, std::vector> &result) { // this callback is called on each new segment if (!wparams.print_realtime) { - wparams.new_segment_callback = whisper_print_segment_callback; + wparams.new_segment_callback = whisper_print_segment; wparams.new_segment_callback_user_data = &user_data; } diff --git a/examples/command/command.cpp b/examples/command/command.cpp index 2b9440a..3e3b533 100644 --- a/examples/command/command.cpp +++ b/examples/command/command.cpp @@ -109,6 +109,73 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, "\n"); } +struct whisper_logits_filter_user_data { + std::vector * allowed_commands; + std::vector> * allowed_tokens; +}; + +void whisper_logits_filter( + struct whisper_context * ctx, + const whisper_token_data * tokens, + int n_tokens, + float * logits, + void * user_data){ + const auto & allowed_tokens = *((whisper_logits_filter_user_data *) user_data)->allowed_tokens; + + printf("n_tokens = %d\n", n_tokens); + for (int i = 0; i < n_tokens; i++) { + printf(" - '%s' (%.2f)\n", whisper_token_to_str(ctx, tokens[i].id), logits[i]); + } + + if (n_tokens == 0) { + return; + } + + std::vector> pool; + for (int i = 0; i < (int) allowed_tokens.size(); i++) { + const int n = (int) allowed_tokens[i].size(); + if (n_tokens > n) { + continue; + } + + const whisper_token id = allowed_tokens[i][n_tokens - 1]; + pool.push_back({ id, logits[id] }); + } + + if (pool.empty()) { + return; + } + + printf("applying logits filter, pool size = %d\n", (int) pool.size()); + + const int ibeg = whisper_token_beg(ctx); + + double sum_all = 0.0; + for (int i = 0; i < ibeg; ++i) { + if (logits[i] == -INFINITY) { + continue; + } + sum_all += logits[i]; + } + + double sum_pool = 0.0; + for (int i = 0; i < (int) pool.size(); ++i) { + sum_pool += pool[i].second; + } + + printf("sum_all = %.2f, sum_pool = %.2f\n", sum_all, sum_pool); + + for (int i = 0; i < ibeg; ++i) { + logits[i] = -INFINITY; + } + + for (int i = 0; i < (int) pool.size(); ++i) { + //logits[pool[i].first] = pool[i].second / sum_pool * sum_all; + logits[pool[i].first] = pool[i].second; + printf(" - '%s' (%.2f)\n", whisper_token_to_str(ctx, pool[i].first), logits[pool[i].first]); + } +} + std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector & pcmf32, float & prob, int64_t & t_ms) { const auto t_start = std::chrono::high_resolution_clock::now(); @@ -131,6 +198,8 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con wparams.audio_ctx = params.audio_ctx; wparams.speed_up = params.speed_up; + wparams.temperature_inc = -1.0f; + if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { return ""; } @@ -334,22 +403,31 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const wparams.translate = params.translate; wparams.no_context = true; wparams.single_segment = true; - wparams.max_tokens = 1; + //wparams.max_tokens = 1; wparams.language = params.language.c_str(); wparams.n_threads = params.n_threads; wparams.audio_ctx = params.audio_ctx; wparams.speed_up = params.speed_up; + wparams.temperature_inc = -1.0f; + wparams.prompt_tokens = k_tokens.data(); wparams.prompt_n_tokens = k_tokens.size(); + whisper_logits_filter_user_data user_data = { &allowed_commands, &allowed_tokens }; + + wparams.logits_filter_callback = whisper_logits_filter; + wparams.logits_filter_callback_user_data = &user_data; + // run the transformer and a single decoding pass if (whisper_full(ctx, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) { fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__); break; } + fprintf(stdout, "%s: text - '%s'\n", __func__, whisper_full_get_segment_text(ctx, 0)); + // estimate command probability // NOTE: not optimal { @@ -436,7 +514,7 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const // always-prompt mode // transcribe the voice into text after valid prompt -int always_prompt_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) { +int process_always_prompt(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) { bool is_running = true; bool ask_prompt = true; @@ -496,7 +574,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi const float sim = similarity(prompt, k_prompt); //debug - //fprintf(stdout, "command size: %i\n", command_length); + //fprintf(stdout, "command size: %d, sim: %f\n", (int) command.size(), sim); if ((sim > 0.7f) && (command.size() > 0)) { fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms); @@ -676,7 +754,7 @@ int main(int argc, char ** argv) { if (!params.commands.empty()) { ret_val = process_command_list(ctx, audio, params); } else if (!params.prompt.empty()) { - ret_val = always_prompt_transcription(ctx, audio, params); + ret_val = process_always_prompt(ctx, audio, params); } else { ret_val = process_general_transcription(ctx, audio, params); } diff --git a/examples/common-sdl.h b/examples/common-sdl.h index a5659db..29d575f 100644 --- a/examples/common-sdl.h +++ b/examples/common-sdl.h @@ -1,13 +1,13 @@ #pragma once -#include -#include - #include #include #include #include +#include +#include + // // SDL Audio capture // diff --git a/examples/main/main.cpp b/examples/main/main.cpp index b8366b7..212ef87 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -193,7 +193,7 @@ struct whisper_print_user_data { const std::vector> * pcmf32s; }; -void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) { +void whisper_print_segment(struct whisper_context * ctx, int n_new, void * user_data) { const auto & params = *((whisper_print_user_data *) user_data)->params; const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s; @@ -597,7 +597,7 @@ int main(int argc, char ** argv) { // this callback is called on each new segment if (!wparams.print_realtime) { - wparams.new_segment_callback = whisper_print_segment_callback; + wparams.new_segment_callback = whisper_print_segment; wparams.new_segment_callback_user_data = &user_data; }