whisper : wip decoding parameters + strategies (part 4)

pull/291/head
Georgi Gerganov 3 years ago
parent 116dd67a15
commit bd6e70b512
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -472,13 +472,12 @@ struct whisper_context {
std::vector<uint8_t> buf_compute_layer;
// decode output (2-dimensional array: [n_tokens][n_vocab])
std::vector<float> probs;
std::vector<float> logits;
std::vector<whisper_segment> result_all;
std::vector<whisper_token> prompt_past;
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
std::vector<std::pair<double, whisper_vocab::id>> logits_id;
// [EXPERIMENTAL] token-level timestamps data
int64_t t_beg;
@ -575,6 +574,10 @@ static void kv_cache_free(struct whisper_kv_cache & cache) {
static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) {
fprintf(stderr, "%s: loading model\n", __func__);
const int64_t t_start_us = ggml_time_us();
wctx.t_start_us = t_start_us;
auto & model = wctx.model;
auto & vocab = wctx.vocab;
@ -770,9 +773,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
}
wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
wctx.probs.reserve (vocab.n_vocab*model.hparams.n_text_ctx);
wctx.probs_id.reserve(n_vocab);
wctx.logits_id.reserve(n_vocab);
// TAGS: WHISPER_DECODER_INIT
wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx);
@ -1178,6 +1180,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
}
}
wctx.t_load_us = ggml_time_us() - t_start_us;
return true;
}
@ -1192,8 +1196,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
//
static bool whisper_encode(
whisper_context & wctx,
const int n_threads,
const int mel_offset) {
const int mel_offset,
const int n_threads) {
const int64_t t_start_us = ggml_time_us();
const auto & model = wctx.model;
const auto & mel_inp = wctx.mel;
const auto & hparams = model.hparams;
@ -1585,6 +1591,8 @@ static bool whisper_encode(
ggml_free(ctx0);
wctx.t_encode_us += ggml_time_us() - t_start_us;
return true;
}
@ -1601,10 +1609,12 @@ static bool whisper_encode(
static bool whisper_decode(
whisper_context & wctx,
whisper_decoder & decoder,
const int n_threads,
const whisper_token * tokens,
const int n_tokens,
const int n_past) {
const int n_past,
const int n_threads) {
const int64_t t_start_us = ggml_time_us();
const auto & model = wctx.model;
const auto & hparams = model.hparams;
@ -1613,7 +1623,6 @@ static bool whisper_decode(
WHISPER_ASSERT(!!kv_self.ctx);
auto & logits_out = wctx.logits;
auto & probs_out = wctx.probs;
const int n_vocab = hparams.n_vocab;
@ -1625,6 +1634,8 @@ static bool whisper_decode(
const int N = n_tokens;
const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
//fprintf(stderr, "n_past = %d, N = %d, M = %d, n_ctx = %d\n", n_past, N, M, n_ctx);
struct ggml_init_params params;
params.mem_size = wctx.buf_compute.size();
params.mem_buffer = wctx.buf_compute.data();
@ -1933,25 +1944,18 @@ static bool whisper_decode(
struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
// logits -> probs
cur = ggml_dup(ctx0, logits);
cur = ggml_soft_max(ctx0, cur); // in-place
// run the computation
{
struct ggml_cgraph gf = {};
gf.n_threads = n_threads;
ggml_build_forward_expand(&gf, cur);
ggml_build_forward_expand(&gf, logits);
ggml_graph_compute (ctx0, &gf);
}
logits_out.resize(N*n_vocab);
memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
probs_out.resize(N*n_vocab);
memcpy(probs_out.data(), ggml_get_data(cur), sizeof(float)*N*n_vocab);
if (N > 1) {
//const float mem_per_token = ggml_used_mem(ctx0)/1024.0/1024.0/N;
//printf("%s: used_mem = %f MB / %f per token\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, mem_per_token);
@ -1960,6 +1964,8 @@ static bool whisper_decode(
ggml_free(ctx0);
wctx.t_decode_us += ggml_time_us() - t_start_us;
return true;
}
@ -2062,6 +2068,7 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
static bool log_mel_spectrogram(
whisper_context & wctx,
const float * samples,
const int n_samples,
const int /*sample_rate*/,
@ -2072,6 +2079,7 @@ static bool log_mel_spectrogram(
const whisper_filters & filters,
const bool speed_up,
whisper_mel & mel) {
const int64_t t_start_us = ggml_time_us();
// Hanning window
std::vector<float> hann;
@ -2180,6 +2188,8 @@ static bool log_mel_spectrogram(
mel.data[i] = (mel.data[i] + 4.0)/4.0;
}
wctx.t_mel_us += ggml_time_us() - t_start_us;
return true;
}
@ -2324,10 +2334,6 @@ struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
whisper_context * ctx = new whisper_context;
const int64_t t_start_us = ggml_time_us();
ctx->t_start_us = t_start_us;
if (!whisper_model_load(loader, *ctx)) {
loader->close(loader->context);
fprintf(stderr, "%s: failed to load model\n", __func__);
@ -2335,8 +2341,6 @@ struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
return nullptr;
}
ctx->t_load_us = ggml_time_us() - t_start_us;
loader->close(loader->context);
return ctx;
@ -2363,29 +2367,21 @@ void whisper_free(struct whisper_context * ctx) {
}
int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
const int64_t t_start_us = ggml_time_us();
if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) {
if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) {
fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
return -1;
}
ctx->t_mel_us = ggml_time_us() - t_start_us;
return 0;
}
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
const int64_t t_start_us = ggml_time_us();
if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) {
if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) {
fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
return -1;
}
ctx->t_mel_us = ggml_time_us() - t_start_us;
return 0;
}
@ -2409,31 +2405,23 @@ int whisper_set_mel(
}
int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
const int64_t t_start_us = ggml_time_us();
if (!whisper_encode(*ctx, n_threads, offset)) {
if (!whisper_encode(*ctx, offset, n_threads)) {
fprintf(stderr, "%s: failed to eval\n", __func__);
return -1;
}
ctx->t_encode_us += ggml_time_us() - t_start_us;
return 0;
}
int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
const int64_t t_start_us = ggml_time_us();
// TODO: add selected_decoder_id to context
const int selected_decoder_id = 0;
if (!whisper_decode(*ctx, ctx->decoders[selected_decoder_id], n_threads, tokens, n_tokens, n_past)) {
if (!whisper_decode(*ctx, ctx->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
fprintf(stderr, "%s: failed to eval\n", __func__);
return 1;
}
ctx->t_decode_us += ggml_time_us() - t_start_us;
return 0;
}
@ -2517,36 +2505,39 @@ int whisper_lang_auto_detect(
return -7;
}
auto & probs_id = ctx->probs_id;
probs_id.clear();
auto & logits_id = ctx->logits_id;
logits_id.clear();
for (const auto & kv : g_lang) {
const auto token_lang = whisper_token_lang(ctx, kv.second.first);
probs_id.emplace_back(ctx->probs[token_lang], kv.second.first);
logits_id.emplace_back(ctx->logits[token_lang], kv.second.first);
}
// sort descending
{
using pair_type = std::remove_reference<decltype(probs_id)>::type::value_type;
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
using pair_type = std::remove_reference<decltype(logits_id)>::type::value_type;
std::sort(logits_id.begin(), logits_id.end(), [](const pair_type & a, const pair_type & b) {
return a.first > b.first;
});
}
// softmax
{
float sum = 0;
for (const auto & kv : probs_id) {
sum += exp(kv.first);
const auto max = logits_id[0].first;
double sum = 0.0f;
for (auto & kv : logits_id) {
kv.first = exp(kv.first - max);
sum += kv.first;
}
for (auto & kv : probs_id) {
kv.first = exp(kv.first) / sum;
for (auto & kv : logits_id) {
kv.first /= sum;
}
}
{
for (const auto & prob : probs_id) {
for (const auto & prob : logits_id) {
if (lang_probs) {
lang_probs[prob.second] = prob.first;
}
@ -2555,7 +2546,7 @@ int whisper_lang_auto_detect(
}
}
return probs_id[0].second;
return logits_id[0].second;
}
int whisper_n_len(struct whisper_context * ctx) {
@ -2582,10 +2573,6 @@ float * whisper_get_logits(struct whisper_context * ctx) {
return ctx->logits.data();
}
float * whisper_get_probs(struct whisper_context * ctx) {
return ctx->probs.data();
}
const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) {
return ctx->vocab.id_to_token.at(token).c_str();
}
@ -2878,13 +2865,11 @@ static void whisper_process_logits(
auto & logits = decoder.logits;
auto & logprobs = decoder.logprobs;
{
probs.resize(n_logits);
memcpy(probs.data(), ctx.probs.data() + (ctx.probs.size() - n_logits), n_logits*sizeof(float));
logits.resize(n_logits);
memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float));
// will be populated a bit later
probs.resize(n_logits);
logprobs.resize(n_logits);
}
@ -2904,6 +2889,10 @@ static void whisper_process_logits(
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412
logits[vocab.token_not] = -INFINITY;
// suppress sot and solm tokens
logits[vocab.token_sot] = -INFINITY;
logits[vocab.token_solm] = -INFINITY;
// timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
// https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424
{
@ -2914,11 +2903,11 @@ static void whisper_process_logits(
if (last_was_timestamp) {
if (penultimate_was_timestamp) {
for (int i = vocab.token_beg; i < n_logits; ++ i) {
for (int i = vocab.token_beg; i < n_logits; ++i) {
logits[i] = -INFINITY;
}
} else {
for (int i = 0; i < vocab.token_eot; ++ i) {
for (int i = 0; i < vocab.token_eot; ++i) {
logits[i] = -INFINITY;
}
}
@ -2931,7 +2920,7 @@ static void whisper_process_logits(
const float precision = float(WHISPER_CHUNK_SIZE)/ctx.model.hparams.n_audio_ctx;
const int tid0 = std::round(params.max_initial_timestamp/precision);
for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++ i) {
for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++i) {
logits[i] = -INFINITY;
}
}
@ -2940,12 +2929,19 @@ static void whisper_process_logits(
{
const float logit_max = *std::max_element(logits.begin(), logits.end());
float logsumexp = 0.0f;
for (int i = 0; i < n_logits; ++ i) {
for (int i = 0; i < n_logits; ++i) {
if (logits[i] > -INFINITY) {
logsumexp += expf(logits[i] - logit_max);
}
}
logsumexp = logf(logsumexp) + logit_max;
for (int i = 0; i < n_logits; ++ i) {
for (int i = 0; i < n_logits; ++i) {
if (logits[i] > -INFINITY) {
logprobs[i] = logits[i] - logsumexp;
} else {
logprobs[i] = -INFINITY;
}
}
}
@ -2957,20 +2953,37 @@ static void whisper_process_logits(
{
float logsumexp = 0.0f;
const float logprob_max = *std::max_element(logprobs.begin() + vocab.token_beg, logprobs.end());
for (int i = vocab.token_beg; i < n_logits; ++ i) {
for (int i = vocab.token_beg; i < n_logits; ++i) {
if (logprobs[i] > -INFINITY) {
logsumexp += expf(logprobs[i] - logprob_max);
}
logsumexp = logf(logsumexp) + logprob_max;
timestamp_logprob = logsumexp;
}
if (logsumexp > 0.0f) {
timestamp_logprob = logf(logsumexp) + logprob_max;
}
}
const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg);
//fprintf(stderr, "timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
if (timestamp_logprob > max_text_token_logprob) {
for (int i = 0; i < vocab.token_beg; ++ i) {
for (int i = 0; i < vocab.token_beg; ++i) {
logits[i] = -INFINITY;
logprobs[i] = -INFINITY;
}
}
}
}
// compute probs
{
for (int i = 0; i < n_logits; ++i) {
if (logits[i] == -INFINITY) {
probs[i] = 0.0f;
} else {
probs[i] = expf(logprobs[i]);
}
}
}
@ -2978,8 +2991,10 @@ static void whisper_process_logits(
// print first 100 logits - token string : logit
for (int i = 0; i < 100; i++) {
const auto token = vocab.id_to_token.at(i);
const auto prob = probs[i];
const auto logit = logits[i];
printf("%s : %f\n", token.c_str(), logit);
const auto logprob = logprobs[i];
printf("%s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob);
}
// "And", "and", " And", " and"
@ -2994,26 +3009,13 @@ static void whisper_process_logits(
printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
#endif
//switch (params.strategy) {
// case WHISPER_SAMPLING_GREEDY:
// {
// // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L249-L274
// // TODO: implement
// result = (is_initial) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx);
// } break;
// case WHISPER_SAMPLING_BEAM_SEARCH:
// {
// // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L277C13-L364
// // TODO: implement
// } break;
//}
//sum_logprobs += logprobs[result.id];
//printf("token: %16s, logprob: %8.3f, sum_logprobs: %8.3f, avg_logprob: %8.3f\n", vocab.id_to_token.at(result.id).c_str(), logprobs[result.id], sum_logprobs, sum_logprobs/(tokens_cur.size() + 1));
//return result;
printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]);
printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]);
printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]);
#endif
}
// select the most probable token
@ -3052,13 +3054,6 @@ static whisper_token_data whisper_sample_best(
}
for (int i = 0; i < n_logits; ++i) {
// never sample these:
if (i == vocab.token_sot ||
i == vocab.token_solm ||
i == vocab.token_not) {
continue;
}
if (result.p < probs[i]) {
result.id = i;
result.p = probs[i];
@ -3334,18 +3329,23 @@ int whisper_full(
// print the prompt
//printf("\n\n");
//for (int i = 0; i < prompt.size(); i++) {
//for (int i = 0; i < (int) prompt.size(); i++) {
// printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token[prompt[i]].c_str());
//}
//printf("\n\n");
if (!whisper_decode(*ctx, ctx->decoders[0], params.n_threads, prompt.data(), prompt.size(), 0)) {
if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
fprintf(stderr, "%s: failed to decode\n", __func__);
return -7;
}
{
const int64_t t_start_sample_us = ggml_time_us();
whisper_process_logits(*ctx, ctx->decoders[0], params);
ctx->decoders[0].n_past += prompt.size();
for (int j = 1; j < n_decoders_cur; ++j) {
auto & decoder = ctx->decoders[j];
@ -3357,9 +3357,14 @@ int whisper_full(
memcpy(decoder.logits.data(), ctx->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
memcpy(decoder.logprobs.data(), ctx->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
}
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
}
}
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
const int64_t t_start_sample_us = ggml_time_us();
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = ctx->decoders[j];
@ -3407,7 +3412,8 @@ int whisper_full(
{
const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]";
printf("%s: %3d, decoder = %d, %10s %6d %6.3f '%s'\n", __func__, i, j, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str());
printf("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n",
__func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token[token.id].c_str());
}
// end of segment
@ -3468,6 +3474,8 @@ int whisper_full(
}
}
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = ctx->decoders[j];
@ -3480,14 +3488,20 @@ int whisper_full(
//fprintf(stderr, "%s: decoder %d: token %d, n_past %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.n_past, decoder.seek_delta);
if (!whisper_decode(*ctx, decoder, params.n_threads, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.n_past)) {
if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.n_past, params.n_threads)) {
fprintf(stderr, "%s: failed to decode\n", __func__);
return -7;
}
{
const int64_t t_start_sample_us = ggml_time_us();
whisper_process_logits(*ctx, decoder, params);
++decoder.n_past;
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
}
}
}
@ -3502,12 +3516,15 @@ int whisper_full(
continue;
}
whisper_sequence_score(params, ctx->decoders[j].sequence);
decoder.sequence.tokens.resize(decoder.sequence.result_len);
whisper_sequence_score(params, decoder.sequence);
if (best_score < decoder.sequence.score) {
best_score = decoder.sequence.score;
best_decoder_id = j;
}
fprintf(stderr, "%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f\n", __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs);
}
}
@ -3671,9 +3688,8 @@ int whisper_full_parallel(
ctx_p = *ctx;
ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx);
ctx_p.probs.reserve (ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx);
ctx_p.probs_id.reserve(ctx_p.vocab.n_vocab);
ctx_p.logits_id.reserve(ctx_p.vocab.n_vocab);
if (!kv_cache_reinit(ctx_p.kv_cross)) {
fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention, processor %d\n", __func__, i);

@ -191,12 +191,6 @@ extern "C" {
// Cols: n_vocab
WHISPER_API float * whisper_get_logits(struct whisper_context * ctx);
// Token probabilities (i.e. softmax(logits)) obtained from the last call to whisper_decode()
// The probabilities for the last token are stored in the last row
// Rows: n_tokens
// Cols: n_vocab
WHISPER_API float * whisper_get_probs(struct whisper_context * ctx);
// Token Id -> String. Uses the vocabulary in the provided context
WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token);

Loading…
Cancel
Save