From 7affd309d34ce36fe219956d7ec5d544163e5ae7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 22 Oct 2022 21:06:50 +0300 Subject: [PATCH] whisper : add new-segment callback Can be used to process new segments as they are being generated. Sample usage in main, for printing the resulting segments during the inference. --- main.cpp | 95 +++++++++++++++++++++++++++++++---------------------- whisper.cpp | 16 +++++++++ whisper.h | 9 +++++ 3 files changed, 81 insertions(+), 39 deletions(-) diff --git a/main.cpp b/main.cpp index 1fcc106..b064860 100644 --- a/main.cpp +++ b/main.cpp @@ -141,6 +141,55 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params) fprintf(stderr, "\n"); } +void whisper_print_segment_callback(struct whisper_context * ctx, void * user_data) { + const whisper_params & params = *(whisper_params *) user_data; + + const int n_segments = whisper_full_n_segments(ctx); + + // print the last segment + const int i = n_segments - 1; + if (i == 0) { + printf("\n"); + } + + if (params.no_timestamps) { + if (params.print_colors) { + // TODO + } else { + const char * text = whisper_full_get_segment_text(ctx, i); + printf("%s", text); + fflush(stdout); + } + } else { + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + + if (params.print_colors) { + printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str()); + for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { + if (params.print_special_tokens == false) { + const whisper_token id = whisper_full_get_token_id(ctx, i, j); + if (id >= whisper_token_eot(ctx)) { + continue; + } + } + + const char * text = whisper_full_get_token_text(ctx, i, j); + const float p = whisper_full_get_token_p (ctx, i, j); + + const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size())))); + + printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m"); + } + printf("\n"); + } else { + const char * text = whisper_full_get_segment_text(ctx, i); + + printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text); + } + } +} + bool output_txt(struct whisper_context * ctx, const char * fname) { std::ofstream fout(fname); if (!fout.is_open()) { @@ -294,7 +343,7 @@ int main(int argc, char ** argv) { { whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); - wparams.print_realtime = !params.print_colors; + wparams.print_realtime = false; wparams.print_progress = false; wparams.print_timestamps = !params.no_timestamps; wparams.print_special_tokens = params.print_special_tokens; @@ -303,49 +352,17 @@ int main(int argc, char ** argv) { wparams.n_threads = params.n_threads; wparams.offset_ms = params.offset_t_ms; + // this callback is called on each new segment + if (!wparams.print_realtime) { + wparams.new_segment_callback = whisper_print_segment_callback; + wparams.new_segment_callback_user_data = ¶ms; + } + if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { fprintf(stderr, "%s: failed to process audio\n", argv[0]); return 7; } - // print result - if (!wparams.print_realtime) { - printf("\n"); - - const int n_segments = whisper_full_n_segments(ctx); - for (int i = 0; i < n_segments; ++i) { - if (params.no_timestamps) { - if (params.print_colors) { - // TODO - } else { - const char * text = whisper_full_get_segment_text(ctx, i); - printf("%s", text); - fflush(stdout); - } - } else { - const int64_t t0 = whisper_full_get_segment_t0(ctx, i); - const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - - if (params.print_colors) { - printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str()); - for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { - const char * text = whisper_full_get_token_text(ctx, i, j); - const float p = whisper_full_get_token_p (ctx, i, j); - - const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size())))); - - printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m"); - } - printf("\n"); - } else { - const char * text = whisper_full_get_segment_text(ctx, i); - - printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text); - } - } - } - } - printf("\n"); // output to text file diff --git a/whisper.cpp b/whisper.cpp index 5c5f8bd..01f6b00 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -2320,6 +2320,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.beam_width =*/ -1, /*.n_best =*/ -1, }, + + /*.new_segment_callback =*/ nullptr, + /*.new_segment_callback_user_data =*/ nullptr, }; } break; case WHISPER_SAMPLING_BEAM_SEARCH: @@ -2348,6 +2351,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.beam_width =*/ 10, /*.n_best =*/ 5, }, + + /*.new_segment_callback =*/ nullptr, + /*.new_segment_callback_user_data =*/ nullptr, }; } break; } @@ -2549,6 +2555,9 @@ int whisper_full( for (int j = i0; j <= i; j++) { result_all.back().tokens.push_back(tokens_cur[j]); } + if (params.new_segment_callback) { + params.new_segment_callback(ctx, params.new_segment_callback_user_data); + } } text = ""; while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) { @@ -2576,6 +2585,9 @@ int whisper_full( for (int j = i0; j < (int) tokens_cur.size(); j++) { result_all.back().tokens.push_back(tokens_cur[j]); } + if (params.new_segment_callback) { + params.new_segment_callback(ctx, params.new_segment_callback_user_data); + } } } @@ -2609,6 +2621,10 @@ const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_seg return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str(); } +whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) { + return ctx->result_all[i_segment].tokens[i_token].id; +} + float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) { return ctx->result_all[i_segment].tokens[i_token].p; } diff --git a/whisper.h b/whisper.h index 3435cd7..53b0041 100644 --- a/whisper.h +++ b/whisper.h @@ -160,6 +160,11 @@ extern "C" { WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet! }; + // Text segment callback + // Called on every newly generated text segment + // Use the whisper_full_...() functions to obtain the text segments + typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, void * user_data); + struct whisper_full_params { enum whisper_sampling_strategy strategy; @@ -184,6 +189,9 @@ extern "C" { int beam_width; int n_best; } beam_search; + + whisper_new_segment_callback new_segment_callback; + void * new_segment_callback_user_data; }; WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy); @@ -212,6 +220,7 @@ extern "C" { // Get the token text of the specified token in the specified segment. WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token); // Get the probability of the specified token in the specified segment. WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);