diff --git a/examples/main/main.cpp b/examples/main/main.cpp index d52e1d7..7dd9800 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -633,24 +633,24 @@ int main(int argc, char ** argv) { wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY; - wparams.print_realtime = false; - wparams.print_progress = params.print_progress; - wparams.print_timestamps = !params.no_timestamps; - wparams.print_special = params.print_special; - wparams.translate = params.translate; - wparams.language = params.language.c_str(); - wparams.n_threads = params.n_threads; - wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx; - wparams.offset_ms = params.offset_t_ms; - wparams.duration_ms = params.duration_ms; - - wparams.token_timestamps = params.output_wts || params.max_len > 0; - wparams.thold_pt = params.word_thold; - wparams.entropy_threshold = params.entropy_thold; - wparams.logprob_threshold = params.logprob_thold; - wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; - - wparams.speed_up = params.speed_up; + wparams.print_realtime = false; + wparams.print_progress = params.print_progress; + wparams.print_timestamps = !params.no_timestamps; + wparams.print_special = params.print_special; + wparams.translate = params.translate; + wparams.language = params.language.c_str(); + wparams.n_threads = params.n_threads; + wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx; + wparams.offset_ms = params.offset_t_ms; + wparams.duration_ms = params.duration_ms; + + wparams.token_timestamps = params.output_wts || params.max_len > 0; + wparams.thold_pt = params.word_thold; + wparams.entropy_thold = params.entropy_thold; + wparams.logprob_thold = params.logprob_thold; + wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; + + wparams.speed_up = params.speed_up; wparams.greedy.best_of = params.best_of; wparams.beam_search.beam_size = params.beam_size; diff --git a/examples/stream/stream.cpp b/examples/stream/stream.cpp index c01a8df..3432cb5 100644 --- a/examples/stream/stream.cpp +++ b/examples/stream/stream.cpp @@ -616,8 +616,8 @@ int main(int argc, char ** argv) { wparams.speed_up = params.speed_up; // disable best_of fallback - wparams.temperature_increment = -1.0f; - wparams.greedy.best_of = -1; + wparams.temperature_inc = -1.0f; + wparams.greedy.best_of = -1; wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data(); wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size(); diff --git a/whisper.cpp b/whisper.cpp index aad77b8..c400856 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -25,7 +25,9 @@ } \ } while (0) +// define this to enable verbose trace logging - useful for debugging purposes //#define WHISPER_DEBUG + #if defined(WHISPER_DEBUG) #define WHISPER_PRINT_DEBUG(...) \ do { \ @@ -380,6 +382,8 @@ struct whisper_kv_cache { struct ggml_context * ctx; std::vector buf; + + int n; // number of tokens currently in the cache }; struct whisper_model { @@ -442,12 +446,13 @@ struct whisper_sequence { // TAGS: WHISPER_DECODER_INIT struct whisper_decoder { + // each decoders keeps its own KV-cache whisper_kv_cache kv_self; + // the currently generated sequence of tokens whisper_sequence sequence; - int n_past; - int seek_delta; + int seek_delta; // the window shift found so far based on the decoded timestamp tokens bool failed; // has the current segment failed to decode? bool completed; // has the decoder completed the current segment? @@ -476,6 +481,8 @@ struct whisper_context { whisper_model model; whisper_vocab vocab; + // cross-attention KV cache for the decoders + // shared between all decoders whisper_kv_cache kv_cross; whisper_decoder decoders[WHISPER_MAX_DECODERS] = {}; @@ -490,6 +497,7 @@ struct whisper_context { std::vector result_all; std::vector prompt_past; + // work container used to avoid memory allocations std::vector> logits_id; mutable std::mt19937 rng; // used for sampling at t > 0.0 @@ -680,6 +688,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0); } + // initialize all memory buffers + // always have at least one decoder + wctx.model.buf = new std::vector(); wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type)); @@ -2671,127 +2682,77 @@ const char * whisper_print_system_info(void) { //////////////////////////////////////////////////////////////////////////// struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) { - struct whisper_full_params result; + struct whisper_full_params result = { + /*.strategy =*/ WHISPER_SAMPLING_GREEDY, - switch (strategy) { - case WHISPER_SAMPLING_GREEDY: - { - result = { - /*.strategy =*/ WHISPER_SAMPLING_GREEDY, + /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), + /*.n_max_text_ctx =*/ 16384, + /*.offset_ms =*/ 0, + /*.duration_ms =*/ 0, - /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), - /*.n_max_text_ctx =*/ 16384, - /*.offset_ms =*/ 0, - /*.duration_ms =*/ 0, + /*.translate =*/ false, + /*.no_context =*/ false, + /*.single_segment =*/ false, + /*.print_special =*/ false, + /*.print_progress =*/ true, + /*.print_realtime =*/ false, + /*.print_timestamps =*/ true, - /*.translate =*/ false, - /*.no_context =*/ false, - /*.single_segment =*/ false, - /*.print_special =*/ false, - /*.print_progress =*/ true, - /*.print_realtime =*/ false, - /*.print_timestamps =*/ true, + /*.token_timestamps =*/ false, + /*.thold_pt =*/ 0.01f, + /*.thold_ptsum =*/ 0.01f, + /*.max_len =*/ 0, + /*.max_tokens =*/ 0, - /*.token_timestamps =*/ false, - /*.thold_pt =*/ 0.01f, - /*.thold_ptsum =*/ 0.01f, - /*.max_len =*/ 0, - /*.max_tokens =*/ 0, + /*.speed_up =*/ false, + /*.audio_ctx =*/ 0, - /*.speed_up =*/ false, - /*.audio_ctx =*/ 0, + /*.prompt_tokens =*/ nullptr, + /*.prompt_n_tokens =*/ 0, - /*.prompt_tokens =*/ nullptr, - /*.prompt_n_tokens =*/ 0, + /*.language =*/ "en", - /*.language =*/ "en", + /*.suppress_blank =*/ true, - /*.suppress_blank =*/ true, + /*.temperature =*/ 0.0f, + /*.max_initial_ts =*/ 1.0f, + /*.length_penalty =*/ -1.0f, - /*.temperature =*/ 0.0f, - /*.max_initial_timestamp =*/ 1.0f, - /*.length_penalty =*/ -1.0f, + /*.temperature_inc =*/ 0.2f, + /*.entropy_thold =*/ 2.4f, + /*.logprob_thold =*/ -1.0f, + /*.no_speech_thold =*/ 0.6f, - /*.temperature_increment =*/ 0.2f, - /*.entropy_threshold =*/ 2.4f, - /*.logprob_threshold =*/ -1.0f, - /*.no_speech_threshold =*/ 0.6f, + /*.greedy =*/ { + /*.best_of =*/ -1, + }, - /*.greedy =*/ { - /*.best_of =*/ 5, - }, + /*.beam_search =*/ { + /*.beam_size =*/ -1, - /*.beam_search =*/ { - /*.beam_size =*/ -1, + /*.patience =*/ -1.0f, + }, - /*.patience =*/ -1.0f, - }, + /*.new_segment_callback =*/ nullptr, + /*.new_segment_callback_user_data =*/ nullptr, - /*.new_segment_callback =*/ nullptr, - /*.new_segment_callback_user_data =*/ nullptr, + /*.encoder_begin_callback =*/ nullptr, + /*.encoder_begin_callback_user_data =*/ nullptr, + }; - /*.encoder_begin_callback =*/ nullptr, - /*.encoder_begin_callback_user_data =*/ nullptr, + switch (strategy) { + case WHISPER_SAMPLING_GREEDY: + { + result.greedy = { + /*.best_of =*/ 1, }; } break; case WHISPER_SAMPLING_BEAM_SEARCH: { - result = { - /*.strategy =*/ WHISPER_SAMPLING_BEAM_SEARCH, - - /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), - /*.n_max_text_ctx =*/ 16384, - /*.offset_ms =*/ 0, - /*.duration_ms =*/ 0, - - /*.translate =*/ false, - /*.no_context =*/ false, - /*.single_segment =*/ false, - /*.print_special =*/ false, - /*.print_progress =*/ true, - /*.print_realtime =*/ false, - /*.print_timestamps =*/ true, - - /*.token_timestamps =*/ false, - /*.thold_pt =*/ 0.01f, - /*.thold_ptsum =*/ 0.01f, - /*.max_len =*/ 0, - /*.max_tokens =*/ 0, - - /*.speed_up =*/ false, - /*.audio_ctx =*/ 0, - - /*.prompt_tokens =*/ nullptr, - /*.prompt_n_tokens =*/ 0, - - /*.language =*/ "en", - - /*.suppress_blank =*/ true, - - /*.temperature =*/ 0.0f, - /*.max_initial_timestamp =*/ 1.0f, - /*.length_penalty =*/ -1.0f, - - /*.temperature_increment =*/ 0.2f, - /*.entropy_threshold =*/ 2.4f, - /*.logprob_threshold =*/ -1.0f, - /*.no_speech_threshold =*/ 0.6f, + result.beam_search = { + /*.beam_size =*/ 5, - /*.greedy =*/ { - /*.best_of =*/ 5, - }, - - /*.beam_search =*/ { - /*.beam_size =*/ 5, - - /*.patience =*/ -1.0f, - }, - - /*.new_segment_callback =*/ nullptr, - /*.new_segment_callback_user_data =*/ nullptr, - - /*.encoder_begin_callback =*/ nullptr, - /*.encoder_begin_callback_user_data =*/ nullptr, + /*.patience =*/ -1.0f, }; } break; } @@ -2862,8 +2823,8 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) { } // process the logits for the selected decoder -// - applyies logit filters -// - computes logprobs +// - applies logit filters +// - computes logprobs and probs static void whisper_process_logits( const struct whisper_context & ctx, const struct whisper_full_params params, @@ -2938,11 +2899,11 @@ static void whisper_process_logits( } } - // the initial timestamp cannot be larger than max_initial_timestamp + // the initial timestamp cannot be larger than max_initial_ts // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429 - if (is_initial && params.max_initial_timestamp > 0.0f) { + if (is_initial && params.max_initial_ts > 0.0f) { const float precision = float(WHISPER_CHUNK_SIZE)/ctx.model.hparams.n_audio_ctx; - const int tid0 = std::round(params.max_initial_timestamp/precision); + const int tid0 = std::round(params.max_initial_ts/precision); for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++i) { logits[i] = -INFINITY; @@ -3073,7 +3034,7 @@ static whisper_token_data whisper_sample_token( } } - result.pt = max_ts/(sum_ts + 1e-10); + result.pt = max_ts/(sum_ts + 1e-10); result.ptsum = sum_ts; } @@ -3127,7 +3088,13 @@ static std::vector whisper_sample_token_topk( return a.first > b.first; }); - std::vector result(k); + std::vector result; + result.reserve(k); + + whisper_token tid; + + float pt; + float ptsum; { double sum_ts = 0.0; @@ -3141,24 +3108,18 @@ static std::vector whisper_sample_token_topk( sum_ts += probs[i]; if (max_ts < probs[i]) { max_ts = probs[i]; - result[0].tid = i; + tid = i; } } - result[0].pt = max_ts/(sum_ts + 1e-10); - result[0].ptsum = sum_ts; + pt = max_ts/(sum_ts + 1e-10); + ptsum = sum_ts; } for (int i = 0; i < k; ++i) { - result[i].id = logits_id[i].second; - result[i].p = probs[result[i].id]; - result[i].plog = logprobs[result[i].id]; - result[i].tid = result[0].tid; - result[i].pt = result[0].pt; - result[i].ptsum = result[0].ptsum; - result[i].t0 = -1; - result[i].t1 = -1; - result[i].vlen = 0.0f; + const auto id = logits_id[i].second; + + result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, }); if (result[i].id >= vocab.token_beg) { result[i].tid = result[i].id; @@ -3276,8 +3237,8 @@ int whisper_full( // a set of temperatures to use // [ t0, t0 + delta, t0 + 2*delta, ..., < 1.0f + 1e-6f ] std::vector temperatures; - if (params.temperature_increment > 0.0f) { - for (float t = params.temperature; t < 1.0f + 1e-6f; t += params.temperature_increment) { + if (params.temperature_inc > 0.0f) { + for (float t = params.temperature; t < 1.0f + 1e-6f; t += params.temperature_inc) { temperatures.push_back(t); } } else { @@ -3448,6 +3409,8 @@ int whisper_full( for (int j = 0; j < n_decoders_cur; ++j) { auto & decoder = ctx->decoders[j]; + decoder.kv_self.n = 0; + decoder.sequence.tokens.clear(); decoder.sequence.result_len = 0; decoder.sequence.sum_logprobs_all = 0.0; @@ -3456,7 +3419,6 @@ int whisper_full( decoder.sequence.entropy = 0.0; decoder.sequence.score = -INFINITY; - decoder.n_past = 0; decoder.seek_delta = 100*WHISPER_CHUNK_SIZE; decoder.failed = false; @@ -3497,7 +3459,7 @@ int whisper_full( whisper_process_logits(*ctx, params, ctx->decoders[0], t_cur); - ctx->decoders[0].n_past += prompt.size(); + ctx->decoders[0].kv_self.n += prompt.size(); for (int j = 1; j < n_decoders_cur; ++j) { auto & decoder = ctx->decoders[j]; @@ -3505,7 +3467,7 @@ int whisper_full( memcpy(decoder.kv_self.k->data, ctx->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k)); memcpy(decoder.kv_self.v->data, ctx->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v)); - decoder.n_past += prompt.size(); + decoder.kv_self.n += prompt.size(); memcpy(decoder.probs.data(), ctx->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0])); memcpy(decoder.logits.data(), ctx->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0])); @@ -3519,6 +3481,7 @@ int whisper_full( for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { const int64_t t_start_sample_us = ggml_time_us(); + // store the KV caches of all decoders when doing beam-search if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) { kv_bufs.resize(n_decoders_cur); for (int j = 0; j < n_decoders_cur; ++j) { @@ -3538,6 +3501,7 @@ int whisper_full( beam_candidates.clear(); } + // generate new sequence candidates for each decoder for (int j = 0; j < n_decoders_cur; ++j) { auto & decoder = ctx->decoders[j]; @@ -3571,6 +3535,7 @@ int whisper_full( }; } + // for beam-search, choose the top candidates and update the KV caches if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) { std::sort( beam_candidates.begin(), @@ -3606,6 +3571,10 @@ int whisper_full( } } + // update the decoder state + // - check if the sequence is completed + // - check if the sequence is failed + // - update sliding window based on timestamp tokens for (int j = 0; j < n_decoders_cur; ++j) { auto & decoder = ctx->decoders[j]; @@ -3705,6 +3674,7 @@ int whisper_full( ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + // obtain logits for the next token for (int j = 0; j < n_decoders_cur; ++j) { auto & decoder = ctx->decoders[j]; @@ -3715,11 +3685,11 @@ int whisper_full( decoder.tokens_tmp.resize(1); decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id; - //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, n_past %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.n_past, decoder.seek_delta); + //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta); - if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.n_past, params.n_threads)) { + if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) { fprintf(stderr, "%s: failed to decode\n", __func__); - return -7; + return -8; } { @@ -3727,7 +3697,7 @@ int whisper_full( whisper_process_logits(*ctx, params, decoder, t_cur); - ++decoder.n_past; + ++decoder.kv_self.n; ctx->t_sample_us += ggml_time_us() - t_start_sample_us; } @@ -3736,7 +3706,7 @@ int whisper_full( // rank the resulting sequences and select the best one { - double best_score = -1e9; + double best_score = -INFINITY; for (int j = 0; j < n_decoders_cur; ++j) { auto & decoder = ctx->decoders[j]; @@ -3751,9 +3721,9 @@ int whisper_full( WHISPER_PRINT_DEBUG("%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n", __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy); - if (decoder.sequence.result_len > 8 && decoder.sequence.entropy < params.entropy_threshold) { + if (decoder.sequence.result_len > 8 && decoder.sequence.entropy < params.entropy_thold) { WHISPER_PRINT_DEBUG("%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n", - __func__, j, decoder.sequence.entropy, params.entropy_threshold); + __func__, j, decoder.sequence.entropy, params.entropy_thold); decoder.failed = true; @@ -3765,31 +3735,33 @@ int whisper_full( best_decoder_id = j; } } - } - bool success = true; + WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id); + } - // implement logprob threshold + // was the decoding successful for the current temperature? { - auto & decoder = ctx->decoders[best_decoder_id]; + bool success = true; + + const auto & decoder = ctx->decoders[best_decoder_id]; - if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_threshold) { + if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) { success = false; } - } - if (success) { - WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id); - //for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) { - // WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str()); - //} + if (success) { + //for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) { + // WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str()); + //} - break; + break; + } } WHISPER_PRINT_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur); } + // output results through a user-provided callback { const auto & best_decoder = ctx->decoders[best_decoder_id]; diff --git a/whisper.h b/whisper.h index 1266e0f..84504b7 100644 --- a/whisper.h +++ b/whisper.h @@ -137,6 +137,7 @@ extern "C" { // tokens + n_tokens is the provided context for the decoder. // n_past is the number of tokens to use from previous decoder calls. // Returns 0 on success + // TODO: add support for multiple decoders WHISPER_API int whisper_decode( struct whisper_context * ctx, const whisper_token * tokens, @@ -218,8 +219,8 @@ extern "C" { // Available sampling strategies enum whisper_sampling_strategy { - WHISPER_SAMPLING_GREEDY, // Always select the most probable token - WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet! + WHISPER_SAMPLING_GREEDY, // similar to OpenAI's GreefyDecoder + WHISPER_SAMPLING_BEAM_SEARCH, // similar to OpenAI's BeamSearchDecoder }; // Text segment callback @@ -239,17 +240,17 @@ extern "C" { enum whisper_sampling_strategy strategy; int n_threads; - int n_max_text_ctx; + int n_max_text_ctx; // max tokens to use from past text as prompt for the decoder int offset_ms; // start offset in ms int duration_ms; // audio duration to process in ms bool translate; - bool no_context; + bool no_context; // do not use initial prompt for the decoder (if any) bool single_segment; // force single segment output (useful for streaming) - bool print_special; - bool print_progress; - bool print_realtime; - bool print_timestamps; + bool print_special; // print special tokens (e.g. , , , etc.) + bool print_progress; // print progress information + bool print_realtime; // print results from within whisper.cpp (avoid it, use callback instead) + bool print_timestamps; // print timestamps for each text segment when printing realtime // [EXPERIMENTAL] token-level timestamps bool token_timestamps; // enable token-level timestamps @@ -259,10 +260,11 @@ extern "C" { int max_tokens; // max tokens per segment (0 = no limit) // [EXPERIMENTAL] speed-up techniques + // note: these can significantly reduce the quality of the output bool speed_up; // speed-up the audio by 2x using Phase Vocoder int audio_ctx; // overwrite the audio context size (0 = use default) - // tokens to provide the whisper model as initial prompt + // tokens to provide to the whisper decoder as initial prompt // these are prepended to any existing text context from a previous call const whisper_token * prompt_tokens; int prompt_n_tokens; @@ -271,31 +273,34 @@ extern "C" { const char * language; // common decoding parameters: - bool suppress_blank; + bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89 - float temperature; - float max_initial_timestamp; - float length_penalty; + float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478 + float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97 + float length_penalty; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L267 // fallback parameters - float temperature_increment; - float entropy_threshold; // analog tho OpenAI's compression_ratio_threshold - float logprob_threshold; - float no_speech_threshold; // TODO: not implemented + // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L274-L278 + float temperature_inc; + float entropy_thold; // similar to OpenAI's "compression_ratio_threshold" + float logprob_thold; + float no_speech_thold; // TODO: not implemented struct { - int best_of; + int best_of; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264 } greedy; struct { - int beam_size; + int beam_size; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265 - float patience; // TODO: not implemented + float patience; // TODO: not implemented, ref: https://arxiv.org/pdf/2204.05424.pdf } beam_search; + // called for every newly generated text segment whisper_new_segment_callback new_segment_callback; void * new_segment_callback_user_data; + // called each time before the encoder starts whisper_encoder_begin_callback encoder_begin_callback; void * encoder_begin_callback_user_data; };