whisper : wip decoding parameters + strategies (part 2)

pull/291/head
Georgi Gerganov 3 years ago
parent 9551d7fabd
commit 3d723d0b82
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -418,14 +418,31 @@ struct whisper_model {
struct whisper_sequence { struct whisper_sequence {
std::vector<whisper_token_data> tokens; std::vector<whisper_token_data> tokens;
// the accumulated transcription in the current interation (used to truncate the tokens array)
int result_len;
double sum_logprobs;
}; };
// TAGS: WHISPER_DECODER_INIT
struct whisper_decoder { struct whisper_decoder {
whisper_kv_cache kv_self; whisper_kv_cache kv_self;
whisper_sequence sequence; whisper_sequence sequence;
std::vector<whisper_token> prompt; int n_past;
int seek_delta;
bool failed; // has the current segment failed to decode?
bool completed; // has the decoder completed the current segment?
bool has_ts; // have we already sampled a non-beg timestamp token for the current segment?
std::vector<whisper_token> tokens;
// new token logits and logprobs after the last whisper_decode (1-dimensional array: [n_vocab])
std::vector<float> logits;
std::vector<float> logprobs;
}; };
struct whisper_context { struct whisper_context {
@ -447,23 +464,17 @@ struct whisper_context {
whisper_decoder decoders[WHISPER_MAX_DECODERS] = {}; whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
int selected_decoder_id = 0;
// memory buffers used by encode / decode contexts // memory buffers used by encode / decode contexts
std::vector<uint8_t> buf_compute; std::vector<uint8_t> buf_compute;
std::vector<uint8_t> buf_compute_layer; std::vector<uint8_t> buf_compute_layer;
// decode output // decode output (2-dimensional array: [n_tokens][n_vocab])
std::vector<float> probs; std::vector<float> probs;
std::vector<float> logits; std::vector<float> logits;
std::vector<whisper_segment> result_all; std::vector<whisper_segment> result_all;
std::vector<whisper_token> prompt_past; std::vector<whisper_token> prompt_past;
// used to avoid allocations
std::vector<float> work_logits;
std::vector<float> work_logprobs;
std::vector<std::pair<double, whisper_vocab::id>> probs_id; std::vector<std::pair<double, whisper_vocab::id>> probs_id;
// [EXPERIMENTAL] token-level timestamps data // [EXPERIMENTAL] token-level timestamps data
@ -651,8 +662,6 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
wctx.model.buf = new std::vector<uint8_t>(); wctx.model.buf = new std::vector<uint8_t>();
wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type)); wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type));
wctx.selected_decoder_id = 0;
if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, wctx.wtype, model.hparams.n_text_ctx)) { if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, wctx.wtype, model.hparams.n_text_ctx)) {
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
return false; return false;
@ -760,13 +769,13 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx); wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
wctx.probs.reserve (vocab.n_vocab*model.hparams.n_text_ctx); wctx.probs.reserve (vocab.n_vocab*model.hparams.n_text_ctx);
wctx.work_logits.reserve (vocab.n_vocab);
wctx.work_logprobs.reserve(vocab.n_vocab);
wctx.probs_id.reserve(n_vocab); wctx.probs_id.reserve(n_vocab);
// TAGS: WHISPER_DECODER_INIT
wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx); wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx);
wctx.decoders[0].prompt.reserve(model.hparams.n_text_ctx);
wctx.decoders[0].logits.reserve (vocab.n_vocab);
wctx.decoders[0].logprobs.reserve(vocab.n_vocab);
} }
size_t ctx_size = 0; size_t ctx_size = 0;
@ -1586,18 +1595,16 @@ static bool whisper_encode(
// - n_past: number of past tokens to prefix the prompt with // - n_past: number of past tokens to prefix the prompt with
// //
static bool whisper_decode( static bool whisper_decode(
whisper_context & wctx, whisper_context & wctx,
const int n_threads, whisper_decoder & decoder,
const whisper_token * tokens, const int n_threads,
const int n_tokens, const whisper_token * tokens,
const int n_past) { const int n_tokens,
const int n_past) {
const auto & model = wctx.model; const auto & model = wctx.model;
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
WHISPER_ASSERT(wctx.selected_decoder_id >= 0); auto & kv_self = decoder.kv_self;
WHISPER_ASSERT(wctx.selected_decoder_id < WHISPER_MAX_DECODERS);
auto & kv_self = wctx.decoders[wctx.selected_decoder_id].kv_self;
WHISPER_ASSERT(!!kv_self.ctx); WHISPER_ASSERT(!!kv_self.ctx);
@ -2506,7 +2513,10 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) { int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
const int64_t t_start_us = ggml_time_us(); const int64_t t_start_us = ggml_time_us();
if (!whisper_decode(*ctx, n_threads, tokens, n_tokens, n_past)) { // TODO: add selected_decoder_id to context
const int selected_decoder_id = 0;
if (!whisper_decode(*ctx, ctx->decoders[selected_decoder_id], n_threads, tokens, n_tokens, n_past)) {
fprintf(stderr, "%s: failed to eval\n", __func__); fprintf(stderr, "%s: failed to eval\n", __func__);
return 1; return 1;
} }
@ -2516,28 +2526,6 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
return 0; return 0;
} }
// TODO: remove
static struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) {
const int64_t t_start_sample_us = ggml_time_us();
const auto res = whisper_sample_best(*ctx, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), false, false);
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
return res;
}
// TODO: remove
static struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial) {
const int64_t t_start_sample_us = ggml_time_us();
const auto res = whisper_sample_best(*ctx, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), true, is_initial);
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
return res;
}
int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) { int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) {
const auto res = tokenize(ctx->vocab, text); const auto res = tokenize(ctx->vocab, text);
@ -2899,15 +2887,15 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
// forward declarations // forward declarations
static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window); static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window);
static void whisper_exp_compute_token_level_timestamps( static void whisper_exp_compute_token_level_timestamps(
struct whisper_context * ctx, struct whisper_context & ctx,
int i_segment, int i_segment,
float thold_pt, float thold_pt,
float thold_ptsum); float thold_ptsum);
// wrap the last segment to max_len characters // wrap the last segment to max_len characters
// returns the number of new segments // returns the number of new segments
static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) { static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) {
auto segment = ctx->result_all.back(); auto segment = ctx.result_all.back();
int res = 1; int res = 1;
int acc = 0; int acc = 0;
@ -2916,34 +2904,34 @@ static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) {
for (int i = 0; i < (int) segment.tokens.size(); i++) { for (int i = 0; i < (int) segment.tokens.size(); i++) {
const auto & token = segment.tokens[i]; const auto & token = segment.tokens[i];
if (token.id >= whisper_token_eot(ctx)) { if (token.id >= whisper_token_eot(&ctx)) {
continue; continue;
} }
const auto txt = whisper_token_to_str(ctx, token.id); const auto txt = whisper_token_to_str(&ctx, token.id);
const int cur = strlen(txt); const int cur = strlen(txt);
if (acc + cur > max_len && i > 0) { if (acc + cur > max_len && i > 0) {
// split here // split here
ctx->result_all.back().text = std::move(text); ctx.result_all.back().text = std::move(text);
ctx->result_all.back().t1 = token.t0; ctx.result_all.back().t1 = token.t0;
ctx->result_all.back().tokens.resize(i); ctx.result_all.back().tokens.resize(i);
ctx->result_all.push_back({}); ctx.result_all.push_back({});
ctx->result_all.back().t0 = token.t0; ctx.result_all.back().t0 = token.t0;
ctx->result_all.back().t1 = segment.t1; ctx.result_all.back().t1 = segment.t1;
// add tokens [i, end] to the new segment // add tokens [i, end] to the new segment
ctx->result_all.back().tokens.insert( ctx.result_all.back().tokens.insert(
ctx->result_all.back().tokens.end(), ctx.result_all.back().tokens.end(),
segment.tokens.begin() + i, segment.tokens.begin() + i,
segment.tokens.end()); segment.tokens.end());
acc = 0; acc = 0;
text = ""; text = "";
segment = ctx->result_all.back(); segment = ctx.result_all.back();
i = -1; i = -1;
res++; res++;
@ -2953,33 +2941,33 @@ static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) {
} }
} }
ctx->result_all.back().text = std::move(text); ctx.result_all.back().text = std::move(text);
return res; return res;
} }
static struct whisper_token_data whisper_sample_next_token( // process the logits for the selected decoder
struct whisper_context * ctx, // - applyies logit filters
struct whisper_full_params params, // - computes logprobs
double & sum_logprobs, static void whisper_process_logits(
const std::vector<whisper_token> & prompt, struct whisper_context & ctx,
const std::vector<whisper_token_data> & tokens_cur) { struct whisper_decoder & decoder,
struct whisper_token_data result = {}; struct whisper_full_params params) {
const auto & vocab = ctx.vocab;
const auto & vocab = ctx->vocab; const auto & tokens_cur = decoder.sequence.tokens;
const bool is_initial = tokens_cur.size() == 0; const bool is_initial = tokens_cur.size() == 0;
const int n_logits = vocab.id_to_token.size(); const int n_logits = vocab.id_to_token.size();
WHISPER_ASSERT(n_logits == ctx->vocab.n_vocab); WHISPER_ASSERT(n_logits == ctx.vocab.n_vocab);
// extract the logits for the last token // extract the logits for the last token
// we will be mutating and therefore we don't want to use the ctx->logits buffer directly // we will be mutating and therefore we don't want to use the ctx.logits buffer directly
auto & logits = ctx->work_logits; auto & logits = decoder.logits;
auto & logprobs = ctx->work_logprobs; auto & logprobs = decoder.logprobs;
{ {
logits.resize(n_logits); logits.resize(n_logits);
memcpy(logits.data(), ctx->logits.data() + (ctx->logits.size() - n_logits), n_logits*sizeof(float)); memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float));
// will be populated a bit later // will be populated a bit later
logprobs.resize(n_logits); logprobs.resize(n_logits);
@ -3023,7 +3011,7 @@ static struct whisper_token_data whisper_sample_next_token(
// the initial timestamp cannot be larger than max_initial_timestamp // the initial timestamp cannot be larger than max_initial_timestamp
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429 // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429
if (is_initial && params.max_initial_timestamp > 0.0f) { if (is_initial && params.max_initial_timestamp > 0.0f) {
const float precision = float(WHISPER_CHUNK_SIZE)/ctx->model.hparams.n_audio_ctx; const float precision = float(WHISPER_CHUNK_SIZE)/ctx.model.hparams.n_audio_ctx;
const int tid0 = std::round(params.max_initial_timestamp/precision); const int tid0 = std::round(params.max_initial_timestamp/precision);
for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++ i) { for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++ i) {
@ -3091,24 +3079,24 @@ static struct whisper_token_data whisper_sample_next_token(
printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]); printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
#endif #endif
switch (params.strategy) { //switch (params.strategy) {
case WHISPER_SAMPLING_GREEDY: // case WHISPER_SAMPLING_GREEDY:
{ // {
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L249-L274 // // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L249-L274
// TODO: implement // // TODO: implement
result = (is_initial) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx); // result = (is_initial) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx);
} break; // } break;
case WHISPER_SAMPLING_BEAM_SEARCH: // case WHISPER_SAMPLING_BEAM_SEARCH:
{ // {
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L277C13-L364 // // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L277C13-L364
// TODO: implement // // TODO: implement
} break; // } break;
} //}
sum_logprobs += logprobs[result.id]; //sum_logprobs += logprobs[result.id];
printf("token: %16s, logprob: %8.3f, sum_logprobs: %8.3f, avg_logprob: %8.3f\n", vocab.id_to_token.at(result.id).c_str(), logprobs[result.id], sum_logprobs, sum_logprobs/(tokens_cur.size() + 1)); //printf("token: %16s, logprob: %8.3f, sum_logprobs: %8.3f, avg_logprob: %8.3f\n", vocab.id_to_token.at(result.id).c_str(), logprobs[result.id], sum_logprobs, sum_logprobs/(tokens_cur.size() + 1));
return result; //return result;
} }
int whisper_full( int whisper_full(
@ -3191,6 +3179,7 @@ int whisper_full(
}; };
for (int i = 1; i < n_decoders; i++) { for (int i = 1; i < n_decoders; i++) {
// TAGS: WHISPER_DECODER_INIT
if (ctx->decoders[i].kv_self.ctx == nullptr) { if (ctx->decoders[i].kv_self.ctx == nullptr) {
ctx->decoders[i].kv_self = ctx->decoders[0].kv_self; ctx->decoders[i].kv_self = ctx->decoders[0].kv_self;
if (!kv_cache_reinit(ctx->decoders[i].kv_self)) { if (!kv_cache_reinit(ctx->decoders[i].kv_self)) {
@ -3201,7 +3190,9 @@ int whisper_full(
fprintf(stderr, "%s: initialized self-attention kv cache, decoder %d\n", __func__, i); fprintf(stderr, "%s: initialized self-attention kv cache, decoder %d\n", __func__, i);
ctx->decoders[i].sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity()); ctx->decoders[i].sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity());
ctx->decoders[i].prompt.reserve(ctx->decoders[0].prompt.capacity());
ctx->decoders[i].logits.reserve (ctx->vocab.n_vocab);
ctx->decoders[i].logprobs.reserve(ctx->vocab.n_vocab);
} }
} }
@ -3242,14 +3233,12 @@ int whisper_full(
int progress_prev = 0; int progress_prev = 0;
int progress_step = 5; int progress_step = 5;
std::vector<whisper_token_data> tokens_cur; int seek = seek_start;
tokens_cur.reserve(whisper_n_text_ctx(ctx));
std::vector<whisper_token> prompt; std::vector<whisper_token> prompt;
prompt.reserve(whisper_n_text_ctx(ctx)); prompt.reserve(whisper_n_text_ctx(ctx));
// main loop // main loop
int seek = seek_start;
while (true) { while (true) {
const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start); const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start);
while (progress_cur >= progress_prev + progress_step) { while (progress_cur >= progress_prev + progress_step) {
@ -3283,133 +3272,297 @@ int whisper_full(
prompt_past.clear(); prompt_past.clear();
} }
prompt.clear(); //prompt.clear();
// if we have already generated some text, use it as a prompt to condition the next generation //// if we have already generated some text, use it as a prompt to condition the next generation
if (!prompt_past.empty()) { //if (!prompt_past.empty()) {
int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size())); // int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
prompt = { whisper_token_prev(ctx) }; // prompt = { whisper_token_prev(ctx) };
prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end()); // prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
prompt_past.clear(); // prompt_past.clear();
prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end()); // prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end());
} //}
//prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
//// print the prompt
////printf("\n\n");
////for (int i = 0; i < prompt.size(); i++) {
//// printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token[prompt[i]].c_str());
////}
////printf("\n\n");
//int n_past = 0;
//int seek_delta = 100*WHISPER_CHUNK_SIZE;
//// the accumulated transcription in the current interation
//int result_len = 0;
//tokens_cur.clear();
//bool failed = false; // has the current segment failed to decode?
//bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment?
//double sum_logprobs = 0.0;
//for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
// if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
// fprintf(stderr, "%s: failed to decode\n", __func__);
// return -7;
// }
// n_past += prompt.size();
// prompt.clear();
// // sample the next token based on the selected decoding strategy + parameters
// // also, update the sliding window position based on the sampled timestamp tokens
// {
// const auto token = whisper_sample_next_token(ctx, params, sum_logprobs, prompt, tokens_cur);
// // timestamp token - update sliding window
// if (token.id > whisper_token_beg(ctx)) {
// const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx));
// // do not allow to go back in time
// if (has_ts && seek_delta > seek_delta_new && result_len < i) {
// break;
// }
// seek_delta = seek_delta_new;
// result_len = i + 1;
// has_ts = true;
// }
// // add it to the context
// prompt.push_back(token.id);
// tokens_cur.push_back(token);
// //{
// // const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]";
// // printf("%s: %3d %10s %6d %6.3f '%s'\n", __func__, i, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str());
// //}
// // end of segment
// if (token.id == whisper_token_eot(ctx) || // end of text token
// (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached
// (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached
// ) {
// if (result_len == 0) {
// if (seek + seek_delta + 100 >= seek_end) {
// result_len = i + 1;
// } else {
// failed = true;
// break;
// }
// }
// if (params.single_segment) {
// result_len = i + 1;
// seek_delta = 100*WHISPER_CHUNK_SIZE;
// }
// break;
// }
// // TESTS: if no tensors are loaded, it means we are running tests
// if (ctx->model.n_loaded == 0) {
// seek_delta = 100*WHISPER_CHUNK_SIZE;
// break;
// }
// }
// // sometimes, the decoding can get stuck in a repetition loop
// // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
// if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
// failed = true;
// break;
// }
//}
//if (failed) {
// // when we fail to sample timestamp token, retry by clearing the past prompt
// // if it fails again, then we advance the window by 1 second
// if (!prompt_past.empty()) {
// prompt_past.clear();
// } else {
// fprintf(stderr, "\n%s: failed to generate timestamp token - skipping one second\n\n", __func__);
// seek += 100;
// }
// continue;
//}
prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); //// shrink down to result_len
//tokens_cur.resize(result_len);
// print the prompt //for (const auto & r : tokens_cur) {
//printf("\n\n"); // prompt_past.push_back(r.id);
//for (int i = 0; i < prompt.size(); i++) {
// printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token[prompt[i]].c_str());
//} //}
//printf("\n\n");
int n_past = 0; for (int it = 0; it < (int) temperatures.size(); ++it) {
int seek_delta = 100*WHISPER_CHUNK_SIZE; const float t_cur = temperatures[it];
// the accumulated transcription in the current interation int n_decoders_cur = 1;
int result_len = 0; switch (params.strategy) {
tokens_cur.clear(); case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
{
if (t_cur > 0.0f) {
n_decoders_cur = params.greedy.best_of;
}
} break;
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
{
if (t_cur > 0.0f) {
n_decoders_cur = params.greedy.best_of;
} else {
n_decoders_cur = params.beam_search.beam_size;
}
} break;
};
bool failed = false; // has the current segment failed to decode? fprintf(stderr, "\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur);
bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment?
double sum_logprobs = 0.0; if (t_cur > 0.5) {
prompt_past.clear();
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { fprintf(stderr, "%s: clearing prompt_past\n", __func__);
if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
fprintf(stderr, "%s: failed to decode\n", __func__);
return -7;
} }
n_past += prompt.size(); // TAGS: WHISPER_DECODER_INIT
prompt.clear(); for (int i = 0; i < n_decoders_cur; ++i) {
auto & decoder = ctx->decoders[i];
// sample the next token based on the selected decoding strategy + parameters decoder.sequence.tokens.clear();
// also, update the sliding window position based on the sampled timestamp tokens decoder.sequence.result_len = 0;
{ decoder.sequence.sum_logprobs = 0.0;
const auto token = whisper_sample_next_token(ctx, params, sum_logprobs, prompt, tokens_cur);
// timestamp token - update sliding window decoder.n_past = 0;
if (token.id > whisper_token_beg(ctx)) { decoder.seek_delta = 100*WHISPER_CHUNK_SIZE;
const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx));
// do not allow to go back in time decoder.failed = false;
if (has_ts && seek_delta > seek_delta_new && result_len < i) { decoder.completed = false;
break; decoder.has_ts = false;
} }
seek_delta = seek_delta_new; // init prompt and kv cache for the current iteration
result_len = i + 1; // run whisper_decoder() only for decoder 0 and copy the results for the other decoders
has_ts = true; {
prompt.clear();
// if we have already generated some text, use it as a prompt to condition the next generation
if (!prompt_past.empty()) {
int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
prompt = { whisper_token_prev(ctx) };
prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
} }
// add it to the context // init new transcription with sot, language (opt) and task tokens
prompt.push_back(token.id); prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
tokens_cur.push_back(token);
//{ // print the prompt
// const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]"; //printf("\n\n");
// printf("%s: %3d %10s %6d %6.3f '%s'\n", __func__, i, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str()); //for (int i = 0; i < prompt.size(); i++) {
// printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token[prompt[i]].c_str());
//} //}
//printf("\n\n");
// end of segment if (whisper_decode(*ctx, ctx->decoders[0], params.n_threads, prompt.data(), prompt.size(), 0) != 0) {
if (token.id == whisper_token_eot(ctx) || // end of text token fprintf(stderr, "%s: failed to decode\n", __func__);
(params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached return -7;
(has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached }
) {
if (result_len == 0) {
if (seek + seek_delta + 100 >= seek_end) {
result_len = i + 1;
} else {
failed = true;
break;
}
}
if (params.single_segment) { whisper_process_logits(*ctx, ctx->decoders[0], params);
result_len = i + 1;
seek_delta = 100*WHISPER_CHUNK_SIZE;
}
break; for (int i = 1; i < n_decoders_cur; ++i) {
auto & decoder = ctx->decoders[i];
memcpy(decoder.kv_self.buf.data(), ctx->decoders[0].kv_self.buf.data(), decoder.kv_self.buf.size());
decoder.n_past += prompt.size();
memcpy(decoder.logits.data(), ctx->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
memcpy(decoder.logprobs.data(), ctx->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
} }
}
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
bool completed = true;
switch (params.strategy) {
case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
{
for (int i = 0; i < n_decoders_cur; ++i) {
auto & decoder = ctx->decoders[i];
if (decoder.completed || decoder.failed) {
continue;
}
if (t_cur < 1e-6f) {
// select top token
} else {
}
}
} break;
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
{
// TODO: ..
} break;
};
// TESTS: if no tensors are loaded, it means we are running tests if (completed) {
if (ctx->model.n_loaded == 0) {
seek_delta = 100*WHISPER_CHUNK_SIZE;
break; break;
} }
for (int i = 0; i < n_decoders_cur; ++i) {
auto & decoder = ctx->decoders[i];
if (decoder.failed || decoder.completed) {
continue;
}
decoder.tokens.resize(1);
decoder.tokens[0] = decoder.sequence.tokens.back().id;
if (whisper_decode(*ctx, decoder, params.n_threads, decoder.tokens.data(), decoder.tokens.size(), decoder.n_past) != 0) {
fprintf(stderr, "%s: failed to decode\n", __func__);
return -7;
}
whisper_process_logits(*ctx, decoder, params);
++decoder.n_past;
}
} }
// sometimes, the decoding can get stuck in a repetition loop // TODO: rank the resulting sequences and select the best one
// this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy {
if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
failed = true;
break;
} }
}
if (failed) { bool success = true;
// when we fail to sample timestamp token, retry by clearing the past prompt
// if it fails again, then we advance the window by 1 second // TODO: implement logprob threshold + compression threshold
if (!prompt_past.empty()) { {
prompt_past.clear();
} else {
fprintf(stderr, "\n%s: failed to generate timestamp token - skipping one second\n\n", __func__);
seek += 100;
} }
continue;
}
// shrink down to result_len if (success) {
tokens_cur.resize(result_len); break;
}
for (const auto & r : tokens_cur) { fprintf(stderr, "\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur);
prompt_past.push_back(r.id);
} }
// TODO
const int seek_delta = 0;
whisper_sequence seq_best;
const auto & tokens_cur = seq_best.tokens;
// TODO: update prompt_past to keep only the last whisper_n_text_ctx(ctx)/2 tokens
//prompt_past.clear();
//prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end());
// store the text from this iteration // store the text from this iteration
if (!tokens_cur.empty()) { if (!tokens_cur.empty()) {
int i0 = 0; int i0 = 0;
@ -3450,10 +3603,10 @@ int whisper_full(
if (params.token_timestamps) { if (params.token_timestamps) {
whisper_exp_compute_token_level_timestamps( whisper_exp_compute_token_level_timestamps(
ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
if (params.max_len > 0) { if (params.max_len > 0) {
n_new = whisper_wrap_segment(ctx, params.max_len); n_new = whisper_wrap_segment(*ctx, params.max_len);
} }
} }
if (params.new_segment_callback) { if (params.new_segment_callback) {
@ -3494,10 +3647,10 @@ int whisper_full(
if (params.token_timestamps) { if (params.token_timestamps) {
whisper_exp_compute_token_level_timestamps( whisper_exp_compute_token_level_timestamps(
ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
if (params.max_len > 0) { if (params.max_len > 0) {
n_new = whisper_wrap_segment(ctx, params.max_len); n_new = whisper_wrap_segment(*ctx, params.max_len);
} }
} }
if (params.new_segment_callback) { if (params.new_segment_callback) {
@ -3535,9 +3688,6 @@ int whisper_full_parallel(
ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx); ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx);
ctx_p.probs.reserve (ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx); ctx_p.probs.reserve (ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx);
ctx_p.work_logits.reserve (ctx_p.vocab.n_vocab);
ctx_p.work_logprobs.reserve(ctx_p.vocab.n_vocab);
ctx_p.probs_id.reserve(ctx_p.vocab.n_vocab); ctx_p.probs_id.reserve(ctx_p.vocab.n_vocab);
if (!kv_cache_reinit(ctx_p.kv_cross)) { if (!kv_cache_reinit(ctx_p.kv_cross)) {
@ -3545,6 +3695,7 @@ int whisper_full_parallel(
return false; return false;
} }
// TAGS: WHISPER_DECODER_INIT
for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) { for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) {
if (ctx_p.decoders[j].kv_self.ctx && !kv_cache_reinit(ctx_p.decoders[j].kv_self)) { if (ctx_p.decoders[j].kv_self.ctx && !kv_cache_reinit(ctx_p.decoders[j].kv_self)) {
fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d, processor %d\n", __func__, j, i); fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d, processor %d\n", __func__, j, i);
@ -3552,7 +3703,9 @@ int whisper_full_parallel(
} }
ctx_p.decoders[j].sequence.tokens.reserve(ctx_p.model.hparams.n_text_ctx); ctx_p.decoders[j].sequence.tokens.reserve(ctx_p.model.hparams.n_text_ctx);
ctx_p.decoders[j].prompt.reserve(ctx_p.model.hparams.n_text_ctx);
ctx_p.decoders[j].logits.reserve (ctx_p.vocab.n_vocab);
ctx_p.decoders[j].logprobs.reserve(ctx_p.vocab.n_vocab);
} }
} }
@ -3747,14 +3900,14 @@ static std::vector<float> get_signal_energy(const float * signal, int n_samples,
} }
static void whisper_exp_compute_token_level_timestamps( static void whisper_exp_compute_token_level_timestamps(
struct whisper_context * ctx, struct whisper_context & ctx,
int i_segment, int i_segment,
float thold_pt, float thold_pt,
float thold_ptsum) { float thold_ptsum) {
auto & segment = ctx->result_all[i_segment]; auto & segment = ctx.result_all[i_segment];
auto & tokens = segment.tokens; auto & tokens = segment.tokens;
const int n_samples = ctx->energy.size(); const int n_samples = ctx.energy.size();
if (n_samples == 0) { if (n_samples == 0) {
fprintf(stderr, "%s: no signal data available\n", __func__); fprintf(stderr, "%s: no signal data available\n", __func__);
@ -3777,28 +3930,28 @@ static void whisper_exp_compute_token_level_timestamps(
return; return;
} }
auto & t_beg = ctx->t_beg; auto & t_beg = ctx.t_beg;
auto & t_last = ctx->t_last; auto & t_last = ctx.t_last;
auto & tid_last = ctx->tid_last; auto & tid_last = ctx.tid_last;
for (int j = 0; j < n; ++j) { for (int j = 0; j < n; ++j) {
auto & token = tokens[j]; auto & token = tokens[j];
if (j == 0) { if (j == 0) {
if (token.id == whisper_token_beg(ctx)) { if (token.id == whisper_token_beg(&ctx)) {
tokens[j ].t0 = t0; tokens[j ].t0 = t0;
tokens[j ].t1 = t0; tokens[j ].t1 = t0;
tokens[j + 1].t0 = t0; tokens[j + 1].t0 = t0;
t_beg = t0; t_beg = t0;
t_last = t0; t_last = t0;
tid_last = whisper_token_beg(ctx); tid_last = whisper_token_beg(&ctx);
} else { } else {
tokens[j ].t0 = t_last; tokens[j ].t0 = t_last;
} }
} }
const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx)); const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(&ctx));
tokens[j].id = token.id; tokens[j].id = token.id;
tokens[j].tid = token.tid; tokens[j].tid = token.tid;
@ -3806,7 +3959,7 @@ static void whisper_exp_compute_token_level_timestamps(
tokens[j].pt = token.pt; tokens[j].pt = token.pt;
tokens[j].ptsum = token.ptsum; tokens[j].ptsum = token.ptsum;
tokens[j].vlen = voice_length(whisper_token_to_str(ctx, token.id)); tokens[j].vlen = voice_length(whisper_token_to_str(&ctx, token.id));
if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) { if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) {
if (j > 0) { if (j > 0) {
@ -3885,7 +4038,7 @@ static void whisper_exp_compute_token_level_timestamps(
const int hw = WHISPER_SAMPLE_RATE/8; const int hw = WHISPER_SAMPLE_RATE/8;
for (int j = 0; j < n; j++) { for (int j = 0; j < n; j++) {
if (tokens[j].id >= whisper_token_eot(ctx)) { if (tokens[j].id >= whisper_token_eot(&ctx)) {
continue; continue;
} }
@ -3900,15 +4053,15 @@ static void whisper_exp_compute_token_level_timestamps(
float sum = 0.0f; float sum = 0.0f;
for (int k = ss0; k < ss1; k++) { for (int k = ss0; k < ss1; k++) {
sum += ctx->energy[k]; sum += ctx.energy[k];
} }
const float thold = 0.5*sum/ns; const float thold = 0.5*sum/ns;
{ {
int k = s0; int k = s0;
if (ctx->energy[k] > thold && j > 0) { if (ctx.energy[k] > thold && j > 0) {
while (k > 0 && ctx->energy[k] > thold) { while (k > 0 && ctx.energy[k] > thold) {
k--; k--;
} }
tokens[j].t0 = sample_to_timestamp(k); tokens[j].t0 = sample_to_timestamp(k);
@ -3918,7 +4071,7 @@ static void whisper_exp_compute_token_level_timestamps(
s0 = k; s0 = k;
} }
} else { } else {
while (ctx->energy[k] < thold && k < s1) { while (ctx.energy[k] < thold && k < s1) {
k++; k++;
} }
s0 = k; s0 = k;
@ -3928,8 +4081,8 @@ static void whisper_exp_compute_token_level_timestamps(
{ {
int k = s1; int k = s1;
if (ctx->energy[k] > thold) { if (ctx.energy[k] > thold) {
while (k < n_samples - 1 && ctx->energy[k] > thold) { while (k < n_samples - 1 && ctx.energy[k] > thold) {
k++; k++;
} }
tokens[j].t1 = sample_to_timestamp(k); tokens[j].t1 = sample_to_timestamp(k);
@ -3939,7 +4092,7 @@ static void whisper_exp_compute_token_level_timestamps(
s1 = k; s1 = k;
} }
} else { } else {
while (ctx->energy[k] < thold && k > s0) { while (ctx.energy[k] < thold && k > s0) {
k--; k--;
} }
s1 = k; s1 = k;

Loading…
Cancel
Save