whisper : fix timestamp sampling

pull/12/head
Georgi Gerganov 2 years ago
parent 7094be1f37
commit 624e4f5313
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -1784,7 +1784,7 @@ bool whisper_decode(
// the most basic sampling scheme - select the top token // the most basic sampling scheme - select the top token
whisper_vocab::id whisper_sample_best( whisper_vocab::id whisper_sample_best(
const whisper_vocab & vocab, const whisper_vocab & vocab,
const float * probs, bool need_timestamp) { const float * probs) {
int n_logits = vocab.id_to_token.size(); int n_logits = vocab.id_to_token.size();
std::vector<std::pair<double, whisper_vocab::id>> probs_id; std::vector<std::pair<double, whisper_vocab::id>> probs_id;
@ -1794,9 +1794,29 @@ whisper_vocab::id whisper_sample_best(
probs_id.push_back(std::make_pair(probs[i], i)); probs_id.push_back(std::make_pair(probs[i], i));
} }
const int top_k = 4; double sum_ts = 0.0;
double max_tx = 0.0;
for (int i = 0; i < vocab.token_beg; i++) {
max_tx = std::max(max_tx, probs_id[i].first);
}
for (int i = vocab.token_beg; i < n_logits; i++) {
sum_ts += probs_id[i].first;
}
// if the probability sum of all timestamp tokesn is higher than the max probability of the text tokens - sample a
// timestamp token
if (sum_ts > max_tx) {
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
for (int i = 0; i < vocab.token_beg; i++) {
probs_id[i].first = -INFINITY;
}
}
// find the top K tokens // find the top K tokens
const int top_k = 4;
std::partial_sort( std::partial_sort(
probs_id.begin(), probs_id.begin(),
probs_id.begin() + top_k, probs_id.end(), probs_id.begin() + top_k, probs_id.end(),
@ -1811,15 +1831,6 @@ whisper_vocab::id whisper_sample_best(
// printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second); // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
//} //}
if (need_timestamp) {
// at the end of the 30-second audio segment, we start giving preference to time tokens
for (int i = 0; i < top_k; i++) {
if (probs_id[i].second > vocab.token_beg + 1300 && probs_id[i].first > 0.01*probs_id[0].first) {
return probs_id[i].second;
}
}
}
int res = 0; int res = 0;
while ((probs_id[res].second == vocab.token_sot || while ((probs_id[res].second == vocab.token_sot ||
probs_id[res].second == vocab.token_solm || probs_id[res].second == vocab.token_solm ||
@ -2155,11 +2166,11 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
return 0; return 0;
} }
whisper_token whisper_sample_best(struct whisper_context * ctx, bool need_timestamp) { whisper_token whisper_sample_best(struct whisper_context * ctx) {
const int64_t t_start_sample_us = ggml_time_us(); const int64_t t_start_sample_us = ggml_time_us();
// TODO: simplify // TODO: simplify
auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), need_timestamp); auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab));
ctx->t_sample_us += ggml_time_us() - t_start_sample_us; ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
@ -2437,7 +2448,7 @@ int whisper_full(
whisper_token id = 0; whisper_token id = 0;
whisper_token tid = whisper_token_beg(ctx); whisper_token tid = whisper_token_beg(ctx);
id = whisper_sample_best(ctx, result_len == 0); id = whisper_sample_best(ctx);
if (i > 0) { if (i > 0) {
tid = whisper_sample_timestamp(ctx); tid = whisper_sample_timestamp(ctx);
} }

@ -120,7 +120,7 @@ extern "C" {
// You can also implement your own sampling method using the whisper_get_probs() function. // You can also implement your own sampling method using the whisper_get_probs() function.
// whisper_sample_best() returns the token with the highest probability // whisper_sample_best() returns the token with the highest probability
// whisper_sample_timestamp() returns the most probable timestamp token // whisper_sample_timestamp() returns the most probable timestamp token
WHISPER_API whisper_token whisper_sample_best(struct whisper_context * ctx, bool need_timestamp); WHISPER_API whisper_token whisper_sample_best(struct whisper_context * ctx);
WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx); WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx);
// Return the id of the specified language, returns -1 if not found // Return the id of the specified language, returns -1 if not found

@ -75,6 +75,9 @@ ggml_fp16_t ggml_fp32_to_fp16(float x) {
#include <immintrin.h> #include <immintrin.h>
// FP16 <-> FP32
// ref: https://github.com/Maratyszcza/FP16
static inline float fp32_from_bits(uint32_t w) { static inline float fp32_from_bits(uint32_t w) {
union { union {
uint32_t as_bits; uint32_t as_bits;

Loading…
Cancel
Save