|
|
|
@ -433,7 +433,9 @@ struct whisper_context {
|
|
|
|
|
std::vector<whisper_segment> result_all;
|
|
|
|
|
std::vector<whisper_token> prompt_past;
|
|
|
|
|
|
|
|
|
|
std::vector<float> work_logits; // used to avoid allocations
|
|
|
|
|
// used to avoid allocations
|
|
|
|
|
std::vector<float> work_logits;
|
|
|
|
|
std::vector<float> work_logprobs;
|
|
|
|
|
|
|
|
|
|
// [EXPERIMENTAL] token-level timestamps data
|
|
|
|
|
int64_t t_beg;
|
|
|
|
@ -2700,6 +2702,8 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
|
|
|
|
|
|
|
/*.suppress_blank =*/ true,
|
|
|
|
|
|
|
|
|
|
/*.max_initial_timestamp =*/ 1.0,
|
|
|
|
|
|
|
|
|
|
/*.greedy =*/ {
|
|
|
|
|
/*.dummy =*/ 0,
|
|
|
|
|
},
|
|
|
|
@ -2750,6 +2754,8 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
|
|
|
|
|
|
|
/*.suppress_blank =*/ true,
|
|
|
|
|
|
|
|
|
|
/*.max_initial_timestamp =*/ 1.0,
|
|
|
|
|
|
|
|
|
|
/*.greedy =*/ {
|
|
|
|
|
/*.dummy =*/ 0,
|
|
|
|
|
},
|
|
|
|
@ -2849,17 +2855,120 @@ static struct whisper_token_data whisper_sample_next_token(
|
|
|
|
|
|
|
|
|
|
// extract the logits for the last token
|
|
|
|
|
// we will be mutating and therefore we don't want to use the ctx->logits buffer directly
|
|
|
|
|
auto & logits = ctx->work_logits;
|
|
|
|
|
auto & logits = ctx->work_logits;
|
|
|
|
|
auto & logprobs = ctx->work_logprobs;
|
|
|
|
|
{
|
|
|
|
|
logits.resize(n_logits);
|
|
|
|
|
memcpy(logits.data(), ctx->logits.data() + (ctx->logits.size() - n_logits), n_logits*sizeof(float));
|
|
|
|
|
|
|
|
|
|
// will be populated a bit later
|
|
|
|
|
logprobs.resize(n_logits);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// apply logit filters here
|
|
|
|
|
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L480-L493
|
|
|
|
|
// TODO: apply logit filters here
|
|
|
|
|
{
|
|
|
|
|
// suppress blank
|
|
|
|
|
// https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L388-L390
|
|
|
|
|
if (params.suppress_blank) {
|
|
|
|
|
if (is_initial) {
|
|
|
|
|
logits[vocab.token_eot] = -INFINITY;
|
|
|
|
|
logits[vocab.token_to_id.at(" ")] = -INFINITY;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// suppress <|notimestamps|> token
|
|
|
|
|
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412
|
|
|
|
|
logits[vocab.token_not] = -INFINITY;
|
|
|
|
|
|
|
|
|
|
// timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
|
|
|
|
// https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424
|
|
|
|
|
{
|
|
|
|
|
const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg;
|
|
|
|
|
const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg;
|
|
|
|
|
|
|
|
|
|
if (last_was_timestamp) {
|
|
|
|
|
if (penultimate_was_timestamp) {
|
|
|
|
|
for (int i = vocab.token_beg; i < n_logits; ++ i) {
|
|
|
|
|
logits[i] = -INFINITY;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
for (int i = 0; i < vocab.token_eot; ++ i) {
|
|
|
|
|
logits[i] = -INFINITY;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// the initial timestamp cannot be larger than max_initial_timestamp
|
|
|
|
|
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429
|
|
|
|
|
if (is_initial && params.max_initial_timestamp > 0.0f) {
|
|
|
|
|
const float precision = float(WHISPER_CHUNK_SIZE)/ctx->model.hparams.n_audio_ctx;
|
|
|
|
|
const int tid0 = std::round(params.max_initial_timestamp/precision);
|
|
|
|
|
|
|
|
|
|
for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++ i) {
|
|
|
|
|
logits[i] = -INFINITY;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// populate the logprobs array (log_softmax)
|
|
|
|
|
{
|
|
|
|
|
const float logit_max = *std::max_element(logits.begin(), logits.end());
|
|
|
|
|
float logsumexp = 0.0f;
|
|
|
|
|
for (int i = 0; i < n_logits; ++ i) {
|
|
|
|
|
logsumexp += expf(logits[i] - logit_max);
|
|
|
|
|
}
|
|
|
|
|
logsumexp = logf(logsumexp) + logit_max;
|
|
|
|
|
for (int i = 0; i < n_logits; ++ i) {
|
|
|
|
|
logprobs[i] = logits[i] - logsumexp;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// if sum of probability over timestamps is above any other token, sample timestamp
|
|
|
|
|
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437
|
|
|
|
|
{
|
|
|
|
|
// logsumexp over timestamps
|
|
|
|
|
float timestamp_logprob = -INFINITY;
|
|
|
|
|
{
|
|
|
|
|
float logsumexp = 0.0f;
|
|
|
|
|
const float logprob_max = *std::max_element(logprobs.begin() + vocab.token_beg, logprobs.end());
|
|
|
|
|
for (int i = vocab.token_beg; i < n_logits; ++ i) {
|
|
|
|
|
logsumexp += expf(logprobs[i] - logprob_max);
|
|
|
|
|
}
|
|
|
|
|
logsumexp = logf(logsumexp) + logprob_max;
|
|
|
|
|
timestamp_logprob = logsumexp;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg);
|
|
|
|
|
|
|
|
|
|
if (timestamp_logprob > max_text_token_logprob) {
|
|
|
|
|
for (int i = 0; i < vocab.token_beg; ++ i) {
|
|
|
|
|
logits[i] = -INFINITY;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// print first 100 logits - token string : logit
|
|
|
|
|
for (int i = 0; i < 100; i++) {
|
|
|
|
|
const auto token = vocab.id_to_token.at(i);
|
|
|
|
|
const auto logit = logits[i];
|
|
|
|
|
printf("%s : %f\n", token.c_str(), logit);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// "And", "and", " And", " and"
|
|
|
|
|
printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]);
|
|
|
|
|
printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]);
|
|
|
|
|
printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
|
|
|
|
|
printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
|
|
|
|
|
printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]);
|
|
|
|
|
|
|
|
|
|
printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]);
|
|
|
|
|
printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]);
|
|
|
|
|
printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
|
|
|
|
|
printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
|
|
|
|
|
printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
|
|
|
|
|
|
|
|
|
|
switch (params.strategy) {
|
|
|
|
|
case WHISPER_SAMPLING_GREEDY:
|
|
|
|
|
{
|
|
|
|
|