diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 18c434e..d52e1d7 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -652,6 +652,9 @@ int main(int argc, char ** argv) { wparams.speed_up = params.speed_up; + wparams.greedy.best_of = params.best_of; + wparams.beam_search.beam_size = params.beam_size; + wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data(); wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size(); diff --git a/whisper.cpp b/whisper.cpp index 73afa7c..0fb0f42 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -433,10 +433,11 @@ struct whisper_sequence { // the accumulated transcription in the current interation (used to truncate the tokens array) int result_len; - double sum_logprobs; // the sum of the log probabilities of the tokens - double avg_logprobs; // the average log probability of the tokens - double entropy; // the entropy of the tokens - double score; // likelihood rank score + double sum_logprobs_all; // the sum of the log probabilities of the tokens + double sum_logprobs; // the sum of the log probabilities of the tokens (first result_len tokens) + double avg_logprobs; // the average log probability of the tokens + double entropy; // the entropy of the tokens + double score; // likelihood rank score }; // TAGS: WHISPER_DECODER_INIT @@ -1650,7 +1651,7 @@ static bool whisper_decode( const int N = n_tokens; const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx; - WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx); + //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx); struct ggml_init_params params; params.mem_size = wctx.buf_compute.size(); @@ -3100,6 +3101,74 @@ static whisper_token_data whisper_sample_token( return result; } +static std::vector whisper_sample_token_topk( + whisper_context & ctx, + const whisper_decoder & decoder, + int k) { + const auto & vocab = ctx.vocab; + + const auto & probs = decoder.probs; + const auto & logits = decoder.logits; + const auto & logprobs = decoder.logprobs; + + const int n_logits = vocab.n_vocab; + + auto & logits_id = ctx.logits_id; + + logits_id.clear(); + for (int i = 0; i < n_logits; ++i) { + logits_id.push_back({ logits[i], i }); + } + + std::partial_sort( + logits_id.begin(), + logits_id.begin() + k, logits_id.end(), + [](const std::pair & a, const std::pair & b) { + return a.first > b.first; + }); + + std::vector result(k); + + { + double sum_ts = 0.0; + double max_ts = 0.0; + + for (int i = vocab.token_beg; i < n_logits; i++) { + if (probs[i] == -INFINITY) { + continue; + } + + sum_ts += probs[i]; + if (max_ts < probs[i]) { + max_ts = probs[i]; + result[0].tid = i; + } + } + + result[0].pt = max_ts/(sum_ts + 1e-10); + result[0].ptsum = sum_ts; + } + + for (int i = 0; i < k; ++i) { + result[i].id = logits_id[i].second; + result[i].p = probs[result[i].id]; + result[i].plog = logprobs[result[i].id]; + result[i].tid = result[0].tid; + result[i].pt = result[0].pt; + result[i].ptsum = result[0].ptsum; + result[i].t0 = -1; + result[i].t1 = -1; + result[i].vlen = 0.0f; + + if (result[i].id >= vocab.token_beg) { + result[i].tid = result[i].id; + result[i].pt = result[i].p; + } + } + + return result; +} + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L178-L192 static void whisper_sequence_score( const struct whisper_full_params & params, @@ -3120,7 +3189,7 @@ static void whisper_sequence_score( double penalty = sequence.result_len; if (params.length_penalty > 0.0f) { - penalty = pow((5.0 + penalty) / 6.0, params.length_penalty); + penalty = pow((5.0 + penalty)/6.0, params.length_penalty); } sequence.score = result/penalty; @@ -3141,7 +3210,8 @@ static void whisper_sequence_score( for (const auto & kv : token_counts) { const auto p = kv.second/(double)cnt; entropy -= p*log(p); - //printf("entropy: %d %f %f, count %d\n", kv.first, p, log(p), kv.second); + + //WHISPER_PRINT_DEBUG("entropy: %d %f %f, count %d\n", kv.first, p, log(p), kv.second); } sequence.entropy = entropy; @@ -3293,6 +3363,25 @@ int whisper_full( std::vector prompt; prompt.reserve(whisper_n_text_ctx(ctx)); + // beam-search helpers + struct kv_buf { + std::vector k; + std::vector v; + }; + + std::vector kv_bufs; + + struct beam_candidate { + int decoder_idx; + int seek_delta; + + bool has_ts; + + whisper_sequence sequence; + }; + + std::vector beam_candidates; + // main loop while (true) { const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start); @@ -3360,11 +3449,12 @@ int whisper_full( auto & decoder = ctx->decoders[j]; decoder.sequence.tokens.clear(); - decoder.sequence.result_len = 0; - decoder.sequence.sum_logprobs = -INFINITY; - decoder.sequence.avg_logprobs = -INFINITY; - decoder.sequence.entropy = 0.0f; - decoder.sequence.score = -INFINITY; + decoder.sequence.result_len = 0; + decoder.sequence.sum_logprobs_all = 0.0; + decoder.sequence.sum_logprobs = -INFINITY; + decoder.sequence.avg_logprobs = -INFINITY; + decoder.sequence.entropy = 0.0; + decoder.sequence.score = -INFINITY; decoder.n_past = 0; decoder.seek_delta = 100*WHISPER_CHUNK_SIZE; @@ -3412,7 +3502,8 @@ int whisper_full( for (int j = 1; j < n_decoders_cur; ++j) { auto & decoder = ctx->decoders[j]; - memcpy(decoder.kv_self.buf.data(), ctx->decoders[0].kv_self.buf.data(), decoder.kv_self.buf.size()); + memcpy(decoder.kv_self.k->data, ctx->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k)); + memcpy(decoder.kv_self.v->data, ctx->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v)); decoder.n_past += prompt.size(); @@ -3428,6 +3519,25 @@ int whisper_full( for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { const int64_t t_start_sample_us = ggml_time_us(); + if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) { + kv_bufs.resize(n_decoders_cur); + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } + + kv_bufs[j].k.resize(ggml_nbytes(decoder.kv_self.k)); + kv_bufs[j].v.resize(ggml_nbytes(decoder.kv_self.v)); + + memcpy(kv_bufs[j].k.data(), decoder.kv_self.k->data, kv_bufs[j].k.size()); + memcpy(kv_bufs[j].v.data(), decoder.kv_self.v->data, kv_bufs[j].v.size()); + } + + beam_candidates.clear(); + } + for (int j = 0; j < n_decoders_cur; ++j) { auto & decoder = ctx->decoders[j]; @@ -3443,12 +3553,65 @@ int whisper_full( } else { decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false)); } + + decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog; } break; case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: { - // TODO: .. + const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size); + + for (const auto & token : tokens_new) { + beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence }); + beam_candidates.back().sequence.tokens.push_back(token); + beam_candidates.back().sequence.sum_logprobs_all += token.plog; + + //WHISPER_PRINT_DEBUG("%s: beam candidate: %s (%f, %f)\n", __func__, ctx->vocab.id_to_token.at(token.id).c_str(), token.plog, beam_candidates.back().sequence.sum_logprobs_all); + } } break; }; + } + + if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) { + std::sort( + beam_candidates.begin(), + beam_candidates.end(), + [](const beam_candidate & a, const beam_candidate & b) { + return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all; + }); + + int cur_c = 0; + + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } + + auto & cur = beam_candidates[cur_c++]; + + while (beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) { + ++cur_c; + } + + decoder.sequence = cur.sequence; + decoder.seek_delta = cur.seek_delta; + decoder.has_ts = cur.has_ts; + + memcpy(decoder.kv_self.k->data, kv_bufs[cur.decoder_idx].k.data(), kv_bufs[cur.decoder_idx].k.size()); + memcpy(decoder.kv_self.v->data, kv_bufs[cur.decoder_idx].v.data(), kv_bufs[cur.decoder_idx].v.size()); + + WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n", + __func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all); + } + } + + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } auto & has_ts = decoder.has_ts; auto & failed = decoder.failed; @@ -3659,6 +3822,7 @@ int whisper_full( } else { text += whisper_token_to_str(ctx, tokens_cur[i].id); } + if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) { const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx)); if (!text.empty()) { @@ -4077,6 +4241,8 @@ static void whisper_exp_compute_token_level_timestamps( p1--; } + //printf("p0=%d p1=%d t0=%lld t1=%lld\n", p0, p1, tokens[p0].t0, tokens[p1].t1); + if (p1 > p0) { double psum = 0.0; for (int j = p0; j <= p1; j++) { @@ -4205,11 +4371,11 @@ static void whisper_exp_compute_token_level_timestamps( // debug info //for (int j = 0; j < n; ++j) { // const auto & token = tokens[j]; - // const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]"; + // const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(&ctx, token.tid) : "[?]"; // printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__, - // tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(ctx, token.id)); + // tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(&ctx, token.id)); - // if (tokens[j].id >= whisper_token_eot(ctx)) { + // if (tokens[j].id >= whisper_token_eot(&ctx)) { // continue; // } //} diff --git a/whisper.h b/whisper.h index 4bcb0e6..1266e0f 100644 --- a/whisper.h +++ b/whisper.h @@ -290,7 +290,7 @@ extern "C" { struct { int beam_size; - float patience; + float patience; // TODO: not implemented } beam_search; whisper_new_segment_callback new_segment_callback;