From 116dd67a158d328211bf2bd96b1e903bff8a4b8e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 14 Jan 2023 11:09:23 +0200 Subject: [PATCH] whisper : wip decoding parameters + strategies (part 3) --- whisper.cpp | 664 +++++++++++++++++++++++++--------------------------- whisper.h | 3 +- 2 files changed, 327 insertions(+), 340 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index 4e11e67..ed6c52d 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -422,7 +422,9 @@ struct whisper_sequence { // the accumulated transcription in the current interation (used to truncate the tokens array) int result_len; - double sum_logprobs; + double sum_logprobs; // the sum of the log probabilities of the tokens + double avg_logprobs; // the average log probability of the tokens + double score; // likelihood rank score }; // TAGS: WHISPER_DECODER_INIT @@ -438,11 +440,12 @@ struct whisper_decoder { bool completed; // has the decoder completed the current segment? bool has_ts; // have we already sampled a non-beg timestamp token for the current segment? - std::vector tokens; - - // new token logits and logprobs after the last whisper_decode (1-dimensional array: [n_vocab]) + // new token probs, logits and logprobs after the last whisper_decode (1-dimensional array: [n_vocab]) + std::vector probs; std::vector logits; std::vector logprobs; + + std::vector tokens_tmp; // used for whisper_decode calls }; struct whisper_context { @@ -774,6 +777,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con // TAGS: WHISPER_DECODER_INIT wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx); + wctx.decoders[0].probs.reserve (vocab.n_vocab); wctx.decoders[0].logits.reserve (vocab.n_vocab); wctx.decoders[0].logprobs.reserve(vocab.n_vocab); } @@ -1959,99 +1963,6 @@ static bool whisper_decode( return true; } -// the most basic sampling scheme - select the top token -static whisper_token_data whisper_sample_best( - whisper_context & ctx, - const float * probs, - bool force_timestamp, - bool is_initial) { - whisper_token_data result = { - 0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f, - }; - - const auto & vocab = ctx.vocab; - - const int n_logits = vocab.n_vocab; - - auto & probs_id = ctx.probs_id; - - probs_id.clear(); - for (int i = 0; i < n_logits; i++) { - probs_id.emplace_back(probs[i], i); - } - - { - double sum_ts = 0.0; - double max_ts = -1.0; - double max_tx = -1.0; - - for (int i = 0; i < vocab.token_beg; i++) { - max_tx = std::max(max_tx, probs_id[i].first); - } - - const auto i0 = is_initial ? vocab.token_beg + 101 : vocab.token_beg; - const auto i1 = is_initial ? vocab.token_beg + 101 : n_logits; - - // the initial timestamp cannot be larger than 100 - // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429 - if (is_initial) { - for (int i = i0; i < n_logits; ++ i) { - probs_id[i].first = -INFINITY; - } - } - - for (int i = vocab.token_beg; i < i1; i++) { - sum_ts += probs_id[i].first; - if (probs_id[i].first > max_ts) { - max_ts = probs_id[i].first; - result.tid = probs_id[i].second; - } - } - - // if the probability sum of all timestamp tokens is higher than the max probability of the text tokens - sample a - // timestamp token - if (sum_ts > max_tx || force_timestamp) { - // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438 - for (int i = 0; i < vocab.token_beg; i++) { - probs_id[i].first = -INFINITY; - } - } - - result.pt = max_ts/(sum_ts + 1e-10); - result.ptsum = sum_ts; - } - - // find the top K tokens - const int top_k = 4; - - std::partial_sort( - probs_id.begin(), - probs_id.begin() + top_k, probs_id.end(), - [](const std::pair & a, const std::pair & b) { - return a.first > b.first; - }); - - probs_id.resize(top_k); - - //printf("\n"); - //for (int i = 0; i < (int) probs_id.size(); i++) { - // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second); - //} - - int res = 0; - while ((probs_id[res].second == vocab.token_sot || - probs_id[res].second == vocab.token_solm || - probs_id[res].second == vocab.token_not) && - res < (int) probs_id.size() - 1) { - res++; - } - - result.id = probs_id[res].second; - result.p = probs_id[res].first; - - return result; -} - // 500 -> 00:05.000 // 6000 -> 01:00.000 static std::string to_timestamp(int64_t t, bool comma = false) { @@ -2793,8 +2704,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.suppress_blank =*/ true, - /*.temperature =*/ 0.0f, - /*.max_initial_timestamp =*/ 1.0f, + /*.temperature =*/ 0.0f, + /*.max_initial_timestamp =*/ 1.0f, + /*.length_penalty =*/ -1.0f, /*.temperature_increment =*/ 0.2f, /*.compression_ratio_threshold =*/ 2.4f, @@ -2809,7 +2721,6 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.beam_size =*/ -1, /*.patience =*/ -1.0f, - /*.length_penalty =*/ -1.0f, }, /*.new_segment_callback =*/ nullptr, @@ -2853,8 +2764,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.suppress_blank =*/ true, - /*.temperature =*/ 0.0f, - /*.max_initial_timestamp =*/ 1.0f, + /*.temperature =*/ 0.0f, + /*.max_initial_timestamp =*/ 1.0f, + /*.length_penalty =*/ -1.0f, /*.temperature_increment =*/ 0.2f, /*.compression_ratio_threshold =*/ 2.4f, @@ -2869,7 +2781,6 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.beam_size =*/ 5, /*.patience =*/ -1.0f, - /*.length_penalty =*/ -1.0f, }, /*.new_segment_callback =*/ nullptr, @@ -2963,9 +2874,13 @@ static void whisper_process_logits( // extract the logits for the last token // we will be mutating and therefore we don't want to use the ctx.logits buffer directly + auto & probs = decoder.probs; auto & logits = decoder.logits; auto & logprobs = decoder.logprobs; { + probs.resize(n_logits); + memcpy(probs.data(), ctx.probs.data() + (ctx.probs.size() - n_logits), n_logits*sizeof(float)); + logits.resize(n_logits); memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float)); @@ -2995,6 +2910,8 @@ static void whisper_process_logits( const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg; const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg; + //fprintf(stderr, "last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp); + if (last_was_timestamp) { if (penultimate_was_timestamp) { for (int i = vocab.token_beg; i < n_logits; ++ i) { @@ -3099,6 +3016,83 @@ static void whisper_process_logits( //return result; } +// select the most probable token +static whisper_token_data whisper_sample_best( + whisper_context & ctx, + whisper_decoder & decoder) { + whisper_token_data result = { + 0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f, + }; + + const auto & vocab = ctx.vocab; + + const auto & probs = decoder.probs; + const auto & logprobs = decoder.logprobs; + + const int n_logits = vocab.n_vocab; + + { + double sum_ts = 0.0; + double max_ts = 0.0; + + for (int i = vocab.token_beg; i < n_logits; i++) { + if (probs[i] == -INFINITY) { + continue; + } + + sum_ts += probs[i]; + if (max_ts < probs[i]) { + max_ts = probs[i]; + result.tid = i; + } + } + + result.pt = max_ts/(sum_ts + 1e-10); + result.ptsum = sum_ts; + } + + for (int i = 0; i < n_logits; ++i) { + // never sample these: + if (i == vocab.token_sot || + i == vocab.token_solm || + i == vocab.token_not) { + continue; + } + + if (result.p < probs[i]) { + result.id = i; + result.p = probs[i]; + result.plog = logprobs[i]; + } + } + + return result; +} + +// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L178-L192 +static void whisper_sequence_score( + const struct whisper_full_params & params, + whisper_sequence & sequence) { + WHISPER_ASSERT(sequence.result_len > 0); + + double result = 0.0f; + + for (int i = 0; i < sequence.result_len; ++i) { + result += sequence.tokens[i].plog; + } + + sequence.sum_logprobs = result; + sequence.avg_logprobs = result/sequence.result_len; + + double penalty = sequence.result_len; + + if (params.length_penalty > 0.0f) { + penalty = pow((5.0 + penalty) / 6.0, params.length_penalty); + } + + sequence.score = result/penalty; +} + int whisper_full( struct whisper_context * ctx, struct whisper_full_params params, @@ -3191,6 +3185,7 @@ int whisper_full( ctx->decoders[i].sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity()); + ctx->decoders[i].probs.reserve (ctx->vocab.n_vocab); ctx->decoders[i].logits.reserve (ctx->vocab.n_vocab); ctx->decoders[i].logprobs.reserve(ctx->vocab.n_vocab); } @@ -3261,7 +3256,7 @@ int whisper_full( } // encode audio features starting at offset seek - if (whisper_encode(ctx, seek, params.n_threads) != 0) { + if (!whisper_encode(*ctx, seek, params.n_threads)) { fprintf(stderr, "%s: failed to encode\n", __func__); return -6; } @@ -3272,132 +3267,7 @@ int whisper_full( prompt_past.clear(); } - //prompt.clear(); - - //// if we have already generated some text, use it as a prompt to condition the next generation - //if (!prompt_past.empty()) { - // int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size())); - - // prompt = { whisper_token_prev(ctx) }; - // prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end()); - - // prompt_past.clear(); - // prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end()); - //} - - //prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); - - //// print the prompt - ////printf("\n\n"); - ////for (int i = 0; i < prompt.size(); i++) { - //// printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token[prompt[i]].c_str()); - ////} - ////printf("\n\n"); - - //int n_past = 0; - //int seek_delta = 100*WHISPER_CHUNK_SIZE; - - //// the accumulated transcription in the current interation - //int result_len = 0; - //tokens_cur.clear(); - - //bool failed = false; // has the current segment failed to decode? - //bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment? - - //double sum_logprobs = 0.0; - - //for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { - // if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) { - // fprintf(stderr, "%s: failed to decode\n", __func__); - // return -7; - // } - - // n_past += prompt.size(); - // prompt.clear(); - - // // sample the next token based on the selected decoding strategy + parameters - // // also, update the sliding window position based on the sampled timestamp tokens - // { - // const auto token = whisper_sample_next_token(ctx, params, sum_logprobs, prompt, tokens_cur); - - // // timestamp token - update sliding window - // if (token.id > whisper_token_beg(ctx)) { - // const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx)); - - // // do not allow to go back in time - // if (has_ts && seek_delta > seek_delta_new && result_len < i) { - // break; - // } - - // seek_delta = seek_delta_new; - // result_len = i + 1; - // has_ts = true; - // } - - // // add it to the context - // prompt.push_back(token.id); - // tokens_cur.push_back(token); - - // //{ - // // const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]"; - // // printf("%s: %3d %10s %6d %6.3f '%s'\n", __func__, i, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str()); - // //} - - // // end of segment - // if (token.id == whisper_token_eot(ctx) || // end of text token - // (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached - // (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached - // ) { - // if (result_len == 0) { - // if (seek + seek_delta + 100 >= seek_end) { - // result_len = i + 1; - // } else { - // failed = true; - // break; - // } - // } - - // if (params.single_segment) { - // result_len = i + 1; - // seek_delta = 100*WHISPER_CHUNK_SIZE; - // } - - // break; - // } - - // // TESTS: if no tensors are loaded, it means we are running tests - // if (ctx->model.n_loaded == 0) { - // seek_delta = 100*WHISPER_CHUNK_SIZE; - // break; - // } - // } - - // // sometimes, the decoding can get stuck in a repetition loop - // // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy - // if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) { - // failed = true; - // break; - // } - //} - - //if (failed) { - // // when we fail to sample timestamp token, retry by clearing the past prompt - // // if it fails again, then we advance the window by 1 second - // if (!prompt_past.empty()) { - // prompt_past.clear(); - // } else { - // fprintf(stderr, "\n%s: failed to generate timestamp token - skipping one second\n\n", __func__); - // seek += 100; - // } - // continue; - //} - - //// shrink down to result_len - //tokens_cur.resize(result_len); - - //for (const auto & r : tokens_cur) { - // prompt_past.push_back(r.id); - //} + int best_decoder_id = 0; for (int it = 0; it < (int) temperatures.size(); ++it) { const float t_cur = temperatures[it]; @@ -3429,12 +3299,14 @@ int whisper_full( } // TAGS: WHISPER_DECODER_INIT - for (int i = 0; i < n_decoders_cur; ++i) { - auto & decoder = ctx->decoders[i]; + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; decoder.sequence.tokens.clear(); decoder.sequence.result_len = 0; decoder.sequence.sum_logprobs = 0.0; + decoder.sequence.avg_logprobs = 0.0; + decoder.sequence.score = 0.0; decoder.n_past = 0; decoder.seek_delta = 100*WHISPER_CHUNK_SIZE; @@ -3467,65 +3339,148 @@ int whisper_full( //} //printf("\n\n"); - if (whisper_decode(*ctx, ctx->decoders[0], params.n_threads, prompt.data(), prompt.size(), 0) != 0) { + if (!whisper_decode(*ctx, ctx->decoders[0], params.n_threads, prompt.data(), prompt.size(), 0)) { fprintf(stderr, "%s: failed to decode\n", __func__); return -7; } whisper_process_logits(*ctx, ctx->decoders[0], params); - for (int i = 1; i < n_decoders_cur; ++i) { - auto & decoder = ctx->decoders[i]; + for (int j = 1; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; memcpy(decoder.kv_self.buf.data(), ctx->decoders[0].kv_self.buf.data(), decoder.kv_self.buf.size()); decoder.n_past += prompt.size(); + memcpy(decoder.probs.data(), ctx->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0])); memcpy(decoder.logits.data(), ctx->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0])); memcpy(decoder.logprobs.data(), ctx->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0])); } } for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { - bool completed = true; - - switch (params.strategy) { - case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: - { - for (int i = 0; i < n_decoders_cur; ++i) { - auto & decoder = ctx->decoders[i]; + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; - if (decoder.completed || decoder.failed) { - continue; - } + if (decoder.completed || decoder.failed) { + continue; + } + switch (params.strategy) { + case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: + { if (t_cur < 1e-6f) { - // select top token + decoder.sequence.tokens.push_back(whisper_sample_best(*ctx, decoder)); } else { } + } break; + case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: + { + // TODO: .. + } break; + }; + + auto & has_ts = decoder.has_ts; + auto & failed = decoder.failed; + auto & completed = decoder.completed; + auto & seek_delta = decoder.seek_delta; + auto & result_len = decoder.sequence.result_len; + + { + const auto & token = decoder.sequence.tokens.back(); + + // timestamp token - update sliding window + if (token.id > whisper_token_beg(ctx)) { + const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx)); + + // do not allow to go back in time + if (has_ts && seek_delta > seek_delta_new && result_len < i) { + failed = true; // TODO: maybe this is not a failure ? + break; } - } break; - case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: + + seek_delta = seek_delta_new; + result_len = i + 1; + has_ts = true; + } + { - // TODO: .. - } break; - }; + const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]"; + printf("%s: %3d, decoder = %d, %10s %6d %6.3f '%s'\n", __func__, i, j, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str()); + } - if (completed) { - break; + // end of segment + if (token.id == whisper_token_eot(ctx) || // end of text token + (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached + (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached + ) { + if (result_len == 0) { + if (seek + seek_delta + 100 >= seek_end) { + result_len = i + 1; + } else { + failed = true; + break; + } + } + + if (params.single_segment) { + result_len = i + 1; + seek_delta = 100*WHISPER_CHUNK_SIZE; + } + + completed = true; + break; + } + + // TESTS: if no tensors are loaded, it means we are running tests + if (ctx->model.n_loaded == 0) { + seek_delta = 100*WHISPER_CHUNK_SIZE; + completed = true; + break; + } + } + + // sometimes, the decoding can get stuck in a repetition loop + // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy + if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) { + failed = true; + break; + } + } + + // check if all decoders have finished (i.e. completed or failed) + { + bool completed_all = true; + + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } + + completed_all = false; + } + + if (completed_all) { + break; + } } - for (int i = 0; i < n_decoders_cur; ++i) { - auto & decoder = ctx->decoders[i]; + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; if (decoder.failed || decoder.completed) { continue; } - decoder.tokens.resize(1); - decoder.tokens[0] = decoder.sequence.tokens.back().id; + decoder.tokens_tmp.resize(1); + decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id; + + //fprintf(stderr, "%s: decoder %d: token %d, n_past %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.n_past, decoder.seek_delta); - if (whisper_decode(*ctx, decoder, params.n_threads, decoder.tokens.data(), decoder.tokens.size(), decoder.n_past) != 0) { + if (!whisper_decode(*ctx, decoder, params.n_threads, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.n_past)) { fprintf(stderr, "%s: failed to decode\n", __func__); return -7; } @@ -3536,14 +3491,36 @@ int whisper_full( } } - // TODO: rank the resulting sequences and select the best one + // rank the resulting sequences and select the best one { + double best_score = -1e9; + + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; + + if (decoder.failed) { + continue; + } + + whisper_sequence_score(params, ctx->decoders[j].sequence); + + if (best_score < decoder.sequence.score) { + best_score = decoder.sequence.score; + best_decoder_id = j; + } + } } bool success = true; - // TODO: implement logprob threshold + compression threshold + // implement logprob threshold + // TODO: implement compression threshold { + auto & decoder = ctx->decoders[best_decoder_id]; + + if (decoder.sequence.avg_logprobs < params.logprob_threshold) { + success = false; + } } if (success) { @@ -3553,113 +3530,121 @@ int whisper_full( fprintf(stderr, "\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur); } - // TODO - const int seek_delta = 0; - whisper_sequence seq_best; + { + const auto & best_decoder = ctx->decoders[best_decoder_id]; - const auto & tokens_cur = seq_best.tokens; + const auto seek_delta = best_decoder.seek_delta; + const auto result_len = best_decoder.sequence.result_len; - // TODO: update prompt_past to keep only the last whisper_n_text_ctx(ctx)/2 tokens - //prompt_past.clear(); - //prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end()); + const auto & tokens_cur = best_decoder.sequence.tokens; - // store the text from this iteration - if (!tokens_cur.empty()) { - int i0 = 0; - auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx)); + // update prompt_past + prompt_past.clear(); + prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end()); - std::string text; + for (int i = 0; i < result_len; ++i) { + prompt_past.push_back(tokens_cur[i].id); + } - for (int i = 0; i < (int) tokens_cur.size(); i++) { - //printf("%s: %18s %6.3f %18s %6.3f\n", __func__, - // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p, - // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt); + // store the text from this iteration + if (!tokens_cur.empty()) { + int i0 = 0; + auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx)); - if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) { - } else { - text += whisper_token_to_str(ctx, tokens_cur[i].id); - } - if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) { - const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx)); - if (!text.empty()) { - const auto tt0 = params.speed_up ? 2*t0 : t0; - const auto tt1 = params.speed_up ? 2*t1 : t1; - - if (params.print_realtime) { - if (params.print_timestamps) { - printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); - } else { - printf("%s", text.c_str()); - fflush(stdout); + std::string text; + + for (int i = 0; i < (int) tokens_cur.size(); i++) { + //printf("%s: %18s %6.3f %18s %6.3f\n", __func__, + // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p, + // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt); + + if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) { + } else { + text += whisper_token_to_str(ctx, tokens_cur[i].id); + } + if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) { + const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx)); + if (!text.empty()) { + const auto tt0 = params.speed_up ? 2*t0 : t0; + const auto tt1 = params.speed_up ? 2*t1 : t1; + + if (params.print_realtime) { + if (params.print_timestamps) { + printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); + } else { + printf("%s", text.c_str()); + fflush(stdout); + } } - } - result_all.push_back({ tt0, tt1, text, {} }); - for (int j = i0; j <= i; j++) { - result_all.back().tokens.push_back(tokens_cur[j]); - } + result_all.push_back({ tt0, tt1, text, {} }); + for (int j = i0; j <= i; j++) { + result_all.back().tokens.push_back(tokens_cur[j]); + } - int n_new = 1; + int n_new = 1; - if (params.token_timestamps) { - whisper_exp_compute_token_level_timestamps( - *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + if (params.token_timestamps) { + whisper_exp_compute_token_level_timestamps( + *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); - if (params.max_len > 0) { - n_new = whisper_wrap_segment(*ctx, params.max_len); + if (params.max_len > 0) { + n_new = whisper_wrap_segment(*ctx, params.max_len); + } + } + if (params.new_segment_callback) { + params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data); } } - if (params.new_segment_callback) { - params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data); + text = ""; + while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) { + i++; } + i--; + t0 = t1; + i0 = i + 1; } - text = ""; - while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) { - i++; - } - i--; - t0 = t1; - i0 = i + 1; } - } - if (!text.empty()) { - const auto t1 = seek + seek_delta; + if (!text.empty()) { + const auto t1 = seek + seek_delta; - const auto tt0 = params.speed_up ? 2*t0 : t0; - const auto tt1 = params.speed_up ? 2*t1 : t1; + const auto tt0 = params.speed_up ? 2*t0 : t0; + const auto tt1 = params.speed_up ? 2*t1 : t1; - if (params.print_realtime) { - if (params.print_timestamps) { - printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); - } else { - printf("%s", text.c_str()); - fflush(stdout); + if (params.print_realtime) { + if (params.print_timestamps) { + printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); + } else { + printf("%s", text.c_str()); + fflush(stdout); + } } - } - result_all.push_back({ tt0, tt1, text, {} }); - for (int j = i0; j < (int) tokens_cur.size(); j++) { - result_all.back().tokens.push_back(tokens_cur[j]); - } + result_all.push_back({ tt0, tt1, text, {} }); + for (int j = i0; j < (int) tokens_cur.size(); j++) { + result_all.back().tokens.push_back(tokens_cur[j]); + } - int n_new = 1; + int n_new = 1; - if (params.token_timestamps) { - whisper_exp_compute_token_level_timestamps( - *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + if (params.token_timestamps) { + whisper_exp_compute_token_level_timestamps( + *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); - if (params.max_len > 0) { - n_new = whisper_wrap_segment(*ctx, params.max_len); + if (params.max_len > 0) { + n_new = whisper_wrap_segment(*ctx, params.max_len); + } + } + if (params.new_segment_callback) { + params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data); } - } - if (params.new_segment_callback) { - params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data); } } - } - seek += seek_delta; + // update audio window + seek += seek_delta; + } } return 0; @@ -3704,6 +3689,7 @@ int whisper_full_parallel( ctx_p.decoders[j].sequence.tokens.reserve(ctx_p.model.hparams.n_text_ctx); + ctx_p.decoders[j].probs.reserve (ctx_p.vocab.n_vocab); ctx_p.decoders[j].logits.reserve (ctx_p.vocab.n_vocab); ctx_p.decoders[j].logprobs.reserve(ctx_p.vocab.n_vocab); } diff --git a/whisper.h b/whisper.h index a3529c8..906c7c0 100644 --- a/whisper.h +++ b/whisper.h @@ -74,6 +74,7 @@ extern "C" { whisper_token tid; // forced timestamp token id float p; // probability of the token + float plog; // log probability of the token float pt; // probability of the timestamp token float ptsum; // sum of probabilities of all timestamp tokens @@ -280,6 +281,7 @@ extern "C" { float temperature; float max_initial_timestamp; + float length_penalty; // fallback parameters float temperature_increment; @@ -295,7 +297,6 @@ extern "C" { int beam_size; float patience; - float length_penalty; } beam_search; whisper_new_segment_callback new_segment_callback;