|
|
|
@ -433,10 +433,11 @@ struct whisper_sequence {
|
|
|
|
|
// the accumulated transcription in the current interation (used to truncate the tokens array)
|
|
|
|
|
int result_len;
|
|
|
|
|
|
|
|
|
|
double sum_logprobs; // the sum of the log probabilities of the tokens
|
|
|
|
|
double avg_logprobs; // the average log probability of the tokens
|
|
|
|
|
double entropy; // the entropy of the tokens
|
|
|
|
|
double score; // likelihood rank score
|
|
|
|
|
double sum_logprobs_all; // the sum of the log probabilities of the tokens
|
|
|
|
|
double sum_logprobs; // the sum of the log probabilities of the tokens (first result_len tokens)
|
|
|
|
|
double avg_logprobs; // the average log probability of the tokens
|
|
|
|
|
double entropy; // the entropy of the tokens
|
|
|
|
|
double score; // likelihood rank score
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// TAGS: WHISPER_DECODER_INIT
|
|
|
|
@ -1650,7 +1651,7 @@ static bool whisper_decode(
|
|
|
|
|
const int N = n_tokens;
|
|
|
|
|
const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
|
|
|
|
|
|
|
|
|
|
WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
|
|
|
|
|
//WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
|
|
|
|
|
|
|
|
|
|
struct ggml_init_params params;
|
|
|
|
|
params.mem_size = wctx.buf_compute.size();
|
|
|
|
@ -3100,6 +3101,74 @@ static whisper_token_data whisper_sample_token(
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::vector<whisper_token_data> whisper_sample_token_topk(
|
|
|
|
|
whisper_context & ctx,
|
|
|
|
|
const whisper_decoder & decoder,
|
|
|
|
|
int k) {
|
|
|
|
|
const auto & vocab = ctx.vocab;
|
|
|
|
|
|
|
|
|
|
const auto & probs = decoder.probs;
|
|
|
|
|
const auto & logits = decoder.logits;
|
|
|
|
|
const auto & logprobs = decoder.logprobs;
|
|
|
|
|
|
|
|
|
|
const int n_logits = vocab.n_vocab;
|
|
|
|
|
|
|
|
|
|
auto & logits_id = ctx.logits_id;
|
|
|
|
|
|
|
|
|
|
logits_id.clear();
|
|
|
|
|
for (int i = 0; i < n_logits; ++i) {
|
|
|
|
|
logits_id.push_back({ logits[i], i });
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::partial_sort(
|
|
|
|
|
logits_id.begin(),
|
|
|
|
|
logits_id.begin() + k, logits_id.end(),
|
|
|
|
|
[](const std::pair<double, whisper_token> & a, const std::pair<double, whisper_token> & b) {
|
|
|
|
|
return a.first > b.first;
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
std::vector<whisper_token_data> result(k);
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
double sum_ts = 0.0;
|
|
|
|
|
double max_ts = 0.0;
|
|
|
|
|
|
|
|
|
|
for (int i = vocab.token_beg; i < n_logits; i++) {
|
|
|
|
|
if (probs[i] == -INFINITY) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
sum_ts += probs[i];
|
|
|
|
|
if (max_ts < probs[i]) {
|
|
|
|
|
max_ts = probs[i];
|
|
|
|
|
result[0].tid = i;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
result[0].pt = max_ts/(sum_ts + 1e-10);
|
|
|
|
|
result[0].ptsum = sum_ts;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < k; ++i) {
|
|
|
|
|
result[i].id = logits_id[i].second;
|
|
|
|
|
result[i].p = probs[result[i].id];
|
|
|
|
|
result[i].plog = logprobs[result[i].id];
|
|
|
|
|
result[i].tid = result[0].tid;
|
|
|
|
|
result[i].pt = result[0].pt;
|
|
|
|
|
result[i].ptsum = result[0].ptsum;
|
|
|
|
|
result[i].t0 = -1;
|
|
|
|
|
result[i].t1 = -1;
|
|
|
|
|
result[i].vlen = 0.0f;
|
|
|
|
|
|
|
|
|
|
if (result[i].id >= vocab.token_beg) {
|
|
|
|
|
result[i].tid = result[i].id;
|
|
|
|
|
result[i].pt = result[i].p;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L178-L192
|
|
|
|
|
static void whisper_sequence_score(
|
|
|
|
|
const struct whisper_full_params & params,
|
|
|
|
@ -3120,7 +3189,7 @@ static void whisper_sequence_score(
|
|
|
|
|
double penalty = sequence.result_len;
|
|
|
|
|
|
|
|
|
|
if (params.length_penalty > 0.0f) {
|
|
|
|
|
penalty = pow((5.0 + penalty) / 6.0, params.length_penalty);
|
|
|
|
|
penalty = pow((5.0 + penalty)/6.0, params.length_penalty);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
sequence.score = result/penalty;
|
|
|
|
@ -3141,7 +3210,8 @@ static void whisper_sequence_score(
|
|
|
|
|
for (const auto & kv : token_counts) {
|
|
|
|
|
const auto p = kv.second/(double)cnt;
|
|
|
|
|
entropy -= p*log(p);
|
|
|
|
|
//printf("entropy: %d %f %f, count %d\n", kv.first, p, log(p), kv.second);
|
|
|
|
|
|
|
|
|
|
//WHISPER_PRINT_DEBUG("entropy: %d %f %f, count %d\n", kv.first, p, log(p), kv.second);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
sequence.entropy = entropy;
|
|
|
|
@ -3293,6 +3363,25 @@ int whisper_full(
|
|
|
|
|
std::vector<whisper_token> prompt;
|
|
|
|
|
prompt.reserve(whisper_n_text_ctx(ctx));
|
|
|
|
|
|
|
|
|
|
// beam-search helpers
|
|
|
|
|
struct kv_buf {
|
|
|
|
|
std::vector<uint8_t> k;
|
|
|
|
|
std::vector<uint8_t> v;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
std::vector<kv_buf> kv_bufs;
|
|
|
|
|
|
|
|
|
|
struct beam_candidate {
|
|
|
|
|
int decoder_idx;
|
|
|
|
|
int seek_delta;
|
|
|
|
|
|
|
|
|
|
bool has_ts;
|
|
|
|
|
|
|
|
|
|
whisper_sequence sequence;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
std::vector<beam_candidate> beam_candidates;
|
|
|
|
|
|
|
|
|
|
// main loop
|
|
|
|
|
while (true) {
|
|
|
|
|
const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start);
|
|
|
|
@ -3360,11 +3449,12 @@ int whisper_full(
|
|
|
|
|
auto & decoder = ctx->decoders[j];
|
|
|
|
|
|
|
|
|
|
decoder.sequence.tokens.clear();
|
|
|
|
|
decoder.sequence.result_len = 0;
|
|
|
|
|
decoder.sequence.sum_logprobs = -INFINITY;
|
|
|
|
|
decoder.sequence.avg_logprobs = -INFINITY;
|
|
|
|
|
decoder.sequence.entropy = 0.0f;
|
|
|
|
|
decoder.sequence.score = -INFINITY;
|
|
|
|
|
decoder.sequence.result_len = 0;
|
|
|
|
|
decoder.sequence.sum_logprobs_all = 0.0;
|
|
|
|
|
decoder.sequence.sum_logprobs = -INFINITY;
|
|
|
|
|
decoder.sequence.avg_logprobs = -INFINITY;
|
|
|
|
|
decoder.sequence.entropy = 0.0;
|
|
|
|
|
decoder.sequence.score = -INFINITY;
|
|
|
|
|
|
|
|
|
|
decoder.n_past = 0;
|
|
|
|
|
decoder.seek_delta = 100*WHISPER_CHUNK_SIZE;
|
|
|
|
@ -3412,7 +3502,8 @@ int whisper_full(
|
|
|
|
|
for (int j = 1; j < n_decoders_cur; ++j) {
|
|
|
|
|
auto & decoder = ctx->decoders[j];
|
|
|
|
|
|
|
|
|
|
memcpy(decoder.kv_self.buf.data(), ctx->decoders[0].kv_self.buf.data(), decoder.kv_self.buf.size());
|
|
|
|
|
memcpy(decoder.kv_self.k->data, ctx->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
|
|
|
|
|
memcpy(decoder.kv_self.v->data, ctx->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
|
|
|
|
|
|
|
|
|
|
decoder.n_past += prompt.size();
|
|
|
|
|
|
|
|
|
@ -3428,6 +3519,25 @@ int whisper_full(
|
|
|
|
|
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
|
|
|
|
|
const int64_t t_start_sample_us = ggml_time_us();
|
|
|
|
|
|
|
|
|
|
if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
|
|
|
|
|
kv_bufs.resize(n_decoders_cur);
|
|
|
|
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
|
|
|
auto & decoder = ctx->decoders[j];
|
|
|
|
|
|
|
|
|
|
if (decoder.completed || decoder.failed) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
kv_bufs[j].k.resize(ggml_nbytes(decoder.kv_self.k));
|
|
|
|
|
kv_bufs[j].v.resize(ggml_nbytes(decoder.kv_self.v));
|
|
|
|
|
|
|
|
|
|
memcpy(kv_bufs[j].k.data(), decoder.kv_self.k->data, kv_bufs[j].k.size());
|
|
|
|
|
memcpy(kv_bufs[j].v.data(), decoder.kv_self.v->data, kv_bufs[j].v.size());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
beam_candidates.clear();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
|
|
|
auto & decoder = ctx->decoders[j];
|
|
|
|
|
|
|
|
|
@ -3443,12 +3553,65 @@ int whisper_full(
|
|
|
|
|
} else {
|
|
|
|
|
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
|
|
|
|
|
} break;
|
|
|
|
|
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
|
|
|
|
|
{
|
|
|
|
|
// TODO: ..
|
|
|
|
|
const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size);
|
|
|
|
|
|
|
|
|
|
for (const auto & token : tokens_new) {
|
|
|
|
|
beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence });
|
|
|
|
|
beam_candidates.back().sequence.tokens.push_back(token);
|
|
|
|
|
beam_candidates.back().sequence.sum_logprobs_all += token.plog;
|
|
|
|
|
|
|
|
|
|
//WHISPER_PRINT_DEBUG("%s: beam candidate: %s (%f, %f)\n", __func__, ctx->vocab.id_to_token.at(token.id).c_str(), token.plog, beam_candidates.back().sequence.sum_logprobs_all);
|
|
|
|
|
}
|
|
|
|
|
} break;
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
|
|
|
|
|
std::sort(
|
|
|
|
|
beam_candidates.begin(),
|
|
|
|
|
beam_candidates.end(),
|
|
|
|
|
[](const beam_candidate & a, const beam_candidate & b) {
|
|
|
|
|
return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
int cur_c = 0;
|
|
|
|
|
|
|
|
|
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
|
|
|
auto & decoder = ctx->decoders[j];
|
|
|
|
|
|
|
|
|
|
if (decoder.completed || decoder.failed) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto & cur = beam_candidates[cur_c++];
|
|
|
|
|
|
|
|
|
|
while (beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) {
|
|
|
|
|
++cur_c;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
decoder.sequence = cur.sequence;
|
|
|
|
|
decoder.seek_delta = cur.seek_delta;
|
|
|
|
|
decoder.has_ts = cur.has_ts;
|
|
|
|
|
|
|
|
|
|
memcpy(decoder.kv_self.k->data, kv_bufs[cur.decoder_idx].k.data(), kv_bufs[cur.decoder_idx].k.size());
|
|
|
|
|
memcpy(decoder.kv_self.v->data, kv_bufs[cur.decoder_idx].v.data(), kv_bufs[cur.decoder_idx].v.size());
|
|
|
|
|
|
|
|
|
|
WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
|
|
|
|
|
__func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
|
|
|
auto & decoder = ctx->decoders[j];
|
|
|
|
|
|
|
|
|
|
if (decoder.completed || decoder.failed) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto & has_ts = decoder.has_ts;
|
|
|
|
|
auto & failed = decoder.failed;
|
|
|
|
@ -3659,6 +3822,7 @@ int whisper_full(
|
|
|
|
|
} else {
|
|
|
|
|
text += whisper_token_to_str(ctx, tokens_cur[i].id);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) {
|
|
|
|
|
const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
|
|
|
|
|
if (!text.empty()) {
|
|
|
|
@ -4077,6 +4241,8 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
|
|
|
p1--;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//printf("p0=%d p1=%d t0=%lld t1=%lld\n", p0, p1, tokens[p0].t0, tokens[p1].t1);
|
|
|
|
|
|
|
|
|
|
if (p1 > p0) {
|
|
|
|
|
double psum = 0.0;
|
|
|
|
|
for (int j = p0; j <= p1; j++) {
|
|
|
|
@ -4205,11 +4371,11 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
|
|
|
// debug info
|
|
|
|
|
//for (int j = 0; j < n; ++j) {
|
|
|
|
|
// const auto & token = tokens[j];
|
|
|
|
|
// const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]";
|
|
|
|
|
// const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(&ctx, token.tid) : "[?]";
|
|
|
|
|
// printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__,
|
|
|
|
|
// tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(ctx, token.id));
|
|
|
|
|
// tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(&ctx, token.id));
|
|
|
|
|
|
|
|
|
|
// if (tokens[j].id >= whisper_token_eot(ctx)) {
|
|
|
|
|
// if (tokens[j].id >= whisper_token_eot(&ctx)) {
|
|
|
|
|
// continue;
|
|
|
|
|
// }
|
|
|
|
|
//}
|
|
|
|
|