whisper : apply logit filters and compute logprobs

pull/291/head
Georgi Gerganov 3 years ago
parent abe104b80b
commit 2d8d3724b8
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

2
.gitignore vendored

@ -8,6 +8,7 @@ build/
build-em/ build-em/
build-debug/ build-debug/
build-release/ build-release/
build-static/
build-sanitize-addr/ build-sanitize-addr/
build-sanitize-thread/ build-sanitize-thread/
@ -18,6 +19,7 @@ build-sanitize-thread/
/bench /bench
sync.sh sync.sh
libwhisper.a
libwhisper.so libwhisper.so
compile_commands.json compile_commands.json

@ -433,7 +433,9 @@ struct whisper_context {
std::vector<whisper_segment> result_all; std::vector<whisper_segment> result_all;
std::vector<whisper_token> prompt_past; std::vector<whisper_token> prompt_past;
std::vector<float> work_logits; // used to avoid allocations // used to avoid allocations
std::vector<float> work_logits;
std::vector<float> work_logprobs;
// [EXPERIMENTAL] token-level timestamps data // [EXPERIMENTAL] token-level timestamps data
int64_t t_beg; int64_t t_beg;
@ -2700,6 +2702,8 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.suppress_blank =*/ true, /*.suppress_blank =*/ true,
/*.max_initial_timestamp =*/ 1.0,
/*.greedy =*/ { /*.greedy =*/ {
/*.dummy =*/ 0, /*.dummy =*/ 0,
}, },
@ -2750,6 +2754,8 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.suppress_blank =*/ true, /*.suppress_blank =*/ true,
/*.max_initial_timestamp =*/ 1.0,
/*.greedy =*/ { /*.greedy =*/ {
/*.dummy =*/ 0, /*.dummy =*/ 0,
}, },
@ -2850,16 +2856,119 @@ static struct whisper_token_data whisper_sample_next_token(
// extract the logits for the last token // extract the logits for the last token
// we will be mutating and therefore we don't want to use the ctx->logits buffer directly // we will be mutating and therefore we don't want to use the ctx->logits buffer directly
auto & logits = ctx->work_logits; auto & logits = ctx->work_logits;
auto & logprobs = ctx->work_logprobs;
{ {
logits.resize(n_logits); logits.resize(n_logits);
memcpy(logits.data(), ctx->logits.data() + (ctx->logits.size() - n_logits), n_logits*sizeof(float)); memcpy(logits.data(), ctx->logits.data() + (ctx->logits.size() - n_logits), n_logits*sizeof(float));
// will be populated a bit later
logprobs.resize(n_logits);
} }
// apply logit filters here
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L480-L493 // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L480-L493
// TODO: apply logit filters here
{ {
// suppress blank
// https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L388-L390
if (params.suppress_blank) {
if (is_initial) {
logits[vocab.token_eot] = -INFINITY;
logits[vocab.token_to_id.at(" ")] = -INFINITY;
}
}
// suppress <|notimestamps|> token
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412
logits[vocab.token_not] = -INFINITY;
// timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
// https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424
{
const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg;
const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg;
if (last_was_timestamp) {
if (penultimate_was_timestamp) {
for (int i = vocab.token_beg; i < n_logits; ++ i) {
logits[i] = -INFINITY;
}
} else {
for (int i = 0; i < vocab.token_eot; ++ i) {
logits[i] = -INFINITY;
}
}
}
}
// the initial timestamp cannot be larger than max_initial_timestamp
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429
if (is_initial && params.max_initial_timestamp > 0.0f) {
const float precision = float(WHISPER_CHUNK_SIZE)/ctx->model.hparams.n_audio_ctx;
const int tid0 = std::round(params.max_initial_timestamp/precision);
for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++ i) {
logits[i] = -INFINITY;
}
}
// populate the logprobs array (log_softmax)
{
const float logit_max = *std::max_element(logits.begin(), logits.end());
float logsumexp = 0.0f;
for (int i = 0; i < n_logits; ++ i) {
logsumexp += expf(logits[i] - logit_max);
}
logsumexp = logf(logsumexp) + logit_max;
for (int i = 0; i < n_logits; ++ i) {
logprobs[i] = logits[i] - logsumexp;
}
}
// if sum of probability over timestamps is above any other token, sample timestamp
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437
{
// logsumexp over timestamps
float timestamp_logprob = -INFINITY;
{
float logsumexp = 0.0f;
const float logprob_max = *std::max_element(logprobs.begin() + vocab.token_beg, logprobs.end());
for (int i = vocab.token_beg; i < n_logits; ++ i) {
logsumexp += expf(logprobs[i] - logprob_max);
}
logsumexp = logf(logsumexp) + logprob_max;
timestamp_logprob = logsumexp;
} }
const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg);
if (timestamp_logprob > max_text_token_logprob) {
for (int i = 0; i < vocab.token_beg; ++ i) {
logits[i] = -INFINITY;
}
}
}
}
// print first 100 logits - token string : logit
for (int i = 0; i < 100; i++) {
const auto token = vocab.id_to_token.at(i);
const auto logit = logits[i];
printf("%s : %f\n", token.c_str(), logit);
}
// "And", "and", " And", " and"
printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]);
printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]);
printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]);
printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]);
printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]);
printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
switch (params.strategy) { switch (params.strategy) {
case WHISPER_SAMPLING_GREEDY: case WHISPER_SAMPLING_GREEDY:
{ {

@ -274,8 +274,11 @@ extern "C" {
// for auto-detection, set to nullptr, "" or "auto" // for auto-detection, set to nullptr, "" or "auto"
const char * language; const char * language;
// common decoding parameters:
bool suppress_blank; bool suppress_blank;
float max_initial_timestamp;
struct { struct {
int dummy; int dummy;
} greedy; } greedy;

Loading…
Cancel
Save