ref #52 : improve greedy sampling strategy

Force timestamp token to be sampled if the probability sum over all
timestamp tokens is above the probability of any other token
pull/67/head
Georgi Gerganov 2 years ago
parent 632660abb9
commit 7eeef0358a
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -1784,7 +1784,7 @@ bool whisper_decode(
// the most basic sampling scheme - select the top token
whisper_vocab::id whisper_sample_best(
const whisper_vocab & vocab,
const float * probs, bool need_timestamp) {
const float * probs) {
int n_logits = vocab.id_to_token.size();
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
@ -1794,9 +1794,29 @@ whisper_vocab::id whisper_sample_best(
probs_id.push_back(std::make_pair(probs[i], i));
}
const int top_k = 4;
double sum_ts = 0.0;
double max_tx = 0.0;
for (int i = 0; i < vocab.token_beg; i++) {
max_tx = std::max(max_tx, probs_id[i].first);
}
for (int i = vocab.token_beg; i < n_logits; i++) {
sum_ts += probs_id[i].first;
}
// if the probability sum of all timestamp tokesn is higher than the max probability of the text tokens - sample a
// timestamp token
if (sum_ts > max_tx) {
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
for (int i = 0; i < vocab.token_beg; i++) {
probs_id[i].first = -INFINITY;
}
}
// find the top K tokens
const int top_k = 4;
std::partial_sort(
probs_id.begin(),
probs_id.begin() + top_k, probs_id.end(),
@ -1811,15 +1831,6 @@ whisper_vocab::id whisper_sample_best(
// printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
//}
if (need_timestamp) {
// at the end of the 30-second audio segment, we start giving preference to time tokens
for (int i = 0; i < top_k; i++) {
if (probs_id[i].second > vocab.token_beg + 1300 && probs_id[i].first > 0.01*probs_id[0].first) {
return probs_id[i].second;
}
}
}
int res = 0;
while ((probs_id[res].second == vocab.token_sot ||
probs_id[res].second == vocab.token_solm ||
@ -2155,11 +2166,11 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
return 0;
}
whisper_token whisper_sample_best(struct whisper_context * ctx, bool need_timestamp) {
whisper_token whisper_sample_best(struct whisper_context * ctx) {
const int64_t t_start_sample_us = ggml_time_us();
// TODO: simplify
auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), need_timestamp);
auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab));
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
@ -2437,7 +2448,7 @@ int whisper_full(
whisper_token id = 0;
whisper_token tid = whisper_token_beg(ctx);
id = whisper_sample_best(ctx, result_len == 0);
id = whisper_sample_best(ctx);
if (i > 0) {
tid = whisper_sample_timestamp(ctx);
}

@ -120,7 +120,7 @@ extern "C" {
// You can also implement your own sampling method using the whisper_get_probs() function.
// whisper_sample_best() returns the token with the highest probability
// whisper_sample_timestamp() returns the most probable timestamp token
WHISPER_API whisper_token whisper_sample_best(struct whisper_context * ctx, bool need_timestamp);
WHISPER_API whisper_token whisper_sample_best(struct whisper_context * ctx);
WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx);
// Return the id of the specified language, returns -1 if not found

Loading…
Cancel
Save