From 2ba90581ce1ad19546b1925e5fed5bf5bbea5973 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 27 Nov 2022 19:59:06 +0200 Subject: [PATCH] whisper : try to improve the token sampling strategy - Add the "max_initial_timestaamp" token logic from OpenAI - Disallow sampling timestamps that are in the past --- whisper.cpp | 71 +++++++++++++++++++---------------------------------- whisper.h | 2 +- 2 files changed, 26 insertions(+), 47 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index fbcb5d1..ad33add 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -1846,7 +1846,9 @@ static bool whisper_decode( // the most basic sampling scheme - select the top token static whisper_token_data whisper_sample_best( const whisper_vocab & vocab, - const float * probs) { + const float * probs, + bool force_timestamp, + bool is_initial) { whisper_token_data result = { 0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f, }; @@ -1869,6 +1871,14 @@ static whisper_token_data whisper_sample_best( max_tx = std::max(max_tx, probs_id[i].first); } + // the initial timestamp cannot be larger than 100 + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429 + if (is_initial) { + for (int i = vocab.token_beg + 101; i < (int) probs_id.size(); ++ i) { + probs_id[i].first = -INFINITY; + } + } + for (int i = vocab.token_beg; i < n_logits; i++) { sum_ts += probs_id[i].first; if (probs_id[i].first > max_ts) { @@ -1879,7 +1889,7 @@ static whisper_token_data whisper_sample_best( // if the probability sum of all timestamp tokens is higher than the max probability of the text tokens - sample a // timestamp token - if (sum_ts > max_tx) { + if (sum_ts > max_tx || force_timestamp) { // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438 for (int i = 0; i < vocab.token_beg; i++) { probs_id[i].first = -INFINITY; @@ -1921,39 +1931,6 @@ static whisper_token_data whisper_sample_best( return result; } -// samples only from the timestamps tokens -static whisper_vocab::id whisper_sample_timestamp( - const whisper_vocab & vocab, - const float * probs) { - int n_logits = vocab.id_to_token.size(); - - std::vector> probs_id; - probs_id.reserve(n_logits); - - for (int i = vocab.token_beg + 1; i < n_logits; i++) { - probs_id.push_back(std::make_pair(probs[i], i)); - } - - const int top_k = 10; - - // find the top K tokens - std::partial_sort( - probs_id.begin(), - probs_id.begin() + top_k, probs_id.end(), - [](const std::pair & a, const std::pair & b) { - return a.first > b.first; - }); - - probs_id.resize(top_k); - - //printf("\n"); - //for (int i = 0; i < (int) probs_id.size(); i++) { - // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second); - //} - - return probs_id[0].second; -} - // 500 -> 00:05.000 // 6000 -> 01:00.000 static std::string to_timestamp(int64_t t, bool comma = false) { @@ -2284,19 +2261,17 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) { const int64_t t_start_sample_us = ggml_time_us(); - // TODO: simplify - auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab)); + const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), false, false); ctx->t_sample_us += ggml_time_us() - t_start_sample_us; return res; } -whisper_token whisper_sample_timestamp(struct whisper_context * ctx) { +struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial) { const int64_t t_start_sample_us = ggml_time_us(); - // TODO: simplify - auto res = whisper_sample_timestamp(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab)); + const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), true, is_initial); ctx->t_sample_us += ggml_time_us() - t_start_sample_us; @@ -2725,15 +2700,19 @@ int whisper_full( // feel free to experiment! // { - auto token = whisper_sample_best(ctx); - - if (i == 0) { - token.tid = whisper_token_beg(ctx); - } + const auto token = (i == 0) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx); // timestamp token - update sliding window if (token.id > whisper_token_beg(ctx)) { - seek_delta = 2*(token.id - whisper_token_beg(ctx)); + const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx)); + + // do not allow to go back in time + if (seek_delta != 100*WHISPER_CHUNK_SIZE && + seek_delta > seek_delta_new && result_len < i) { + break; + } + + seek_delta = seek_delta_new; result_len = i + 1; } diff --git a/whisper.h b/whisper.h index 156edbb..def77d4 100644 --- a/whisper.h +++ b/whisper.h @@ -137,7 +137,7 @@ extern "C" { // whisper_sample_best() returns the token with the highest probability // whisper_sample_timestamp() returns the most probable timestamp token WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx); - WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx); + WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial); // Return the id of the specified language, returns -1 if not found WHISPER_API int whisper_lang_id(const char * lang);