From 9551d7fabd690f0354c72bf871ceeeeaaf36e499 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 11 Jan 2023 22:15:54 +0200 Subject: [PATCH] whisper : wip decoding parameters + strategies --- whisper.cpp | 172 +++++++++++++++++++++++++++++++++++++--------------- whisper.h | 15 ++++- 2 files changed, 135 insertions(+), 52 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index 52c70dc..6759cd2 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -416,11 +416,16 @@ struct whisper_model { std::map tensors; }; +struct whisper_sequence { + std::vector tokens; +}; + struct whisper_decoder { whisper_kv_cache kv_self; -}; -struct whisper_sequence { + whisper_sequence sequence; + + std::vector prompt; }; struct whisper_context { @@ -759,6 +764,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con wctx.work_logprobs.reserve(vocab.n_vocab); wctx.probs_id.reserve(n_vocab); + + wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx); + wctx.decoders[0].prompt.reserve(model.hparams.n_text_ctx); } size_t ctx_size = 0; @@ -2766,46 +2774,54 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str case WHISPER_SAMPLING_GREEDY: { result = { - /*.strategy =*/ WHISPER_SAMPLING_GREEDY, + /*.strategy =*/ WHISPER_SAMPLING_GREEDY, - /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), - /*.n_max_text_ctx =*/ 16384, - /*.offset_ms =*/ 0, - /*.duration_ms =*/ 0, + /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), + /*.n_max_text_ctx =*/ 16384, + /*.offset_ms =*/ 0, + /*.duration_ms =*/ 0, - /*.translate =*/ false, - /*.no_context =*/ false, - /*.single_segment =*/ false, - /*.print_special =*/ false, - /*.print_progress =*/ true, - /*.print_realtime =*/ false, - /*.print_timestamps =*/ true, + /*.translate =*/ false, + /*.no_context =*/ false, + /*.single_segment =*/ false, + /*.print_special =*/ false, + /*.print_progress =*/ true, + /*.print_realtime =*/ false, + /*.print_timestamps =*/ true, - /*.token_timestamps =*/ false, - /*.thold_pt =*/ 0.01f, - /*.thold_ptsum =*/ 0.01f, - /*.max_len =*/ 0, - /*.max_tokens =*/ 0, + /*.token_timestamps =*/ false, + /*.thold_pt =*/ 0.01f, + /*.thold_ptsum =*/ 0.01f, + /*.max_len =*/ 0, + /*.max_tokens =*/ 0, - /*.speed_up =*/ false, - /*.audio_ctx =*/ 0, + /*.speed_up =*/ false, + /*.audio_ctx =*/ 0, - /*.prompt_tokens =*/ nullptr, - /*.prompt_n_tokens =*/ 0, + /*.prompt_tokens =*/ nullptr, + /*.prompt_n_tokens =*/ 0, - /*.language =*/ "en", + /*.language =*/ "en", + + /*.suppress_blank =*/ true, - /*.suppress_blank =*/ true, + /*.temperature =*/ 0.0f, + /*.max_initial_timestamp =*/ 1.0f, - /*.max_initial_timestamp =*/ 1.0, + /*.temperature_increment =*/ 0.2f, + /*.compression_ratio_threshold =*/ 2.4f, + /*.logprob_threshold =*/ -1.0f, + /*.no_speech_threshold =*/ 0.6f, - /*.greedy =*/ { - /*.dummy =*/ 0, + /*.greedy =*/ { + /*.best_of =*/ 5, }, - /*.beam_search =*/ { - /*.beam_width =*/ -1, - /*.n_best =*/ -1, + /*.beam_search =*/ { + /*.beam_size =*/ -1, + + /*.patience =*/ -1.0f, + /*.length_penalty =*/ -1.0f, }, /*.new_segment_callback =*/ nullptr, @@ -2847,17 +2863,25 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.language =*/ "en", - /*.suppress_blank =*/ true, + /*.suppress_blank =*/ true, + + /*.temperature =*/ 0.0f, + /*.max_initial_timestamp =*/ 1.0f, - /*.max_initial_timestamp =*/ 1.0, + /*.temperature_increment =*/ 0.2f, + /*.compression_ratio_threshold =*/ 2.4f, + /*.logprob_threshold =*/ -1.0f, + /*.no_speech_threshold =*/ 0.6f, - /*.greedy =*/ { - /*.dummy =*/ 0, + /*.greedy =*/ { + /*.best_of =*/ 5, }, - /*.beam_search =*/ { - /*.beam_width =*/ 10, - /*.n_best =*/ 5, + /*.beam_search =*/ { + /*.beam_size =*/ 5, + + /*.patience =*/ -1.0f, + /*.length_penalty =*/ -1.0f, }, /*.new_segment_callback =*/ nullptr, @@ -3142,6 +3166,45 @@ int whisper_full( return 0; } + // a set of temperatures to use + // [ t0, t0 + delta, t0 + 2*delta, ..., < 1.0f + 1e-6f ] + std::vector temperatures; + if (params.temperature_increment > 0.0f) { + for (float t = params.temperature; t < 1.0f + 1e-6f; t += params.temperature_increment) { + temperatures.push_back(t); + } + } else { + temperatures.push_back(params.temperature); + } + + // initialize the decoders + int n_decoders = 1; + switch (params.strategy) { + case WHISPER_SAMPLING_GREEDY: + { + n_decoders = params.greedy.best_of; + } break; + case WHISPER_SAMPLING_BEAM_SEARCH: + { + n_decoders = std::max(params.greedy.best_of, params.beam_search.beam_size); + } break; + }; + + for (int i = 1; i < n_decoders; i++) { + if (ctx->decoders[i].kv_self.ctx == nullptr) { + ctx->decoders[i].kv_self = ctx->decoders[0].kv_self; + if (!kv_cache_reinit(ctx->decoders[i].kv_self)) { + fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, i); + return -4; + } + + fprintf(stderr, "%s: initialized self-attention kv cache, decoder %d\n", __func__, i); + + ctx->decoders[i].sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity()); + ctx->decoders[i].prompt.reserve(ctx->decoders[0].prompt.capacity()); + } + } + // the accumulated text context so far auto & prompt_past = ctx->prompt_past; if (params.no_context) { @@ -3160,7 +3223,7 @@ int whisper_full( // overwrite audio_ctx, max allowed is hparams.n_audio_ctx if (params.audio_ctx > whisper_n_audio_ctx(ctx)) { fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx)); - return -4; + return -5; } ctx->exp_n_audio_ctx = params.audio_ctx; @@ -3201,12 +3264,6 @@ int whisper_full( break; } - // if there is a very short audio segment left to process, we remove any past prompt since it tends - // to confuse the decoder and often make it repeat or hallucinate stuff - if (seek > seek_start && seek + 500 >= seek_end) { - prompt_past.clear(); - } - if (params.encoder_begin_callback) { if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) { fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__); @@ -3217,7 +3274,13 @@ int whisper_full( // encode audio features starting at offset seek if (whisper_encode(ctx, seek, params.n_threads) != 0) { fprintf(stderr, "%s: failed to encode\n", __func__); - return -4; + return -6; + } + + // if there is a very short audio segment left to process, we remove any past prompt since it tends + // to confuse the decoder and often make it repeat or hallucinate stuff + if (seek > seek_start && seek + 500 >= seek_end) { + prompt_past.clear(); } prompt.clear(); @@ -3257,7 +3320,7 @@ int whisper_full( for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) { fprintf(stderr, "%s: failed to decode\n", __func__); - return -5; + return -7; } n_past += prompt.size(); @@ -3469,16 +3532,27 @@ int whisper_full_parallel( ctx_p = *ctx; + ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx); + ctx_p.probs.reserve (ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx); + + ctx_p.work_logits.reserve (ctx_p.vocab.n_vocab); + ctx_p.work_logprobs.reserve(ctx_p.vocab.n_vocab); + + ctx_p.probs_id.reserve(ctx_p.vocab.n_vocab); + if (!kv_cache_reinit(ctx_p.kv_cross)) { - fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention cache\n", __func__); + fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention, processor %d\n", __func__, i); return false; } for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) { if (ctx_p.decoders[j].kv_self.ctx && !kv_cache_reinit(ctx_p.decoders[j].kv_self)) { - fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention cache\n", __func__); + fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d, processor %d\n", __func__, j, i); return false; } + + ctx_p.decoders[j].sequence.tokens.reserve(ctx_p.model.hparams.n_text_ctx); + ctx_p.decoders[j].prompt.reserve(ctx_p.model.hparams.n_text_ctx); } } diff --git a/whisper.h b/whisper.h index 2983765..a3529c8 100644 --- a/whisper.h +++ b/whisper.h @@ -278,15 +278,24 @@ extern "C" { // common decoding parameters: bool suppress_blank; + float temperature; float max_initial_timestamp; + // fallback parameters + float temperature_increment; + float compression_ratio_threshold; + float logprob_threshold; + float no_speech_threshold; + struct { - int dummy; + int best_of; } greedy; struct { - int beam_width; - int n_best; + int beam_size; + + float patience; + float length_penalty; } beam_search; whisper_new_segment_callback new_segment_callback;