From c301a7942b615f7fb102b5ebbcc9e0638d4f5b00 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 14 Jan 2023 18:25:16 +0200 Subject: [PATCH] whisper : add WHISPER_DEBUG + diagnostic prints + new main args --- README.md | 12 +---- examples/main/main.cpp | 119 +++++++++++++++++++++++------------------ whisper.cpp | 48 ++++++++++------- whisper.h | 2 +- 4 files changed, 99 insertions(+), 82 deletions(-) diff --git a/README.md b/README.md index f22724a..448e758 100644 --- a/README.md +++ b/README.md @@ -212,17 +212,7 @@ make large ## Limitations - Inference only -- No GPU support -- Very basic greedy sampling scheme - always pick up the token with highest probability. - This should be similar to the [GreedyDecoder](https://github.com/openai/whisper/blob/main/whisper/decoding.py#L249-L274) - from the original python implementation, so in order to make a fair comparison between the 2 implementations, make sure - to run the python code with the following parameters: - - ``` - whisper --best_of None --beam_size None ... - ``` - - In the future, `whisper.cpp` will support more sampling strategies. +- No GPU support (yet) ## Another example diff --git a/examples/main/main.cpp b/examples/main/main.cpp index d149e49..18c434e 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -59,8 +59,12 @@ struct whisper_params { int32_t duration_ms = 0; int32_t max_context = -1; int32_t max_len = 0; + int32_t best_of = 5; + int32_t beam_size = -1; - float word_thold = 0.01f; + float word_thold = 0.01f; + float entropy_thold = 2.4f; + float logprob_thold = -1.0f; bool speed_up = false; bool translate = false; @@ -104,7 +108,11 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); } else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); } else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); } + else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); } + else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); } else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } + else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); } + else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } @@ -136,31 +144,35 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]); fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); - fprintf(stderr, " -h, --help [default] show this help message and exit\n"); - fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); - fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors); - fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); - fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n); - fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); - fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); - fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); - fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); - fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); - fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); - fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); - fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); - fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false"); - fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false"); - fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false"); - fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false"); - fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); - fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); - fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false"); - fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true"); - fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str()); - fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str()); - fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); - fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", ""); + fprintf(stderr, " -h, --help [default] show this help message and exit\n"); + fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); + fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors); + fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); + fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n); + fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); + fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); + fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); + fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of); + fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size); + fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); + fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold); + fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); + fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); + fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); + fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); + fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); + fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false"); + fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false"); + fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false"); + fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false"); + fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); + fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); + fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false"); + fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true"); + fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str()); + fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str()); + fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); + fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", ""); fprintf(stderr, "\n"); } @@ -331,20 +343,19 @@ bool output_csv(struct whisper_context * ctx, const char * fname) { const int n_segments = whisper_full_n_segments(ctx); for (int i = 0; i < n_segments; ++i) { const char * text = whisper_full_get_segment_text(ctx, i); - if (text[0] == ' ') - text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character. + if (text[0] == ' ') { + text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character. + } const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - //need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds. - fout << 10 * t0 << ", " - << 10 * t1 << ", \"" - << text << "\"\n"; + + //need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds. + fout << 10 * t0 << ", " << 10 * t1 << ", \"" << text << "\"\n"; } return true; } - // karaoke video generation // outputs a bash script that uses ffmpeg to generate a video with the subtitles // TODO: font parameter adjustments @@ -620,25 +631,29 @@ int main(int argc, char ** argv) { { whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); - wparams.print_realtime = false; - wparams.print_progress = params.print_progress; - wparams.print_timestamps = !params.no_timestamps; - wparams.print_special = params.print_special; - wparams.translate = params.translate; - wparams.language = params.language.c_str(); - wparams.n_threads = params.n_threads; - wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx; - wparams.offset_ms = params.offset_t_ms; - wparams.duration_ms = params.duration_ms; - - wparams.token_timestamps = params.output_wts || params.max_len > 0; - wparams.thold_pt = params.word_thold; - wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; - - wparams.speed_up = params.speed_up; - - wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data(); - wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size(); + wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY; + + wparams.print_realtime = false; + wparams.print_progress = params.print_progress; + wparams.print_timestamps = !params.no_timestamps; + wparams.print_special = params.print_special; + wparams.translate = params.translate; + wparams.language = params.language.c_str(); + wparams.n_threads = params.n_threads; + wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx; + wparams.offset_ms = params.offset_t_ms; + wparams.duration_ms = params.duration_ms; + + wparams.token_timestamps = params.output_wts || params.max_len > 0; + wparams.thold_pt = params.word_thold; + wparams.entropy_threshold = params.entropy_thold; + wparams.logprob_threshold = params.logprob_thold; + wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; + + wparams.speed_up = params.speed_up; + + wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data(); + wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size(); whisper_print_user_data user_data = { ¶ms, &pcmf32s }; diff --git a/whisper.cpp b/whisper.cpp index b964aa1..385d183 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -25,6 +25,16 @@ } \ } while (0) +#define WHISPER_DEBUG +#if defined(WHISPER_DEBUG) +#define WHISPER_PRINT_DEBUG(...) \ + do { \ + fprintf(stderr, __VA_ARGS__); \ + } while (0) +#else +#define WHISPER_PRINT_DEBUG(...) +#endif + #define WHISPER_USE_FLASH_ATTN //#define WHISPER_USE_FLASH_FF #define WHISPER_MAX_DECODERS 16 @@ -1640,7 +1650,7 @@ static bool whisper_decode( const int N = n_tokens; const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx; - //fprintf(stderr, "n_past = %d, N = %d, M = %d, n_ctx = %d\n", n_past, N, M, n_ctx); + WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx); struct ggml_init_params params; params.mem_size = wctx.buf_compute.size(); @@ -3177,8 +3187,8 @@ int whisper_full( } if (params.token_timestamps) { - ctx->t_beg = 0; - ctx->t_last = 0; + ctx->t_beg = 0; + ctx->t_last = 0; ctx->tid_last = 0; ctx->energy = get_signal_energy(samples, n_samples, 32); } @@ -3228,7 +3238,7 @@ int whisper_full( return -4; } - fprintf(stderr, "%s: initialized self-attention kv cache, decoder %d\n", __func__, j); + WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j); decoder.sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity()); @@ -3338,12 +3348,12 @@ int whisper_full( } break; }; - fprintf(stderr, "\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur); + WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur); if (t_cur > 0.5) { prompt_past.clear(); - fprintf(stderr, "%s: clearing prompt_past\n", __func__); + WHISPER_PRINT_DEBUG("%s: clearing prompt_past\n", __func__); } // TAGS: WHISPER_DECODER_INIT @@ -3465,11 +3475,13 @@ int whisper_full( has_ts = true; } - //{ - // const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]"; - // printf("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n", - // __func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token.at(token.id).c_str()); - //} +#ifdef WHISPER_DEBUG + { + const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]"; + WHISPER_PRINT_DEBUG("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n", + __func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token.at(token.id).c_str()); + } +#endif // end of segment if (token.id == whisper_token_eot(ctx) || // end of text token @@ -3541,7 +3553,7 @@ int whisper_full( decoder.tokens_tmp.resize(1); decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id; - //fprintf(stderr, "%s: decoder %d: token %d, n_past %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.n_past, decoder.seek_delta); + //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, n_past %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.n_past, decoder.seek_delta); if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.n_past, params.n_threads)) { fprintf(stderr, "%s: failed to decode\n", __func__); @@ -3574,11 +3586,11 @@ int whisper_full( decoder.sequence.tokens.resize(decoder.sequence.result_len); whisper_sequence_score(params, decoder.sequence); - fprintf(stderr, "%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n", + WHISPER_PRINT_DEBUG("%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n", __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy); if (decoder.sequence.result_len > 8 && decoder.sequence.entropy < params.entropy_threshold) { - fprintf(stderr, "%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n", + WHISPER_PRINT_DEBUG("%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n", __func__, j, decoder.sequence.entropy, params.entropy_threshold); decoder.failed = true; @@ -3606,15 +3618,15 @@ int whisper_full( } if (success) { - fprintf(stderr, "%s: best decoder = %d\n", __func__, best_decoder_id); + WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id); //for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) { - // fprintf(stderr, "%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str()); + // WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str()); //} break; } - fprintf(stderr, "\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur); + WHISPER_PRINT_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur); } { @@ -3734,7 +3746,7 @@ int whisper_full( // update audio window seek += seek_delta; - fprintf(stderr, "seek = %d, seek_delta = %d\n", seek, seek_delta); + WHISPER_PRINT_DEBUG("seek = %d, seek_delta = %d\n", seek, seek_delta); } } diff --git a/whisper.h b/whisper.h index 69a2246..4bcb0e6 100644 --- a/whisper.h +++ b/whisper.h @@ -281,7 +281,7 @@ extern "C" { float temperature_increment; float entropy_threshold; // analog tho OpenAI's compression_ratio_threshold float logprob_threshold; - float no_speech_threshold; + float no_speech_threshold; // TODO: not implemented struct { int best_of;