whisper : wip decoding parameters + strategies (part 3)

pull/291/head
Georgi Gerganov 3 years ago
parent 3d723d0b82
commit 116dd67a15
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -422,7 +422,9 @@ struct whisper_sequence {
// the accumulated transcription in the current interation (used to truncate the tokens array)
int result_len;
double sum_logprobs;
double sum_logprobs; // the sum of the log probabilities of the tokens
double avg_logprobs; // the average log probability of the tokens
double score; // likelihood rank score
};
// TAGS: WHISPER_DECODER_INIT
@ -438,11 +440,12 @@ struct whisper_decoder {
bool completed; // has the decoder completed the current segment?
bool has_ts; // have we already sampled a non-beg timestamp token for the current segment?
std::vector<whisper_token> tokens;
// new token logits and logprobs after the last whisper_decode (1-dimensional array: [n_vocab])
// new token probs, logits and logprobs after the last whisper_decode (1-dimensional array: [n_vocab])
std::vector<float> probs;
std::vector<float> logits;
std::vector<float> logprobs;
std::vector<whisper_token> tokens_tmp; // used for whisper_decode calls
};
struct whisper_context {
@ -774,6 +777,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
// TAGS: WHISPER_DECODER_INIT
wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx);
wctx.decoders[0].probs.reserve (vocab.n_vocab);
wctx.decoders[0].logits.reserve (vocab.n_vocab);
wctx.decoders[0].logprobs.reserve(vocab.n_vocab);
}
@ -1959,99 +1963,6 @@ static bool whisper_decode(
return true;
}
// the most basic sampling scheme - select the top token
static whisper_token_data whisper_sample_best(
whisper_context & ctx,
const float * probs,
bool force_timestamp,
bool is_initial) {
whisper_token_data result = {
0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
};
const auto & vocab = ctx.vocab;
const int n_logits = vocab.n_vocab;
auto & probs_id = ctx.probs_id;
probs_id.clear();
for (int i = 0; i < n_logits; i++) {
probs_id.emplace_back(probs[i], i);
}
{
double sum_ts = 0.0;
double max_ts = -1.0;
double max_tx = -1.0;
for (int i = 0; i < vocab.token_beg; i++) {
max_tx = std::max(max_tx, probs_id[i].first);
}
const auto i0 = is_initial ? vocab.token_beg + 101 : vocab.token_beg;
const auto i1 = is_initial ? vocab.token_beg + 101 : n_logits;
// the initial timestamp cannot be larger than 100
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429
if (is_initial) {
for (int i = i0; i < n_logits; ++ i) {
probs_id[i].first = -INFINITY;
}
}
for (int i = vocab.token_beg; i < i1; i++) {
sum_ts += probs_id[i].first;
if (probs_id[i].first > max_ts) {
max_ts = probs_id[i].first;
result.tid = probs_id[i].second;
}
}
// if the probability sum of all timestamp tokens is higher than the max probability of the text tokens - sample a
// timestamp token
if (sum_ts > max_tx || force_timestamp) {
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
for (int i = 0; i < vocab.token_beg; i++) {
probs_id[i].first = -INFINITY;
}
}
result.pt = max_ts/(sum_ts + 1e-10);
result.ptsum = sum_ts;
}
// find the top K tokens
const int top_k = 4;
std::partial_sort(
probs_id.begin(),
probs_id.begin() + top_k, probs_id.end(),
[](const std::pair<double, whisper_vocab::id> & a, const std::pair<double, whisper_vocab::id> & b) {
return a.first > b.first;
});
probs_id.resize(top_k);
//printf("\n");
//for (int i = 0; i < (int) probs_id.size(); i++) {
// printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
//}
int res = 0;
while ((probs_id[res].second == vocab.token_sot ||
probs_id[res].second == vocab.token_solm ||
probs_id[res].second == vocab.token_not) &&
res < (int) probs_id.size() - 1) {
res++;
}
result.id = probs_id[res].second;
result.p = probs_id[res].first;
return result;
}
// 500 -> 00:05.000
// 6000 -> 01:00.000
static std::string to_timestamp(int64_t t, bool comma = false) {
@ -2793,8 +2704,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.suppress_blank =*/ true,
/*.temperature =*/ 0.0f,
/*.max_initial_timestamp =*/ 1.0f,
/*.temperature =*/ 0.0f,
/*.max_initial_timestamp =*/ 1.0f,
/*.length_penalty =*/ -1.0f,
/*.temperature_increment =*/ 0.2f,
/*.compression_ratio_threshold =*/ 2.4f,
@ -2809,7 +2721,6 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.beam_size =*/ -1,
/*.patience =*/ -1.0f,
/*.length_penalty =*/ -1.0f,
},
/*.new_segment_callback =*/ nullptr,
@ -2853,8 +2764,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.suppress_blank =*/ true,
/*.temperature =*/ 0.0f,
/*.max_initial_timestamp =*/ 1.0f,
/*.temperature =*/ 0.0f,
/*.max_initial_timestamp =*/ 1.0f,
/*.length_penalty =*/ -1.0f,
/*.temperature_increment =*/ 0.2f,
/*.compression_ratio_threshold =*/ 2.4f,
@ -2869,7 +2781,6 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.beam_size =*/ 5,
/*.patience =*/ -1.0f,
/*.length_penalty =*/ -1.0f,
},
/*.new_segment_callback =*/ nullptr,
@ -2963,9 +2874,13 @@ static void whisper_process_logits(
// extract the logits for the last token
// we will be mutating and therefore we don't want to use the ctx.logits buffer directly
auto & probs = decoder.probs;
auto & logits = decoder.logits;
auto & logprobs = decoder.logprobs;
{
probs.resize(n_logits);
memcpy(probs.data(), ctx.probs.data() + (ctx.probs.size() - n_logits), n_logits*sizeof(float));
logits.resize(n_logits);
memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float));
@ -2995,6 +2910,8 @@ static void whisper_process_logits(
const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg;
const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg;
//fprintf(stderr, "last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp);
if (last_was_timestamp) {
if (penultimate_was_timestamp) {
for (int i = vocab.token_beg; i < n_logits; ++ i) {
@ -3099,6 +3016,83 @@ static void whisper_process_logits(
//return result;
}
// select the most probable token
static whisper_token_data whisper_sample_best(
whisper_context & ctx,
whisper_decoder & decoder) {
whisper_token_data result = {
0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
};
const auto & vocab = ctx.vocab;
const auto & probs = decoder.probs;
const auto & logprobs = decoder.logprobs;
const int n_logits = vocab.n_vocab;
{
double sum_ts = 0.0;
double max_ts = 0.0;
for (int i = vocab.token_beg; i < n_logits; i++) {
if (probs[i] == -INFINITY) {
continue;
}
sum_ts += probs[i];
if (max_ts < probs[i]) {
max_ts = probs[i];
result.tid = i;
}
}
result.pt = max_ts/(sum_ts + 1e-10);
result.ptsum = sum_ts;
}
for (int i = 0; i < n_logits; ++i) {
// never sample these:
if (i == vocab.token_sot ||
i == vocab.token_solm ||
i == vocab.token_not) {
continue;
}
if (result.p < probs[i]) {
result.id = i;
result.p = probs[i];
result.plog = logprobs[i];
}
}
return result;
}
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L178-L192
static void whisper_sequence_score(
const struct whisper_full_params & params,
whisper_sequence & sequence) {
WHISPER_ASSERT(sequence.result_len > 0);
double result = 0.0f;
for (int i = 0; i < sequence.result_len; ++i) {
result += sequence.tokens[i].plog;
}
sequence.sum_logprobs = result;
sequence.avg_logprobs = result/sequence.result_len;
double penalty = sequence.result_len;
if (params.length_penalty > 0.0f) {
penalty = pow((5.0 + penalty) / 6.0, params.length_penalty);
}
sequence.score = result/penalty;
}
int whisper_full(
struct whisper_context * ctx,
struct whisper_full_params params,
@ -3191,6 +3185,7 @@ int whisper_full(
ctx->decoders[i].sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity());
ctx->decoders[i].probs.reserve (ctx->vocab.n_vocab);
ctx->decoders[i].logits.reserve (ctx->vocab.n_vocab);
ctx->decoders[i].logprobs.reserve(ctx->vocab.n_vocab);
}
@ -3261,7 +3256,7 @@ int whisper_full(
}
// encode audio features starting at offset seek
if (whisper_encode(ctx, seek, params.n_threads) != 0) {
if (!whisper_encode(*ctx, seek, params.n_threads)) {
fprintf(stderr, "%s: failed to encode\n", __func__);
return -6;
}
@ -3272,132 +3267,7 @@ int whisper_full(
prompt_past.clear();
}
//prompt.clear();
//// if we have already generated some text, use it as a prompt to condition the next generation
//if (!prompt_past.empty()) {
// int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
// prompt = { whisper_token_prev(ctx) };
// prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
// prompt_past.clear();
// prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end());
//}
//prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
//// print the prompt
////printf("\n\n");
////for (int i = 0; i < prompt.size(); i++) {
//// printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token[prompt[i]].c_str());
////}
////printf("\n\n");
//int n_past = 0;
//int seek_delta = 100*WHISPER_CHUNK_SIZE;
//// the accumulated transcription in the current interation
//int result_len = 0;
//tokens_cur.clear();
//bool failed = false; // has the current segment failed to decode?
//bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment?
//double sum_logprobs = 0.0;
//for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
// if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
// fprintf(stderr, "%s: failed to decode\n", __func__);
// return -7;
// }
// n_past += prompt.size();
// prompt.clear();
// // sample the next token based on the selected decoding strategy + parameters
// // also, update the sliding window position based on the sampled timestamp tokens
// {
// const auto token = whisper_sample_next_token(ctx, params, sum_logprobs, prompt, tokens_cur);
// // timestamp token - update sliding window
// if (token.id > whisper_token_beg(ctx)) {
// const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx));
// // do not allow to go back in time
// if (has_ts && seek_delta > seek_delta_new && result_len < i) {
// break;
// }
// seek_delta = seek_delta_new;
// result_len = i + 1;
// has_ts = true;
// }
// // add it to the context
// prompt.push_back(token.id);
// tokens_cur.push_back(token);
// //{
// // const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]";
// // printf("%s: %3d %10s %6d %6.3f '%s'\n", __func__, i, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str());
// //}
// // end of segment
// if (token.id == whisper_token_eot(ctx) || // end of text token
// (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached
// (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached
// ) {
// if (result_len == 0) {
// if (seek + seek_delta + 100 >= seek_end) {
// result_len = i + 1;
// } else {
// failed = true;
// break;
// }
// }
// if (params.single_segment) {
// result_len = i + 1;
// seek_delta = 100*WHISPER_CHUNK_SIZE;
// }
// break;
// }
// // TESTS: if no tensors are loaded, it means we are running tests
// if (ctx->model.n_loaded == 0) {
// seek_delta = 100*WHISPER_CHUNK_SIZE;
// break;
// }
// }
// // sometimes, the decoding can get stuck in a repetition loop
// // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
// if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
// failed = true;
// break;
// }
//}
//if (failed) {
// // when we fail to sample timestamp token, retry by clearing the past prompt
// // if it fails again, then we advance the window by 1 second
// if (!prompt_past.empty()) {
// prompt_past.clear();
// } else {
// fprintf(stderr, "\n%s: failed to generate timestamp token - skipping one second\n\n", __func__);
// seek += 100;
// }
// continue;
//}
//// shrink down to result_len
//tokens_cur.resize(result_len);
//for (const auto & r : tokens_cur) {
// prompt_past.push_back(r.id);
//}
int best_decoder_id = 0;
for (int it = 0; it < (int) temperatures.size(); ++it) {
const float t_cur = temperatures[it];
@ -3429,12 +3299,14 @@ int whisper_full(
}
// TAGS: WHISPER_DECODER_INIT
for (int i = 0; i < n_decoders_cur; ++i) {
auto & decoder = ctx->decoders[i];
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = ctx->decoders[j];
decoder.sequence.tokens.clear();
decoder.sequence.result_len = 0;
decoder.sequence.sum_logprobs = 0.0;
decoder.sequence.avg_logprobs = 0.0;
decoder.sequence.score = 0.0;
decoder.n_past = 0;
decoder.seek_delta = 100*WHISPER_CHUNK_SIZE;
@ -3467,65 +3339,148 @@ int whisper_full(
//}
//printf("\n\n");
if (whisper_decode(*ctx, ctx->decoders[0], params.n_threads, prompt.data(), prompt.size(), 0) != 0) {
if (!whisper_decode(*ctx, ctx->decoders[0], params.n_threads, prompt.data(), prompt.size(), 0)) {
fprintf(stderr, "%s: failed to decode\n", __func__);
return -7;
}
whisper_process_logits(*ctx, ctx->decoders[0], params);
for (int i = 1; i < n_decoders_cur; ++i) {
auto & decoder = ctx->decoders[i];
for (int j = 1; j < n_decoders_cur; ++j) {
auto & decoder = ctx->decoders[j];
memcpy(decoder.kv_self.buf.data(), ctx->decoders[0].kv_self.buf.data(), decoder.kv_self.buf.size());
decoder.n_past += prompt.size();
memcpy(decoder.probs.data(), ctx->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
memcpy(decoder.logits.data(), ctx->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
memcpy(decoder.logprobs.data(), ctx->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
}
}
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
bool completed = true;
switch (params.strategy) {
case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
{
for (int i = 0; i < n_decoders_cur; ++i) {
auto & decoder = ctx->decoders[i];
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = ctx->decoders[j];
if (decoder.completed || decoder.failed) {
continue;
}
if (decoder.completed || decoder.failed) {
continue;
}
switch (params.strategy) {
case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
{
if (t_cur < 1e-6f) {
// select top token
decoder.sequence.tokens.push_back(whisper_sample_best(*ctx, decoder));
} else {
}
} break;
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
{
// TODO: ..
} break;
};
auto & has_ts = decoder.has_ts;
auto & failed = decoder.failed;
auto & completed = decoder.completed;
auto & seek_delta = decoder.seek_delta;
auto & result_len = decoder.sequence.result_len;
{
const auto & token = decoder.sequence.tokens.back();
// timestamp token - update sliding window
if (token.id > whisper_token_beg(ctx)) {
const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx));
// do not allow to go back in time
if (has_ts && seek_delta > seek_delta_new && result_len < i) {
failed = true; // TODO: maybe this is not a failure ?
break;
}
} break;
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
seek_delta = seek_delta_new;
result_len = i + 1;
has_ts = true;
}
{
// TODO: ..
} break;
};
const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]";
printf("%s: %3d, decoder = %d, %10s %6d %6.3f '%s'\n", __func__, i, j, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str());
}
if (completed) {
break;
// end of segment
if (token.id == whisper_token_eot(ctx) || // end of text token
(params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached
(has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached
) {
if (result_len == 0) {
if (seek + seek_delta + 100 >= seek_end) {
result_len = i + 1;
} else {
failed = true;
break;
}
}
if (params.single_segment) {
result_len = i + 1;
seek_delta = 100*WHISPER_CHUNK_SIZE;
}
completed = true;
break;
}
// TESTS: if no tensors are loaded, it means we are running tests
if (ctx->model.n_loaded == 0) {
seek_delta = 100*WHISPER_CHUNK_SIZE;
completed = true;
break;
}
}
// sometimes, the decoding can get stuck in a repetition loop
// this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
failed = true;
break;
}
}
// check if all decoders have finished (i.e. completed or failed)
{
bool completed_all = true;
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = ctx->decoders[j];
if (decoder.completed || decoder.failed) {
continue;
}
completed_all = false;
}
if (completed_all) {
break;
}
}
for (int i = 0; i < n_decoders_cur; ++i) {
auto & decoder = ctx->decoders[i];
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = ctx->decoders[j];
if (decoder.failed || decoder.completed) {
continue;
}
decoder.tokens.resize(1);
decoder.tokens[0] = decoder.sequence.tokens.back().id;
decoder.tokens_tmp.resize(1);
decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id;
//fprintf(stderr, "%s: decoder %d: token %d, n_past %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.n_past, decoder.seek_delta);
if (whisper_decode(*ctx, decoder, params.n_threads, decoder.tokens.data(), decoder.tokens.size(), decoder.n_past) != 0) {
if (!whisper_decode(*ctx, decoder, params.n_threads, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.n_past)) {
fprintf(stderr, "%s: failed to decode\n", __func__);
return -7;
}
@ -3536,14 +3491,36 @@ int whisper_full(
}
}
// TODO: rank the resulting sequences and select the best one
// rank the resulting sequences and select the best one
{
double best_score = -1e9;
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = ctx->decoders[j];
if (decoder.failed) {
continue;
}
whisper_sequence_score(params, ctx->decoders[j].sequence);
if (best_score < decoder.sequence.score) {
best_score = decoder.sequence.score;
best_decoder_id = j;
}
}
}
bool success = true;
// TODO: implement logprob threshold + compression threshold
// implement logprob threshold
// TODO: implement compression threshold
{
auto & decoder = ctx->decoders[best_decoder_id];
if (decoder.sequence.avg_logprobs < params.logprob_threshold) {
success = false;
}
}
if (success) {
@ -3553,113 +3530,121 @@ int whisper_full(
fprintf(stderr, "\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur);
}
// TODO
const int seek_delta = 0;
whisper_sequence seq_best;
{
const auto & best_decoder = ctx->decoders[best_decoder_id];
const auto & tokens_cur = seq_best.tokens;
const auto seek_delta = best_decoder.seek_delta;
const auto result_len = best_decoder.sequence.result_len;
// TODO: update prompt_past to keep only the last whisper_n_text_ctx(ctx)/2 tokens
//prompt_past.clear();
//prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end());
const auto & tokens_cur = best_decoder.sequence.tokens;
// store the text from this iteration
if (!tokens_cur.empty()) {
int i0 = 0;
auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
// update prompt_past
prompt_past.clear();
prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end());
std::string text;
for (int i = 0; i < result_len; ++i) {
prompt_past.push_back(tokens_cur[i].id);
}
for (int i = 0; i < (int) tokens_cur.size(); i++) {
//printf("%s: %18s %6.3f %18s %6.3f\n", __func__,
// ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p,
// ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt);
// store the text from this iteration
if (!tokens_cur.empty()) {
int i0 = 0;
auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) {
} else {
text += whisper_token_to_str(ctx, tokens_cur[i].id);
}
if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) {
const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
if (!text.empty()) {
const auto tt0 = params.speed_up ? 2*t0 : t0;
const auto tt1 = params.speed_up ? 2*t1 : t1;
if (params.print_realtime) {
if (params.print_timestamps) {
printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
} else {
printf("%s", text.c_str());
fflush(stdout);
std::string text;
for (int i = 0; i < (int) tokens_cur.size(); i++) {
//printf("%s: %18s %6.3f %18s %6.3f\n", __func__,
// ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p,
// ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt);
if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) {
} else {
text += whisper_token_to_str(ctx, tokens_cur[i].id);
}
if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) {
const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
if (!text.empty()) {
const auto tt0 = params.speed_up ? 2*t0 : t0;
const auto tt1 = params.speed_up ? 2*t1 : t1;
if (params.print_realtime) {
if (params.print_timestamps) {
printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
} else {
printf("%s", text.c_str());
fflush(stdout);
}
}
}
result_all.push_back({ tt0, tt1, text, {} });
for (int j = i0; j <= i; j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
result_all.push_back({ tt0, tt1, text, {} });
for (int j = i0; j <= i; j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
int n_new = 1;
int n_new = 1;
if (params.token_timestamps) {
whisper_exp_compute_token_level_timestamps(
*ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
if (params.token_timestamps) {
whisper_exp_compute_token_level_timestamps(
*ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
if (params.max_len > 0) {
n_new = whisper_wrap_segment(*ctx, params.max_len);
if (params.max_len > 0) {
n_new = whisper_wrap_segment(*ctx, params.max_len);
}
}
if (params.new_segment_callback) {
params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
}
}
if (params.new_segment_callback) {
params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
text = "";
while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) {
i++;
}
i--;
t0 = t1;
i0 = i + 1;
}
text = "";
while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) {
i++;
}
i--;
t0 = t1;
i0 = i + 1;
}
}
if (!text.empty()) {
const auto t1 = seek + seek_delta;
if (!text.empty()) {
const auto t1 = seek + seek_delta;
const auto tt0 = params.speed_up ? 2*t0 : t0;
const auto tt1 = params.speed_up ? 2*t1 : t1;
const auto tt0 = params.speed_up ? 2*t0 : t0;
const auto tt1 = params.speed_up ? 2*t1 : t1;
if (params.print_realtime) {
if (params.print_timestamps) {
printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
} else {
printf("%s", text.c_str());
fflush(stdout);
if (params.print_realtime) {
if (params.print_timestamps) {
printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
} else {
printf("%s", text.c_str());
fflush(stdout);
}
}
}
result_all.push_back({ tt0, tt1, text, {} });
for (int j = i0; j < (int) tokens_cur.size(); j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
result_all.push_back({ tt0, tt1, text, {} });
for (int j = i0; j < (int) tokens_cur.size(); j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
int n_new = 1;
int n_new = 1;
if (params.token_timestamps) {
whisper_exp_compute_token_level_timestamps(
*ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
if (params.token_timestamps) {
whisper_exp_compute_token_level_timestamps(
*ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
if (params.max_len > 0) {
n_new = whisper_wrap_segment(*ctx, params.max_len);
if (params.max_len > 0) {
n_new = whisper_wrap_segment(*ctx, params.max_len);
}
}
if (params.new_segment_callback) {
params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
}
}
if (params.new_segment_callback) {
params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
}
}
}
seek += seek_delta;
// update audio window
seek += seek_delta;
}
}
return 0;
@ -3704,6 +3689,7 @@ int whisper_full_parallel(
ctx_p.decoders[j].sequence.tokens.reserve(ctx_p.model.hparams.n_text_ctx);
ctx_p.decoders[j].probs.reserve (ctx_p.vocab.n_vocab);
ctx_p.decoders[j].logits.reserve (ctx_p.vocab.n_vocab);
ctx_p.decoders[j].logprobs.reserve(ctx_p.vocab.n_vocab);
}

@ -74,6 +74,7 @@ extern "C" {
whisper_token tid; // forced timestamp token id
float p; // probability of the token
float plog; // log probability of the token
float pt; // probability of the timestamp token
float ptsum; // sum of probabilities of all timestamp tokens
@ -280,6 +281,7 @@ extern "C" {
float temperature;
float max_initial_timestamp;
float length_penalty;
// fallback parameters
float temperature_increment;
@ -295,7 +297,6 @@ extern "C" {
int beam_size;
float patience;
float length_penalty;
} beam_search;
whisper_new_segment_callback new_segment_callback;

Loading…
Cancel
Save