From 3d723d0b82b84cf998efb21b63a713ce73cfc3a9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 12 Jan 2023 21:44:05 +0200 Subject: [PATCH] whisper : wip decoding parameters + strategies (part 2) --- whisper.cpp | 589 +++++++++++++++++++++++++++++++++------------------- 1 file changed, 371 insertions(+), 218 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index 6759cd2..4e11e67 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -418,14 +418,31 @@ struct whisper_model { struct whisper_sequence { std::vector tokens; + + // the accumulated transcription in the current interation (used to truncate the tokens array) + int result_len; + + double sum_logprobs; }; +// TAGS: WHISPER_DECODER_INIT struct whisper_decoder { whisper_kv_cache kv_self; whisper_sequence sequence; - std::vector prompt; + int n_past; + int seek_delta; + + bool failed; // has the current segment failed to decode? + bool completed; // has the decoder completed the current segment? + bool has_ts; // have we already sampled a non-beg timestamp token for the current segment? + + std::vector tokens; + + // new token logits and logprobs after the last whisper_decode (1-dimensional array: [n_vocab]) + std::vector logits; + std::vector logprobs; }; struct whisper_context { @@ -447,23 +464,17 @@ struct whisper_context { whisper_decoder decoders[WHISPER_MAX_DECODERS] = {}; - int selected_decoder_id = 0; - // memory buffers used by encode / decode contexts std::vector buf_compute; std::vector buf_compute_layer; - // decode output + // decode output (2-dimensional array: [n_tokens][n_vocab]) std::vector probs; std::vector logits; std::vector result_all; std::vector prompt_past; - // used to avoid allocations - std::vector work_logits; - std::vector work_logprobs; - std::vector> probs_id; // [EXPERIMENTAL] token-level timestamps data @@ -651,8 +662,6 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con wctx.model.buf = new std::vector(); wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type)); - wctx.selected_decoder_id = 0; - if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, wctx.wtype, model.hparams.n_text_ctx)) { fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); return false; @@ -760,13 +769,13 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx); wctx.probs.reserve (vocab.n_vocab*model.hparams.n_text_ctx); - wctx.work_logits.reserve (vocab.n_vocab); - wctx.work_logprobs.reserve(vocab.n_vocab); - wctx.probs_id.reserve(n_vocab); + // TAGS: WHISPER_DECODER_INIT wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx); - wctx.decoders[0].prompt.reserve(model.hparams.n_text_ctx); + + wctx.decoders[0].logits.reserve (vocab.n_vocab); + wctx.decoders[0].logprobs.reserve(vocab.n_vocab); } size_t ctx_size = 0; @@ -1586,18 +1595,16 @@ static bool whisper_encode( // - n_past: number of past tokens to prefix the prompt with // static bool whisper_decode( - whisper_context & wctx, - const int n_threads, - const whisper_token * tokens, - const int n_tokens, - const int n_past) { + whisper_context & wctx, + whisper_decoder & decoder, + const int n_threads, + const whisper_token * tokens, + const int n_tokens, + const int n_past) { const auto & model = wctx.model; const auto & hparams = model.hparams; - WHISPER_ASSERT(wctx.selected_decoder_id >= 0); - WHISPER_ASSERT(wctx.selected_decoder_id < WHISPER_MAX_DECODERS); - - auto & kv_self = wctx.decoders[wctx.selected_decoder_id].kv_self; + auto & kv_self = decoder.kv_self; WHISPER_ASSERT(!!kv_self.ctx); @@ -2506,7 +2513,10 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) { int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) { const int64_t t_start_us = ggml_time_us(); - if (!whisper_decode(*ctx, n_threads, tokens, n_tokens, n_past)) { + // TODO: add selected_decoder_id to context + const int selected_decoder_id = 0; + + if (!whisper_decode(*ctx, ctx->decoders[selected_decoder_id], n_threads, tokens, n_tokens, n_past)) { fprintf(stderr, "%s: failed to eval\n", __func__); return 1; } @@ -2516,28 +2526,6 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i return 0; } -// TODO: remove -static struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) { - const int64_t t_start_sample_us = ggml_time_us(); - - const auto res = whisper_sample_best(*ctx, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), false, false); - - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - - return res; -} - -// TODO: remove -static struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial) { - const int64_t t_start_sample_us = ggml_time_us(); - - const auto res = whisper_sample_best(*ctx, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), true, is_initial); - - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - - return res; -} - int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) { const auto res = tokenize(ctx->vocab, text); @@ -2899,15 +2887,15 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str // forward declarations static std::vector get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window); static void whisper_exp_compute_token_level_timestamps( - struct whisper_context * ctx, - int i_segment, - float thold_pt, - float thold_ptsum); + struct whisper_context & ctx, + int i_segment, + float thold_pt, + float thold_ptsum); // wrap the last segment to max_len characters // returns the number of new segments -static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) { - auto segment = ctx->result_all.back(); +static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) { + auto segment = ctx.result_all.back(); int res = 1; int acc = 0; @@ -2916,34 +2904,34 @@ static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) { for (int i = 0; i < (int) segment.tokens.size(); i++) { const auto & token = segment.tokens[i]; - if (token.id >= whisper_token_eot(ctx)) { + if (token.id >= whisper_token_eot(&ctx)) { continue; } - const auto txt = whisper_token_to_str(ctx, token.id); + const auto txt = whisper_token_to_str(&ctx, token.id); const int cur = strlen(txt); if (acc + cur > max_len && i > 0) { // split here - ctx->result_all.back().text = std::move(text); - ctx->result_all.back().t1 = token.t0; - ctx->result_all.back().tokens.resize(i); + ctx.result_all.back().text = std::move(text); + ctx.result_all.back().t1 = token.t0; + ctx.result_all.back().tokens.resize(i); - ctx->result_all.push_back({}); - ctx->result_all.back().t0 = token.t0; - ctx->result_all.back().t1 = segment.t1; + ctx.result_all.push_back({}); + ctx.result_all.back().t0 = token.t0; + ctx.result_all.back().t1 = segment.t1; // add tokens [i, end] to the new segment - ctx->result_all.back().tokens.insert( - ctx->result_all.back().tokens.end(), + ctx.result_all.back().tokens.insert( + ctx.result_all.back().tokens.end(), segment.tokens.begin() + i, segment.tokens.end()); acc = 0; text = ""; - segment = ctx->result_all.back(); + segment = ctx.result_all.back(); i = -1; res++; @@ -2953,33 +2941,33 @@ static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) { } } - ctx->result_all.back().text = std::move(text); + ctx.result_all.back().text = std::move(text); return res; } -static struct whisper_token_data whisper_sample_next_token( - struct whisper_context * ctx, - struct whisper_full_params params, - double & sum_logprobs, - const std::vector & prompt, - const std::vector & tokens_cur) { - struct whisper_token_data result = {}; - - const auto & vocab = ctx->vocab; +// process the logits for the selected decoder +// - applyies logit filters +// - computes logprobs +static void whisper_process_logits( + struct whisper_context & ctx, + struct whisper_decoder & decoder, + struct whisper_full_params params) { + const auto & vocab = ctx.vocab; + const auto & tokens_cur = decoder.sequence.tokens; const bool is_initial = tokens_cur.size() == 0; const int n_logits = vocab.id_to_token.size(); - WHISPER_ASSERT(n_logits == ctx->vocab.n_vocab); + WHISPER_ASSERT(n_logits == ctx.vocab.n_vocab); // extract the logits for the last token - // we will be mutating and therefore we don't want to use the ctx->logits buffer directly - auto & logits = ctx->work_logits; - auto & logprobs = ctx->work_logprobs; + // we will be mutating and therefore we don't want to use the ctx.logits buffer directly + auto & logits = decoder.logits; + auto & logprobs = decoder.logprobs; { logits.resize(n_logits); - memcpy(logits.data(), ctx->logits.data() + (ctx->logits.size() - n_logits), n_logits*sizeof(float)); + memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float)); // will be populated a bit later logprobs.resize(n_logits); @@ -3023,7 +3011,7 @@ static struct whisper_token_data whisper_sample_next_token( // the initial timestamp cannot be larger than max_initial_timestamp // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429 if (is_initial && params.max_initial_timestamp > 0.0f) { - const float precision = float(WHISPER_CHUNK_SIZE)/ctx->model.hparams.n_audio_ctx; + const float precision = float(WHISPER_CHUNK_SIZE)/ctx.model.hparams.n_audio_ctx; const int tid0 = std::round(params.max_initial_timestamp/precision); for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++ i) { @@ -3091,24 +3079,24 @@ static struct whisper_token_data whisper_sample_next_token( printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]); #endif - switch (params.strategy) { - case WHISPER_SAMPLING_GREEDY: - { - // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L249-L274 - // TODO: implement - result = (is_initial) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx); - } break; - case WHISPER_SAMPLING_BEAM_SEARCH: - { - // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L277C13-L364 - // TODO: implement - } break; - } + //switch (params.strategy) { + // case WHISPER_SAMPLING_GREEDY: + // { + // // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L249-L274 + // // TODO: implement + // result = (is_initial) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx); + // } break; + // case WHISPER_SAMPLING_BEAM_SEARCH: + // { + // // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L277C13-L364 + // // TODO: implement + // } break; + //} - sum_logprobs += logprobs[result.id]; - printf("token: %16s, logprob: %8.3f, sum_logprobs: %8.3f, avg_logprob: %8.3f\n", vocab.id_to_token.at(result.id).c_str(), logprobs[result.id], sum_logprobs, sum_logprobs/(tokens_cur.size() + 1)); + //sum_logprobs += logprobs[result.id]; + //printf("token: %16s, logprob: %8.3f, sum_logprobs: %8.3f, avg_logprob: %8.3f\n", vocab.id_to_token.at(result.id).c_str(), logprobs[result.id], sum_logprobs, sum_logprobs/(tokens_cur.size() + 1)); - return result; + //return result; } int whisper_full( @@ -3191,6 +3179,7 @@ int whisper_full( }; for (int i = 1; i < n_decoders; i++) { + // TAGS: WHISPER_DECODER_INIT if (ctx->decoders[i].kv_self.ctx == nullptr) { ctx->decoders[i].kv_self = ctx->decoders[0].kv_self; if (!kv_cache_reinit(ctx->decoders[i].kv_self)) { @@ -3201,7 +3190,9 @@ int whisper_full( fprintf(stderr, "%s: initialized self-attention kv cache, decoder %d\n", __func__, i); ctx->decoders[i].sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity()); - ctx->decoders[i].prompt.reserve(ctx->decoders[0].prompt.capacity()); + + ctx->decoders[i].logits.reserve (ctx->vocab.n_vocab); + ctx->decoders[i].logprobs.reserve(ctx->vocab.n_vocab); } } @@ -3242,14 +3233,12 @@ int whisper_full( int progress_prev = 0; int progress_step = 5; - std::vector tokens_cur; - tokens_cur.reserve(whisper_n_text_ctx(ctx)); + int seek = seek_start; std::vector prompt; prompt.reserve(whisper_n_text_ctx(ctx)); // main loop - int seek = seek_start; while (true) { const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start); while (progress_cur >= progress_prev + progress_step) { @@ -3283,133 +3272,297 @@ int whisper_full( prompt_past.clear(); } - prompt.clear(); + //prompt.clear(); - // if we have already generated some text, use it as a prompt to condition the next generation - if (!prompt_past.empty()) { - int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size())); + //// if we have already generated some text, use it as a prompt to condition the next generation + //if (!prompt_past.empty()) { + // int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size())); - prompt = { whisper_token_prev(ctx) }; - prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end()); + // prompt = { whisper_token_prev(ctx) }; + // prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end()); - prompt_past.clear(); - prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end()); - } + // prompt_past.clear(); + // prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end()); + //} + + //prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); + + //// print the prompt + ////printf("\n\n"); + ////for (int i = 0; i < prompt.size(); i++) { + //// printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token[prompt[i]].c_str()); + ////} + ////printf("\n\n"); + + //int n_past = 0; + //int seek_delta = 100*WHISPER_CHUNK_SIZE; + + //// the accumulated transcription in the current interation + //int result_len = 0; + //tokens_cur.clear(); + + //bool failed = false; // has the current segment failed to decode? + //bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment? + + //double sum_logprobs = 0.0; + + //for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { + // if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) { + // fprintf(stderr, "%s: failed to decode\n", __func__); + // return -7; + // } + + // n_past += prompt.size(); + // prompt.clear(); + + // // sample the next token based on the selected decoding strategy + parameters + // // also, update the sliding window position based on the sampled timestamp tokens + // { + // const auto token = whisper_sample_next_token(ctx, params, sum_logprobs, prompt, tokens_cur); + + // // timestamp token - update sliding window + // if (token.id > whisper_token_beg(ctx)) { + // const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx)); + + // // do not allow to go back in time + // if (has_ts && seek_delta > seek_delta_new && result_len < i) { + // break; + // } + + // seek_delta = seek_delta_new; + // result_len = i + 1; + // has_ts = true; + // } + + // // add it to the context + // prompt.push_back(token.id); + // tokens_cur.push_back(token); + + // //{ + // // const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]"; + // // printf("%s: %3d %10s %6d %6.3f '%s'\n", __func__, i, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str()); + // //} + + // // end of segment + // if (token.id == whisper_token_eot(ctx) || // end of text token + // (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached + // (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached + // ) { + // if (result_len == 0) { + // if (seek + seek_delta + 100 >= seek_end) { + // result_len = i + 1; + // } else { + // failed = true; + // break; + // } + // } + + // if (params.single_segment) { + // result_len = i + 1; + // seek_delta = 100*WHISPER_CHUNK_SIZE; + // } + + // break; + // } + + // // TESTS: if no tensors are loaded, it means we are running tests + // if (ctx->model.n_loaded == 0) { + // seek_delta = 100*WHISPER_CHUNK_SIZE; + // break; + // } + // } + + // // sometimes, the decoding can get stuck in a repetition loop + // // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy + // if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) { + // failed = true; + // break; + // } + //} + + //if (failed) { + // // when we fail to sample timestamp token, retry by clearing the past prompt + // // if it fails again, then we advance the window by 1 second + // if (!prompt_past.empty()) { + // prompt_past.clear(); + // } else { + // fprintf(stderr, "\n%s: failed to generate timestamp token - skipping one second\n\n", __func__); + // seek += 100; + // } + // continue; + //} - prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); + //// shrink down to result_len + //tokens_cur.resize(result_len); - // print the prompt - //printf("\n\n"); - //for (int i = 0; i < prompt.size(); i++) { - // printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token[prompt[i]].c_str()); + //for (const auto & r : tokens_cur) { + // prompt_past.push_back(r.id); //} - //printf("\n\n"); - int n_past = 0; - int seek_delta = 100*WHISPER_CHUNK_SIZE; + for (int it = 0; it < (int) temperatures.size(); ++it) { + const float t_cur = temperatures[it]; - // the accumulated transcription in the current interation - int result_len = 0; - tokens_cur.clear(); + int n_decoders_cur = 1; + switch (params.strategy) { + case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: + { + if (t_cur > 0.0f) { + n_decoders_cur = params.greedy.best_of; + } + } break; + case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: + { + if (t_cur > 0.0f) { + n_decoders_cur = params.greedy.best_of; + } else { + n_decoders_cur = params.beam_search.beam_size; + } + } break; + }; - bool failed = false; // has the current segment failed to decode? - bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment? + fprintf(stderr, "\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur); - double sum_logprobs = 0.0; + if (t_cur > 0.5) { + prompt_past.clear(); - for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { - if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) { - fprintf(stderr, "%s: failed to decode\n", __func__); - return -7; + fprintf(stderr, "%s: clearing prompt_past\n", __func__); } - n_past += prompt.size(); - prompt.clear(); + // TAGS: WHISPER_DECODER_INIT + for (int i = 0; i < n_decoders_cur; ++i) { + auto & decoder = ctx->decoders[i]; - // sample the next token based on the selected decoding strategy + parameters - // also, update the sliding window position based on the sampled timestamp tokens - { - const auto token = whisper_sample_next_token(ctx, params, sum_logprobs, prompt, tokens_cur); + decoder.sequence.tokens.clear(); + decoder.sequence.result_len = 0; + decoder.sequence.sum_logprobs = 0.0; - // timestamp token - update sliding window - if (token.id > whisper_token_beg(ctx)) { - const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx)); + decoder.n_past = 0; + decoder.seek_delta = 100*WHISPER_CHUNK_SIZE; - // do not allow to go back in time - if (has_ts && seek_delta > seek_delta_new && result_len < i) { - break; - } + decoder.failed = false; + decoder.completed = false; + decoder.has_ts = false; + } - seek_delta = seek_delta_new; - result_len = i + 1; - has_ts = true; + // init prompt and kv cache for the current iteration + // run whisper_decoder() only for decoder 0 and copy the results for the other decoders + { + prompt.clear(); + + // if we have already generated some text, use it as a prompt to condition the next generation + if (!prompt_past.empty()) { + int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size())); + + prompt = { whisper_token_prev(ctx) }; + prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end()); } - // add it to the context - prompt.push_back(token.id); - tokens_cur.push_back(token); + // init new transcription with sot, language (opt) and task tokens + prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); - //{ - // const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]"; - // printf("%s: %3d %10s %6d %6.3f '%s'\n", __func__, i, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str()); + // print the prompt + //printf("\n\n"); + //for (int i = 0; i < prompt.size(); i++) { + // printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token[prompt[i]].c_str()); //} + //printf("\n\n"); - // end of segment - if (token.id == whisper_token_eot(ctx) || // end of text token - (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached - (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached - ) { - if (result_len == 0) { - if (seek + seek_delta + 100 >= seek_end) { - result_len = i + 1; - } else { - failed = true; - break; - } - } + if (whisper_decode(*ctx, ctx->decoders[0], params.n_threads, prompt.data(), prompt.size(), 0) != 0) { + fprintf(stderr, "%s: failed to decode\n", __func__); + return -7; + } - if (params.single_segment) { - result_len = i + 1; - seek_delta = 100*WHISPER_CHUNK_SIZE; - } + whisper_process_logits(*ctx, ctx->decoders[0], params); - break; + for (int i = 1; i < n_decoders_cur; ++i) { + auto & decoder = ctx->decoders[i]; + + memcpy(decoder.kv_self.buf.data(), ctx->decoders[0].kv_self.buf.data(), decoder.kv_self.buf.size()); + + decoder.n_past += prompt.size(); + + memcpy(decoder.logits.data(), ctx->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0])); + memcpy(decoder.logprobs.data(), ctx->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0])); } + } + + for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { + bool completed = true; + + switch (params.strategy) { + case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: + { + for (int i = 0; i < n_decoders_cur; ++i) { + auto & decoder = ctx->decoders[i]; + + if (decoder.completed || decoder.failed) { + continue; + } + + if (t_cur < 1e-6f) { + // select top token + } else { + } + } + } break; + case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: + { + // TODO: .. + } break; + }; - // TESTS: if no tensors are loaded, it means we are running tests - if (ctx->model.n_loaded == 0) { - seek_delta = 100*WHISPER_CHUNK_SIZE; + if (completed) { break; } + + for (int i = 0; i < n_decoders_cur; ++i) { + auto & decoder = ctx->decoders[i]; + + if (decoder.failed || decoder.completed) { + continue; + } + + decoder.tokens.resize(1); + decoder.tokens[0] = decoder.sequence.tokens.back().id; + + if (whisper_decode(*ctx, decoder, params.n_threads, decoder.tokens.data(), decoder.tokens.size(), decoder.n_past) != 0) { + fprintf(stderr, "%s: failed to decode\n", __func__); + return -7; + } + + whisper_process_logits(*ctx, decoder, params); + + ++decoder.n_past; + } } - // sometimes, the decoding can get stuck in a repetition loop - // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy - if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) { - failed = true; - break; + // TODO: rank the resulting sequences and select the best one + { } - } - if (failed) { - // when we fail to sample timestamp token, retry by clearing the past prompt - // if it fails again, then we advance the window by 1 second - if (!prompt_past.empty()) { - prompt_past.clear(); - } else { - fprintf(stderr, "\n%s: failed to generate timestamp token - skipping one second\n\n", __func__); - seek += 100; + bool success = true; + + // TODO: implement logprob threshold + compression threshold + { } - continue; - } - // shrink down to result_len - tokens_cur.resize(result_len); + if (success) { + break; + } - for (const auto & r : tokens_cur) { - prompt_past.push_back(r.id); + fprintf(stderr, "\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur); } + // TODO + const int seek_delta = 0; + whisper_sequence seq_best; + + const auto & tokens_cur = seq_best.tokens; + + // TODO: update prompt_past to keep only the last whisper_n_text_ctx(ctx)/2 tokens + //prompt_past.clear(); + //prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end()); + // store the text from this iteration if (!tokens_cur.empty()) { int i0 = 0; @@ -3450,10 +3603,10 @@ int whisper_full( if (params.token_timestamps) { whisper_exp_compute_token_level_timestamps( - ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); if (params.max_len > 0) { - n_new = whisper_wrap_segment(ctx, params.max_len); + n_new = whisper_wrap_segment(*ctx, params.max_len); } } if (params.new_segment_callback) { @@ -3494,10 +3647,10 @@ int whisper_full( if (params.token_timestamps) { whisper_exp_compute_token_level_timestamps( - ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); if (params.max_len > 0) { - n_new = whisper_wrap_segment(ctx, params.max_len); + n_new = whisper_wrap_segment(*ctx, params.max_len); } } if (params.new_segment_callback) { @@ -3535,9 +3688,6 @@ int whisper_full_parallel( ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx); ctx_p.probs.reserve (ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx); - ctx_p.work_logits.reserve (ctx_p.vocab.n_vocab); - ctx_p.work_logprobs.reserve(ctx_p.vocab.n_vocab); - ctx_p.probs_id.reserve(ctx_p.vocab.n_vocab); if (!kv_cache_reinit(ctx_p.kv_cross)) { @@ -3545,6 +3695,7 @@ int whisper_full_parallel( return false; } + // TAGS: WHISPER_DECODER_INIT for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) { if (ctx_p.decoders[j].kv_self.ctx && !kv_cache_reinit(ctx_p.decoders[j].kv_self)) { fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d, processor %d\n", __func__, j, i); @@ -3552,7 +3703,9 @@ int whisper_full_parallel( } ctx_p.decoders[j].sequence.tokens.reserve(ctx_p.model.hparams.n_text_ctx); - ctx_p.decoders[j].prompt.reserve(ctx_p.model.hparams.n_text_ctx); + + ctx_p.decoders[j].logits.reserve (ctx_p.vocab.n_vocab); + ctx_p.decoders[j].logprobs.reserve(ctx_p.vocab.n_vocab); } } @@ -3747,14 +3900,14 @@ static std::vector get_signal_energy(const float * signal, int n_samples, } static void whisper_exp_compute_token_level_timestamps( - struct whisper_context * ctx, - int i_segment, - float thold_pt, - float thold_ptsum) { - auto & segment = ctx->result_all[i_segment]; + struct whisper_context & ctx, + int i_segment, + float thold_pt, + float thold_ptsum) { + auto & segment = ctx.result_all[i_segment]; auto & tokens = segment.tokens; - const int n_samples = ctx->energy.size(); + const int n_samples = ctx.energy.size(); if (n_samples == 0) { fprintf(stderr, "%s: no signal data available\n", __func__); @@ -3777,28 +3930,28 @@ static void whisper_exp_compute_token_level_timestamps( return; } - auto & t_beg = ctx->t_beg; - auto & t_last = ctx->t_last; - auto & tid_last = ctx->tid_last; + auto & t_beg = ctx.t_beg; + auto & t_last = ctx.t_last; + auto & tid_last = ctx.tid_last; for (int j = 0; j < n; ++j) { auto & token = tokens[j]; if (j == 0) { - if (token.id == whisper_token_beg(ctx)) { + if (token.id == whisper_token_beg(&ctx)) { tokens[j ].t0 = t0; tokens[j ].t1 = t0; tokens[j + 1].t0 = t0; t_beg = t0; t_last = t0; - tid_last = whisper_token_beg(ctx); + tid_last = whisper_token_beg(&ctx); } else { tokens[j ].t0 = t_last; } } - const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx)); + const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(&ctx)); tokens[j].id = token.id; tokens[j].tid = token.tid; @@ -3806,7 +3959,7 @@ static void whisper_exp_compute_token_level_timestamps( tokens[j].pt = token.pt; tokens[j].ptsum = token.ptsum; - tokens[j].vlen = voice_length(whisper_token_to_str(ctx, token.id)); + tokens[j].vlen = voice_length(whisper_token_to_str(&ctx, token.id)); if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) { if (j > 0) { @@ -3885,7 +4038,7 @@ static void whisper_exp_compute_token_level_timestamps( const int hw = WHISPER_SAMPLE_RATE/8; for (int j = 0; j < n; j++) { - if (tokens[j].id >= whisper_token_eot(ctx)) { + if (tokens[j].id >= whisper_token_eot(&ctx)) { continue; } @@ -3900,15 +4053,15 @@ static void whisper_exp_compute_token_level_timestamps( float sum = 0.0f; for (int k = ss0; k < ss1; k++) { - sum += ctx->energy[k]; + sum += ctx.energy[k]; } const float thold = 0.5*sum/ns; { int k = s0; - if (ctx->energy[k] > thold && j > 0) { - while (k > 0 && ctx->energy[k] > thold) { + if (ctx.energy[k] > thold && j > 0) { + while (k > 0 && ctx.energy[k] > thold) { k--; } tokens[j].t0 = sample_to_timestamp(k); @@ -3918,7 +4071,7 @@ static void whisper_exp_compute_token_level_timestamps( s0 = k; } } else { - while (ctx->energy[k] < thold && k < s1) { + while (ctx.energy[k] < thold && k < s1) { k++; } s0 = k; @@ -3928,8 +4081,8 @@ static void whisper_exp_compute_token_level_timestamps( { int k = s1; - if (ctx->energy[k] > thold) { - while (k < n_samples - 1 && ctx->energy[k] > thold) { + if (ctx.energy[k] > thold) { + while (k < n_samples - 1 && ctx.energy[k] > thold) { k++; } tokens[j].t1 = sample_to_timestamp(k); @@ -3939,7 +4092,7 @@ static void whisper_exp_compute_token_level_timestamps( s1 = k; } } else { - while (ctx->energy[k] < thold && k > s0) { + while (ctx.energy[k] < thold && k > s0) { k--; } s1 = k;