From 5e97f80fc5ac60c4247583de7399975a32124213 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 14 Jan 2023 19:30:53 +0200 Subject: [PATCH] whisper : minor fixes --- examples/stream/stream.cpp | 6 +++++- whisper.cpp | 24 +++++++++++------------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/examples/stream/stream.cpp b/examples/stream/stream.cpp index c7aa871..c01a8df 100644 --- a/examples/stream/stream.cpp +++ b/examples/stream/stream.cpp @@ -459,7 +459,7 @@ int main(int argc, char ** argv) { struct whisper_context * ctx = whisper_init_from_file(params.model.c_str()); std::vector pcmf32 (n_samples_30s, 0.0f); - std::vector pcmf32_old(n_samples_30s, 0.0f); + std::vector pcmf32_old; std::vector pcmf32_new(n_samples_30s, 0.0f); std::vector prompt_tokens; @@ -615,6 +615,10 @@ int main(int argc, char ** argv) { wparams.audio_ctx = params.audio_ctx; wparams.speed_up = params.speed_up; + // disable best_of fallback + wparams.temperature_increment = -1.0f; + wparams.greedy.best_of = -1; + wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data(); wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size(); diff --git a/whisper.cpp b/whisper.cpp index 385d183..73afa7c 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -25,7 +25,7 @@ } \ } while (0) -#define WHISPER_DEBUG +//#define WHISPER_DEBUG #if defined(WHISPER_DEBUG) #define WHISPER_PRINT_DEBUG(...) \ do { \ @@ -3216,6 +3216,7 @@ int whisper_full( // initialize the decoders int n_decoders = 1; + switch (params.strategy) { case WHISPER_SAMPLING_GREEDY: { @@ -3227,6 +3228,8 @@ int whisper_full( } break; }; + n_decoders = std::max(1, n_decoders); + // TAGS: WHISPER_DECODER_INIT for (int j = 1; j < n_decoders; j++) { auto & decoder = ctx->decoders[j]; @@ -3348,13 +3351,9 @@ int whisper_full( } break; }; - WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur); - - if (t_cur > 0.5) { - prompt_past.clear(); + n_decoders_cur = std::max(1, n_decoders_cur); - WHISPER_PRINT_DEBUG("%s: clearing prompt_past\n", __func__); - } + WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur); // TAGS: WHISPER_DECODER_INIT for (int j = 0; j < n_decoders_cur; ++j) { @@ -3381,7 +3380,7 @@ int whisper_full( prompt.clear(); // if we have already generated some text, use it as a prompt to condition the next generation - if (!prompt_past.empty()) { + if (!prompt_past.empty() && t_cur > 0.5f) { int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size())); prompt = { whisper_token_prev(ctx) }; @@ -3392,11 +3391,11 @@ int whisper_full( prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); // print the prompt - //printf("\n\n"); + //WHISPER_PRINT_DEBUG("\n\n"); //for (int i = 0; i < (int) prompt.size(); i++) { - // printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str()); + // WHISPER_PRINT_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str()); //} - //printf("\n\n"); + //WHISPER_PRINT_DEBUG("\n\n"); if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) { fprintf(stderr, "%s: failed to decode\n", __func__); @@ -3608,7 +3607,6 @@ int whisper_full( bool success = true; // implement logprob threshold - // TODO: implement compression threshold { auto & decoder = ctx->decoders[best_decoder_id]; @@ -3646,7 +3644,7 @@ int whisper_full( } // store the text from this iteration - if (!tokens_cur.empty()) { + if (!tokens_cur.empty() && ctx->model.n_loaded > 0) { int i0 = 0; auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));