From 31ff0c6a1fe33cac266538e4a1f53fce9533fe4d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 21 Oct 2022 17:33:59 +0300 Subject: [PATCH] wip : experimental color coding of tokens based on probabilities --- main.cpp | 45 ++++++++++++++---- whisper.cpp | 135 ++++++++++++++++++++++++++++++++++------------------ whisper.h | 9 ++++ 3 files changed, 133 insertions(+), 56 deletions(-) diff --git a/main.cpp b/main.cpp index c9ac669..cf5ca95 100644 --- a/main.cpp +++ b/main.cpp @@ -5,12 +5,20 @@ #define DR_WAV_IMPLEMENTATION #include "dr_wav.h" +#include #include #include #include #include #include +// Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9] +// Lowest is red, middle is yellow, highest is green. +const std::vector k_colors = { + "\033[38;5;196m", "\033[38;5;202m", "\033[38;5;208m", "\033[38;5;214m", "\033[38;5;220m", + "\033[38;5;226m", "\033[38;5;190m", "\033[38;5;154m", "\033[38;5;118m", "\033[38;5;82m", +}; + // 500 -> 00:05.000 // 6000 -> 01:00.000 std::string to_timestamp(int64_t t) { @@ -41,6 +49,7 @@ struct whisper_params { bool output_vtt = false; bool output_srt = false; bool print_special_tokens = false; + bool print_colors = false; bool no_timestamps = false; std::string language = "en"; @@ -87,6 +96,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { params.output_srt = true; } else if (arg == "-ps" || arg == "--print_special") { params.print_special_tokens = true; + } else if (arg == "-pc" || arg == "--print_colors") { + params.print_colors = true; } else if (arg == "-nt" || arg == "--no_timestamps") { params.no_timestamps = true; } else if (arg == "-m" || arg == "--model") { @@ -122,6 +133,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params) fprintf(stderr, " -ovtt, --output-vtt output result in a vtt file\n"); fprintf(stderr, " -osrt, --output-srt output result in a srt file\n"); fprintf(stderr, " -ps, --print_special print special tokens\n"); + fprintf(stderr, " -pc, --print_colors print colors\n"); fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n"); fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str()); fprintf(stderr, " -m FNAME, --model FNAME model path (default: %s)\n", params.model.c_str()); @@ -222,7 +234,7 @@ int main(int argc, char ** argv) { { whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); - wparams.print_realtime = true; + wparams.print_realtime = !params.print_colors; wparams.print_progress = false; wparams.print_timestamps = !params.no_timestamps; wparams.print_special_tokens = params.print_special_tokens; @@ -242,16 +254,34 @@ int main(int argc, char ** argv) { const int n_segments = whisper_full_n_segments(ctx); for (int i = 0; i < n_segments; ++i) { - const char * text = whisper_full_get_segment_text(ctx, i); - if (params.no_timestamps) { - printf("%s", text); - fflush(stdout); + if (params.print_colors) { + // TODO + } else { + const char * text = whisper_full_get_segment_text(ctx, i); + printf("%s", text); + fflush(stdout); + } } else { const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text); + if (params.print_colors) { + printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str()); + for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { + const char * text = whisper_full_get_token_text(ctx, i, j); + const float p = whisper_full_get_token_p (ctx, i, j); + + const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size())))); + + printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m"); + } + printf("\n"); + } else { + const char * text = whisper_full_get_segment_text(ctx, i); + + printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text); + } } } } @@ -260,7 +290,6 @@ int main(int argc, char ** argv) { // output to text file if (params.output_txt) { - const auto fname_txt = fname_inp + ".txt"; std::ofstream fout_txt(fname_txt); if (!fout_txt.is_open()) { @@ -279,7 +308,6 @@ int main(int argc, char ** argv) { // output to VTT file if (params.output_vtt) { - const auto fname_vtt = fname_inp + ".vtt"; std::ofstream fout_vtt(fname_vtt); if (!fout_vtt.is_open()) { @@ -304,7 +332,6 @@ int main(int argc, char ** argv) { // output to SRT file if (params.output_srt) { - const auto fname_srt = fname_inp + ".srt"; std::ofstream fout_srt(fname_srt); if (!fout_srt.is_open()) { diff --git a/whisper.cpp b/whisper.cpp index 09250c0..5c5f8bd 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -210,9 +210,12 @@ struct whisper_vocab { } }; -struct whisper_result { - int64_t t; - whisper_token id; +struct whisper_token_data { + whisper_token id; // token id + whisper_token tid; // forced timestamp token id + + float p; // probability of the token + float pt; // probability of the timestamp token }; struct whisper_segment { @@ -220,6 +223,8 @@ struct whisper_segment { int64_t t1; std::string text; + + std::vector tokens; }; // medium @@ -407,7 +412,7 @@ struct whisper_context { std::vector probs; std::vector logits; - std::vector result_cur; + std::vector tokens_cur; std::vector result_all; std::vector prompt_past; @@ -1786,9 +1791,11 @@ bool whisper_decode( } // the most basic sampling scheme - select the top token -whisper_vocab::id whisper_sample_best( +whisper_token_data whisper_sample_best( const whisper_vocab & vocab, const float * probs) { + whisper_token_data result; + int n_logits = vocab.id_to_token.size(); std::vector> probs_id; @@ -1798,24 +1805,33 @@ whisper_vocab::id whisper_sample_best( probs_id.push_back(std::make_pair(probs[i], i)); } - double sum_ts = 0.0; - double max_tx = 0.0; + { + double sum_ts = 0.0; + double max_ts = -1.0; + double max_tx = -1.0; - for (int i = 0; i < vocab.token_beg; i++) { - max_tx = std::max(max_tx, probs_id[i].first); - } + for (int i = 0; i < vocab.token_beg; i++) { + max_tx = std::max(max_tx, probs_id[i].first); + } - for (int i = vocab.token_beg; i < n_logits; i++) { - sum_ts += probs_id[i].first; - } + for (int i = vocab.token_beg; i < n_logits; i++) { + sum_ts += probs_id[i].first; + if (probs_id[i].first > max_ts) { + max_ts = probs_id[i].first; + result.tid = probs_id[i].second; + } + } - // if the probability sum of all timestamp tokesn is higher than the max probability of the text tokens - sample a - // timestamp token - if (sum_ts > max_tx) { - // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438 - for (int i = 0; i < vocab.token_beg; i++) { - probs_id[i].first = -INFINITY; + // if the probability sum of all timestamp tokens is higher than the max probability of the text tokens - sample a + // timestamp token + if (sum_ts > max_tx) { + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438 + for (int i = 0; i < vocab.token_beg; i++) { + probs_id[i].first = -INFINITY; + } } + + result.pt = max_ts/(sum_ts + 1e-6); } // find the top K tokens @@ -1843,7 +1859,10 @@ whisper_vocab::id whisper_sample_best( res++; } - return probs_id[res].second; + result.id = probs_id[res].second; + result.p = probs_id[res].first; + + return result; } // samples only from the timestamps tokens @@ -2178,7 +2197,7 @@ whisper_token whisper_sample_best(struct whisper_context * ctx) { ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - return res; + return res.id; } whisper_token whisper_sample_timestamp(struct whisper_context * ctx) { @@ -2343,7 +2362,7 @@ int whisper_full( int n_samples) { // clear old results auto & result_all = ctx->result_all; - auto & result_cur = ctx->result_cur; + auto & tokens_cur = ctx->tokens_cur; result_all.clear(); @@ -2430,7 +2449,7 @@ int whisper_full( // the accumulated transcription in the current interation int result_len = 0; - result_cur.clear(); + tokens_cur.clear(); for (int i = 0; i < whisper_n_text_ctx(ctx)/2 - 4; ++i) { if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) { @@ -2449,28 +2468,26 @@ int whisper_full( // feel free to experiment! // { - whisper_token id = 0; - whisper_token tid = whisper_token_beg(ctx); + auto token = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab)); - id = whisper_sample_best(ctx); - if (i > 0) { - tid = whisper_sample_timestamp(ctx); + if (i == 0) { + token.tid = whisper_token_beg(ctx); } - // update sliding window - if (id > whisper_token_beg(ctx)) { - seek_delta = 2*(id - whisper_token_beg(ctx)); + // timestamp token - update sliding window + if (token.id > whisper_token_beg(ctx)) { + seek_delta = 2*(token.id - whisper_token_beg(ctx)); result_len = i + 1; } // add it to the context - prompt.push_back(id); - result_cur.push_back({ seek + 2*(tid - whisper_token_beg(ctx)), id }); + prompt.push_back(token.id); + tokens_cur.push_back(token); //printf("%s: %s\n", __func__, ctx->vocab.id_to_token[id].c_str()); // end of text token - if (id == whisper_token_eot(ctx)) { + if (token.id == whisper_token_eot(ctx)) { if (result_len == 0) { if (seek + seek_delta + 100 >= whisper_n_len(ctx)) { result_len = i + 1; @@ -2494,25 +2511,30 @@ int whisper_full( } } - result_cur.resize(result_len); + tokens_cur.resize(result_len); - for (const auto & r : result_cur) { + for (const auto & r : tokens_cur) { prompt_past.push_back(r.id); } // store the text from this iteration - if (result_cur.size() > 0) { - auto t0 = result_cur.front().t; + if (tokens_cur.size() > 0) { + int i0 = 0; + auto t0 = 2*(tokens_cur.front().tid - whisper_token_beg(ctx)); std::string text = ""; - for (int i = 0; i < (int) result_cur.size(); i++) { - if (params.print_special_tokens == false && result_cur[i].id >= whisper_token_eot(ctx)) { + for (int i = 0; i < (int) tokens_cur.size(); i++) { + //printf("%s: %18s %6.3f %18s %6.3f\n", __func__, + // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p, + // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt); + + if (params.print_special_tokens == false && tokens_cur[i].id >= whisper_token_eot(ctx)) { } else { - text += whisper_token_to_str(ctx, result_cur[i].id); + text += whisper_token_to_str(ctx, tokens_cur[i].id); } - if (result_cur[i].id > whisper_token_beg(ctx)) { - const auto t1 = result_cur[i].t; + if (tokens_cur[i].id > whisper_token_beg(ctx)) { + const auto t1 = 2*(tokens_cur[i].tid - whisper_token_beg(ctx)); if (!text.empty()) { if (params.print_realtime) { if (params.print_timestamps) { @@ -2523,14 +2545,18 @@ int whisper_full( } } - result_all.push_back({ t0, t1, text }); + result_all.push_back({ t0, t1, text, {} }); + for (int j = i0; j <= i; j++) { + result_all.back().tokens.push_back(tokens_cur[j]); + } } text = ""; - while (i < (int) result_cur.size() && result_cur[i].id > whisper_token_beg(ctx)) { + while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) { i++; } i--; - t0 = result_cur[i].t; + t0 = t1; + i0 = i + 1; } } @@ -2546,7 +2572,10 @@ int whisper_full( } } - result_all.push_back({ t0, t1, text }); + result_all.push_back({ t0, t1, text, {} }); + for (int j = i0; j < (int) tokens_cur.size(); j++) { + result_all.back().tokens.push_back(tokens_cur[j]); + } } } @@ -2571,3 +2600,15 @@ int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) { return ctx->result_all[i_segment].text.c_str(); } + +int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment) { + return ctx->result_all[i_segment].tokens.size(); +} + +const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) { + return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str(); +} + +float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) { + return ctx->result_all[i_segment].tokens[i_token].p; +} diff --git a/whisper.h b/whisper.h index 4423674..3435cd7 100644 --- a/whisper.h +++ b/whisper.h @@ -207,6 +207,15 @@ extern "C" { // Get the text of the specified segment. WHISPER_API const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment); + // Get number of tokens in the specified segment. + WHISPER_API int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment); + + // Get the token text of the specified token in the specified segment. + WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token); + + // Get the probability of the specified token in the specified segment. + WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token); + #ifdef __cplusplus } #endif