From 624e4f531370d254b4e06268506f704524f57dc9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 18 Oct 2022 21:14:27 +0300 Subject: [PATCH] whisper : fix timestamp sampling --- examples/whisper/whisper.cpp | 39 +++++++++++++++++++++++------------- examples/whisper/whisper.h | 2 +- src/ggml.c | 3 +++ 3 files changed, 29 insertions(+), 15 deletions(-) diff --git a/examples/whisper/whisper.cpp b/examples/whisper/whisper.cpp index 236fcf1..2d2b8ce 100644 --- a/examples/whisper/whisper.cpp +++ b/examples/whisper/whisper.cpp @@ -1784,7 +1784,7 @@ bool whisper_decode( // the most basic sampling scheme - select the top token whisper_vocab::id whisper_sample_best( const whisper_vocab & vocab, - const float * probs, bool need_timestamp) { + const float * probs) { int n_logits = vocab.id_to_token.size(); std::vector> probs_id; @@ -1794,9 +1794,29 @@ whisper_vocab::id whisper_sample_best( probs_id.push_back(std::make_pair(probs[i], i)); } - const int top_k = 4; + double sum_ts = 0.0; + double max_tx = 0.0; + + for (int i = 0; i < vocab.token_beg; i++) { + max_tx = std::max(max_tx, probs_id[i].first); + } + + for (int i = vocab.token_beg; i < n_logits; i++) { + sum_ts += probs_id[i].first; + } + + // if the probability sum of all timestamp tokesn is higher than the max probability of the text tokens - sample a + // timestamp token + if (sum_ts > max_tx) { + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438 + for (int i = 0; i < vocab.token_beg; i++) { + probs_id[i].first = -INFINITY; + } + } // find the top K tokens + const int top_k = 4; + std::partial_sort( probs_id.begin(), probs_id.begin() + top_k, probs_id.end(), @@ -1811,15 +1831,6 @@ whisper_vocab::id whisper_sample_best( // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second); //} - if (need_timestamp) { - // at the end of the 30-second audio segment, we start giving preference to time tokens - for (int i = 0; i < top_k; i++) { - if (probs_id[i].second > vocab.token_beg + 1300 && probs_id[i].first > 0.01*probs_id[0].first) { - return probs_id[i].second; - } - } - } - int res = 0; while ((probs_id[res].second == vocab.token_sot || probs_id[res].second == vocab.token_solm || @@ -2155,11 +2166,11 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i return 0; } -whisper_token whisper_sample_best(struct whisper_context * ctx, bool need_timestamp) { +whisper_token whisper_sample_best(struct whisper_context * ctx) { const int64_t t_start_sample_us = ggml_time_us(); // TODO: simplify - auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), need_timestamp); + auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab)); ctx->t_sample_us += ggml_time_us() - t_start_sample_us; @@ -2437,7 +2448,7 @@ int whisper_full( whisper_token id = 0; whisper_token tid = whisper_token_beg(ctx); - id = whisper_sample_best(ctx, result_len == 0); + id = whisper_sample_best(ctx); if (i > 0) { tid = whisper_sample_timestamp(ctx); } diff --git a/examples/whisper/whisper.h b/examples/whisper/whisper.h index 45faa5b..4423674 100644 --- a/examples/whisper/whisper.h +++ b/examples/whisper/whisper.h @@ -120,7 +120,7 @@ extern "C" { // You can also implement your own sampling method using the whisper_get_probs() function. // whisper_sample_best() returns the token with the highest probability // whisper_sample_timestamp() returns the most probable timestamp token - WHISPER_API whisper_token whisper_sample_best(struct whisper_context * ctx, bool need_timestamp); + WHISPER_API whisper_token whisper_sample_best(struct whisper_context * ctx); WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx); // Return the id of the specified language, returns -1 if not found diff --git a/src/ggml.c b/src/ggml.c index 4861f24..115e619 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -75,6 +75,9 @@ ggml_fp16_t ggml_fp32_to_fp16(float x) { #include +// FP16 <-> FP32 +// ref: https://github.com/Maratyszcza/FP16 + static inline float fp32_from_bits(uint32_t w) { union { uint32_t as_bits;