whisper : add beam-search support

pull/291/head
Georgi Gerganov 3 years ago
parent 5e97f80fc5
commit 5548a1986f
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -652,6 +652,9 @@ int main(int argc, char ** argv) {
wparams.speed_up = params.speed_up; wparams.speed_up = params.speed_up;
wparams.greedy.best_of = params.best_of;
wparams.beam_search.beam_size = params.beam_size;
wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data(); wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size(); wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();

@ -433,7 +433,8 @@ struct whisper_sequence {
// the accumulated transcription in the current interation (used to truncate the tokens array) // the accumulated transcription in the current interation (used to truncate the tokens array)
int result_len; int result_len;
double sum_logprobs; // the sum of the log probabilities of the tokens double sum_logprobs_all; // the sum of the log probabilities of the tokens
double sum_logprobs; // the sum of the log probabilities of the tokens (first result_len tokens)
double avg_logprobs; // the average log probability of the tokens double avg_logprobs; // the average log probability of the tokens
double entropy; // the entropy of the tokens double entropy; // the entropy of the tokens
double score; // likelihood rank score double score; // likelihood rank score
@ -1650,7 +1651,7 @@ static bool whisper_decode(
const int N = n_tokens; const int N = n_tokens;
const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx; const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx); //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
struct ggml_init_params params; struct ggml_init_params params;
params.mem_size = wctx.buf_compute.size(); params.mem_size = wctx.buf_compute.size();
@ -3100,6 +3101,74 @@ static whisper_token_data whisper_sample_token(
return result; return result;
} }
static std::vector<whisper_token_data> whisper_sample_token_topk(
whisper_context & ctx,
const whisper_decoder & decoder,
int k) {
const auto & vocab = ctx.vocab;
const auto & probs = decoder.probs;
const auto & logits = decoder.logits;
const auto & logprobs = decoder.logprobs;
const int n_logits = vocab.n_vocab;
auto & logits_id = ctx.logits_id;
logits_id.clear();
for (int i = 0; i < n_logits; ++i) {
logits_id.push_back({ logits[i], i });
}
std::partial_sort(
logits_id.begin(),
logits_id.begin() + k, logits_id.end(),
[](const std::pair<double, whisper_token> & a, const std::pair<double, whisper_token> & b) {
return a.first > b.first;
});
std::vector<whisper_token_data> result(k);
{
double sum_ts = 0.0;
double max_ts = 0.0;
for (int i = vocab.token_beg; i < n_logits; i++) {
if (probs[i] == -INFINITY) {
continue;
}
sum_ts += probs[i];
if (max_ts < probs[i]) {
max_ts = probs[i];
result[0].tid = i;
}
}
result[0].pt = max_ts/(sum_ts + 1e-10);
result[0].ptsum = sum_ts;
}
for (int i = 0; i < k; ++i) {
result[i].id = logits_id[i].second;
result[i].p = probs[result[i].id];
result[i].plog = logprobs[result[i].id];
result[i].tid = result[0].tid;
result[i].pt = result[0].pt;
result[i].ptsum = result[0].ptsum;
result[i].t0 = -1;
result[i].t1 = -1;
result[i].vlen = 0.0f;
if (result[i].id >= vocab.token_beg) {
result[i].tid = result[i].id;
result[i].pt = result[i].p;
}
}
return result;
}
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L178-L192 // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L178-L192
static void whisper_sequence_score( static void whisper_sequence_score(
const struct whisper_full_params & params, const struct whisper_full_params & params,
@ -3120,7 +3189,7 @@ static void whisper_sequence_score(
double penalty = sequence.result_len; double penalty = sequence.result_len;
if (params.length_penalty > 0.0f) { if (params.length_penalty > 0.0f) {
penalty = pow((5.0 + penalty) / 6.0, params.length_penalty); penalty = pow((5.0 + penalty)/6.0, params.length_penalty);
} }
sequence.score = result/penalty; sequence.score = result/penalty;
@ -3141,7 +3210,8 @@ static void whisper_sequence_score(
for (const auto & kv : token_counts) { for (const auto & kv : token_counts) {
const auto p = kv.second/(double)cnt; const auto p = kv.second/(double)cnt;
entropy -= p*log(p); entropy -= p*log(p);
//printf("entropy: %d %f %f, count %d\n", kv.first, p, log(p), kv.second);
//WHISPER_PRINT_DEBUG("entropy: %d %f %f, count %d\n", kv.first, p, log(p), kv.second);
} }
sequence.entropy = entropy; sequence.entropy = entropy;
@ -3293,6 +3363,25 @@ int whisper_full(
std::vector<whisper_token> prompt; std::vector<whisper_token> prompt;
prompt.reserve(whisper_n_text_ctx(ctx)); prompt.reserve(whisper_n_text_ctx(ctx));
// beam-search helpers
struct kv_buf {
std::vector<uint8_t> k;
std::vector<uint8_t> v;
};
std::vector<kv_buf> kv_bufs;
struct beam_candidate {
int decoder_idx;
int seek_delta;
bool has_ts;
whisper_sequence sequence;
};
std::vector<beam_candidate> beam_candidates;
// main loop // main loop
while (true) { while (true) {
const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start); const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start);
@ -3361,9 +3450,10 @@ int whisper_full(
decoder.sequence.tokens.clear(); decoder.sequence.tokens.clear();
decoder.sequence.result_len = 0; decoder.sequence.result_len = 0;
decoder.sequence.sum_logprobs_all = 0.0;
decoder.sequence.sum_logprobs = -INFINITY; decoder.sequence.sum_logprobs = -INFINITY;
decoder.sequence.avg_logprobs = -INFINITY; decoder.sequence.avg_logprobs = -INFINITY;
decoder.sequence.entropy = 0.0f; decoder.sequence.entropy = 0.0;
decoder.sequence.score = -INFINITY; decoder.sequence.score = -INFINITY;
decoder.n_past = 0; decoder.n_past = 0;
@ -3412,7 +3502,8 @@ int whisper_full(
for (int j = 1; j < n_decoders_cur; ++j) { for (int j = 1; j < n_decoders_cur; ++j) {
auto & decoder = ctx->decoders[j]; auto & decoder = ctx->decoders[j];
memcpy(decoder.kv_self.buf.data(), ctx->decoders[0].kv_self.buf.data(), decoder.kv_self.buf.size()); memcpy(decoder.kv_self.k->data, ctx->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
memcpy(decoder.kv_self.v->data, ctx->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
decoder.n_past += prompt.size(); decoder.n_past += prompt.size();
@ -3428,6 +3519,25 @@ int whisper_full(
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
const int64_t t_start_sample_us = ggml_time_us(); const int64_t t_start_sample_us = ggml_time_us();
if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
kv_bufs.resize(n_decoders_cur);
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = ctx->decoders[j];
if (decoder.completed || decoder.failed) {
continue;
}
kv_bufs[j].k.resize(ggml_nbytes(decoder.kv_self.k));
kv_bufs[j].v.resize(ggml_nbytes(decoder.kv_self.v));
memcpy(kv_bufs[j].k.data(), decoder.kv_self.k->data, kv_bufs[j].k.size());
memcpy(kv_bufs[j].v.data(), decoder.kv_self.v->data, kv_bufs[j].v.size());
}
beam_candidates.clear();
}
for (int j = 0; j < n_decoders_cur; ++j) { for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = ctx->decoders[j]; auto & decoder = ctx->decoders[j];
@ -3443,12 +3553,65 @@ int whisper_full(
} else { } else {
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false)); decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false));
} }
decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
} break; } break;
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
{ {
// TODO: .. const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size);
for (const auto & token : tokens_new) {
beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence });
beam_candidates.back().sequence.tokens.push_back(token);
beam_candidates.back().sequence.sum_logprobs_all += token.plog;
//WHISPER_PRINT_DEBUG("%s: beam candidate: %s (%f, %f)\n", __func__, ctx->vocab.id_to_token.at(token.id).c_str(), token.plog, beam_candidates.back().sequence.sum_logprobs_all);
}
} break; } break;
}; };
}
if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
std::sort(
beam_candidates.begin(),
beam_candidates.end(),
[](const beam_candidate & a, const beam_candidate & b) {
return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
});
int cur_c = 0;
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = ctx->decoders[j];
if (decoder.completed || decoder.failed) {
continue;
}
auto & cur = beam_candidates[cur_c++];
while (beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) {
++cur_c;
}
decoder.sequence = cur.sequence;
decoder.seek_delta = cur.seek_delta;
decoder.has_ts = cur.has_ts;
memcpy(decoder.kv_self.k->data, kv_bufs[cur.decoder_idx].k.data(), kv_bufs[cur.decoder_idx].k.size());
memcpy(decoder.kv_self.v->data, kv_bufs[cur.decoder_idx].v.data(), kv_bufs[cur.decoder_idx].v.size());
WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
__func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
}
}
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = ctx->decoders[j];
if (decoder.completed || decoder.failed) {
continue;
}
auto & has_ts = decoder.has_ts; auto & has_ts = decoder.has_ts;
auto & failed = decoder.failed; auto & failed = decoder.failed;
@ -3659,6 +3822,7 @@ int whisper_full(
} else { } else {
text += whisper_token_to_str(ctx, tokens_cur[i].id); text += whisper_token_to_str(ctx, tokens_cur[i].id);
} }
if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) { if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) {
const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx)); const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
if (!text.empty()) { if (!text.empty()) {
@ -4077,6 +4241,8 @@ static void whisper_exp_compute_token_level_timestamps(
p1--; p1--;
} }
//printf("p0=%d p1=%d t0=%lld t1=%lld\n", p0, p1, tokens[p0].t0, tokens[p1].t1);
if (p1 > p0) { if (p1 > p0) {
double psum = 0.0; double psum = 0.0;
for (int j = p0; j <= p1; j++) { for (int j = p0; j <= p1; j++) {
@ -4205,11 +4371,11 @@ static void whisper_exp_compute_token_level_timestamps(
// debug info // debug info
//for (int j = 0; j < n; ++j) { //for (int j = 0; j < n; ++j) {
// const auto & token = tokens[j]; // const auto & token = tokens[j];
// const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]"; // const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(&ctx, token.tid) : "[?]";
// printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__, // printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__,
// tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(ctx, token.id)); // tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(&ctx, token.id));
// if (tokens[j].id >= whisper_token_eot(ctx)) { // if (tokens[j].id >= whisper_token_eot(&ctx)) {
// continue; // continue;
// } // }
//} //}

@ -290,7 +290,7 @@ extern "C" {
struct { struct {
int beam_size; int beam_size;
float patience; float patience; // TODO: not implemented
} beam_search; } beam_search;
whisper_new_segment_callback new_segment_callback; whisper_new_segment_callback new_segment_callback;

Loading…
Cancel
Save