From abe104b80bb1871f3219c894bd062902fb11c2b4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 18 Dec 2022 13:58:25 +0200 Subject: [PATCH] whisper : prepare infra for new decoding strategies --- whisper.cpp | 87 +++++++++++++++++++++++++++++++++++++++++------------ whisper.h | 5 +-- 2 files changed, 71 insertions(+), 21 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index a645056..da43f5d 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -16,6 +16,14 @@ #include #include +#define WHISPER_ASSERT(x) \ + do { \ + if (!(x)) { \ + fprintf(stderr, "WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ + abort(); \ + } \ + } while (0) + #define USE_FLASH_ATTN //#define USE_FLASH_FF @@ -423,8 +431,9 @@ struct whisper_context { std::vector logits; std::vector result_all; + std::vector prompt_past; - std::vector prompt_past; + std::vector work_logits; // used to avoid allocations // [EXPERIMENTAL] token-level timestamps data int64_t t_beg; @@ -2689,12 +2698,13 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.language =*/ "en", + /*.suppress_blank =*/ true, + /*.greedy =*/ { - /*.n_past =*/ 0, + /*.dummy =*/ 0, }, /*.beam_search =*/ { - /*.n_past =*/ -1, /*.beam_width =*/ -1, /*.n_best =*/ -1, }, @@ -2738,12 +2748,13 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.language =*/ "en", + /*.suppress_blank =*/ true, + /*.greedy =*/ { - /*.n_past =*/ -1, + /*.dummy =*/ 0, }, /*.beam_search =*/ { - /*.n_past =*/ 0, /*.beam_width =*/ 10, /*.n_best =*/ 5, }, @@ -2822,6 +2833,50 @@ static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) { return res; } +static struct whisper_token_data whisper_sample_next_token( + struct whisper_context * ctx, + struct whisper_full_params params, + const std::vector & prompt, + const std::vector & tokens_cur) { + struct whisper_token_data result = {}; + + const auto & vocab = ctx->vocab; + + const bool is_initial = tokens_cur.size() == 0; + const int n_logits = vocab.id_to_token.size(); + + WHISPER_ASSERT(n_logits == ctx->vocab.n_vocab); + + // extract the logits for the last token + // we will be mutating and therefore we don't want to use the ctx->logits buffer directly + auto & logits = ctx->work_logits; + { + logits.resize(n_logits); + memcpy(logits.data(), ctx->logits.data() + (ctx->logits.size() - n_logits), n_logits*sizeof(float)); + } + + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L480-L493 + // TODO: apply logit filters here + { + } + + switch (params.strategy) { + case WHISPER_SAMPLING_GREEDY: + { + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L249-L274 + // TODO: implement + result = (is_initial) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx); + } break; + case WHISPER_SAMPLING_BEAM_SEARCH: + { + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L277C13-L364 + // TODO: implement + } break; + } + + return result; +} + int whisper_full( struct whisper_context * ctx, struct whisper_full_params params, @@ -2955,7 +3010,6 @@ int whisper_full( return -4; } - int n_past = 0; prompt.clear(); // if we have already generated some text, use it as a prompt to condition the next generation @@ -2971,8 +3025,6 @@ int whisper_full( prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); - int seek_delta = 100*WHISPER_CHUNK_SIZE; - // print the prompt //printf("\n\n"); //for (int i = 0; i < prompt.size(); i++) { @@ -2980,11 +3032,14 @@ int whisper_full( //} //printf("\n\n"); + int n_past = 0; + int seek_delta = 100*WHISPER_CHUNK_SIZE; + // the accumulated transcription in the current interation int result_len = 0; tokens_cur.clear(); - bool failed = false; + bool failed = false; // has the current segment failed to decode? bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment? for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { @@ -2996,15 +3051,10 @@ int whisper_full( n_past += prompt.size(); prompt.clear(); - // very basic greedy sampling strategy: - // - // - always take the most probable token - // - // more sophisticated sampling strategies could be implemented here, but we keep it simple - // feel free to experiment! - // + // sample the next token based on the selected decoding strategy + parameters + // also, update the sliding window position based on the sampled timestamp tokens { - const auto token = (i == 0) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx); + const auto token = whisper_sample_next_token(ctx, params, prompt, tokens_cur); // timestamp token - update sliding window if (token.id > whisper_token_beg(ctx)) { @@ -3059,8 +3109,7 @@ int whisper_full( } // sometimes, the decoding can get stuck in a repetition loop - // this is a simple strategy to avoid such cases - we simply flag the decoding as failed and advance - // the sliding window by 1 second + // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) { failed = true; break; diff --git a/whisper.h b/whisper.h index 63f61af..77fa89a 100644 --- a/whisper.h +++ b/whisper.h @@ -274,12 +274,13 @@ extern "C" { // for auto-detection, set to nullptr, "" or "auto" const char * language; + bool suppress_blank; + struct { - int n_past; + int dummy; } greedy; struct { - int n_past; int beam_width; int n_best; } beam_search;