|
|
@ -1784,7 +1784,7 @@ bool whisper_decode(
|
|
|
|
// the most basic sampling scheme - select the top token
|
|
|
|
// the most basic sampling scheme - select the top token
|
|
|
|
whisper_vocab::id whisper_sample_best(
|
|
|
|
whisper_vocab::id whisper_sample_best(
|
|
|
|
const whisper_vocab & vocab,
|
|
|
|
const whisper_vocab & vocab,
|
|
|
|
const float * probs, bool need_timestamp) {
|
|
|
|
const float * probs) {
|
|
|
|
int n_logits = vocab.id_to_token.size();
|
|
|
|
int n_logits = vocab.id_to_token.size();
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
|
|
|
|
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
|
|
|
@ -1794,9 +1794,29 @@ whisper_vocab::id whisper_sample_best(
|
|
|
|
probs_id.push_back(std::make_pair(probs[i], i));
|
|
|
|
probs_id.push_back(std::make_pair(probs[i], i));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
const int top_k = 4;
|
|
|
|
double sum_ts = 0.0;
|
|
|
|
|
|
|
|
double max_tx = 0.0;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < vocab.token_beg; i++) {
|
|
|
|
|
|
|
|
max_tx = std::max(max_tx, probs_id[i].first);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = vocab.token_beg; i < n_logits; i++) {
|
|
|
|
|
|
|
|
sum_ts += probs_id[i].first;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// if the probability sum of all timestamp tokesn is higher than the max probability of the text tokens - sample a
|
|
|
|
|
|
|
|
// timestamp token
|
|
|
|
|
|
|
|
if (sum_ts > max_tx) {
|
|
|
|
|
|
|
|
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
|
|
|
|
|
|
|
|
for (int i = 0; i < vocab.token_beg; i++) {
|
|
|
|
|
|
|
|
probs_id[i].first = -INFINITY;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// find the top K tokens
|
|
|
|
// find the top K tokens
|
|
|
|
|
|
|
|
const int top_k = 4;
|
|
|
|
|
|
|
|
|
|
|
|
std::partial_sort(
|
|
|
|
std::partial_sort(
|
|
|
|
probs_id.begin(),
|
|
|
|
probs_id.begin(),
|
|
|
|
probs_id.begin() + top_k, probs_id.end(),
|
|
|
|
probs_id.begin() + top_k, probs_id.end(),
|
|
|
@ -1811,15 +1831,6 @@ whisper_vocab::id whisper_sample_best(
|
|
|
|
// printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
|
|
|
|
// printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
|
|
|
|
//}
|
|
|
|
//}
|
|
|
|
|
|
|
|
|
|
|
|
if (need_timestamp) {
|
|
|
|
|
|
|
|
// at the end of the 30-second audio segment, we start giving preference to time tokens
|
|
|
|
|
|
|
|
for (int i = 0; i < top_k; i++) {
|
|
|
|
|
|
|
|
if (probs_id[i].second > vocab.token_beg + 1300 && probs_id[i].first > 0.01*probs_id[0].first) {
|
|
|
|
|
|
|
|
return probs_id[i].second;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int res = 0;
|
|
|
|
int res = 0;
|
|
|
|
while ((probs_id[res].second == vocab.token_sot ||
|
|
|
|
while ((probs_id[res].second == vocab.token_sot ||
|
|
|
|
probs_id[res].second == vocab.token_solm ||
|
|
|
|
probs_id[res].second == vocab.token_solm ||
|
|
|
@ -2155,11 +2166,11 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
|
|
|
|
return 0;
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
whisper_token whisper_sample_best(struct whisper_context * ctx, bool need_timestamp) {
|
|
|
|
whisper_token whisper_sample_best(struct whisper_context * ctx) {
|
|
|
|
const int64_t t_start_sample_us = ggml_time_us();
|
|
|
|
const int64_t t_start_sample_us = ggml_time_us();
|
|
|
|
|
|
|
|
|
|
|
|
// TODO: simplify
|
|
|
|
// TODO: simplify
|
|
|
|
auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), need_timestamp);
|
|
|
|
auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab));
|
|
|
|
|
|
|
|
|
|
|
|
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
|
|
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
|
|
|
|
|
|
|
|
|
@ -2437,7 +2448,7 @@ int whisper_full(
|
|
|
|
whisper_token id = 0;
|
|
|
|
whisper_token id = 0;
|
|
|
|
whisper_token tid = whisper_token_beg(ctx);
|
|
|
|
whisper_token tid = whisper_token_beg(ctx);
|
|
|
|
|
|
|
|
|
|
|
|
id = whisper_sample_best(ctx, result_len == 0);
|
|
|
|
id = whisper_sample_best(ctx);
|
|
|
|
if (i > 0) {
|
|
|
|
if (i > 0) {
|
|
|
|
tid = whisper_sample_timestamp(ctx);
|
|
|
|
tid = whisper_sample_timestamp(ctx);
|
|
|
|
}
|
|
|
|
}
|
|
|
|