whisper : wip decoding parameters + strategies (part 4)

pull/291/head
Georgi Gerganov 3 years ago
parent 116dd67a15
commit bd6e70b512
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -472,13 +472,12 @@ struct whisper_context {
std::vector<uint8_t> buf_compute_layer; std::vector<uint8_t> buf_compute_layer;
// decode output (2-dimensional array: [n_tokens][n_vocab]) // decode output (2-dimensional array: [n_tokens][n_vocab])
std::vector<float> probs;
std::vector<float> logits; std::vector<float> logits;
std::vector<whisper_segment> result_all; std::vector<whisper_segment> result_all;
std::vector<whisper_token> prompt_past; std::vector<whisper_token> prompt_past;
std::vector<std::pair<double, whisper_vocab::id>> probs_id; std::vector<std::pair<double, whisper_vocab::id>> logits_id;
// [EXPERIMENTAL] token-level timestamps data // [EXPERIMENTAL] token-level timestamps data
int64_t t_beg; int64_t t_beg;
@ -575,6 +574,10 @@ static void kv_cache_free(struct whisper_kv_cache & cache) {
static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) { static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) {
fprintf(stderr, "%s: loading model\n", __func__); fprintf(stderr, "%s: loading model\n", __func__);
const int64_t t_start_us = ggml_time_us();
wctx.t_start_us = t_start_us;
auto & model = wctx.model; auto & model = wctx.model;
auto & vocab = wctx.vocab; auto & vocab = wctx.vocab;
@ -770,9 +773,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
} }
wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx); wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
wctx.probs.reserve (vocab.n_vocab*model.hparams.n_text_ctx);
wctx.probs_id.reserve(n_vocab); wctx.logits_id.reserve(n_vocab);
// TAGS: WHISPER_DECODER_INIT // TAGS: WHISPER_DECODER_INIT
wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx); wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx);
@ -1178,6 +1180,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
} }
} }
wctx.t_load_us = ggml_time_us() - t_start_us;
return true; return true;
} }
@ -1192,8 +1196,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
// //
static bool whisper_encode( static bool whisper_encode(
whisper_context & wctx, whisper_context & wctx,
const int n_threads, const int mel_offset,
const int mel_offset) { const int n_threads) {
const int64_t t_start_us = ggml_time_us();
const auto & model = wctx.model; const auto & model = wctx.model;
const auto & mel_inp = wctx.mel; const auto & mel_inp = wctx.mel;
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
@ -1585,6 +1591,8 @@ static bool whisper_encode(
ggml_free(ctx0); ggml_free(ctx0);
wctx.t_encode_us += ggml_time_us() - t_start_us;
return true; return true;
} }
@ -1601,10 +1609,12 @@ static bool whisper_encode(
static bool whisper_decode( static bool whisper_decode(
whisper_context & wctx, whisper_context & wctx,
whisper_decoder & decoder, whisper_decoder & decoder,
const int n_threads,
const whisper_token * tokens, const whisper_token * tokens,
const int n_tokens, const int n_tokens,
const int n_past) { const int n_past,
const int n_threads) {
const int64_t t_start_us = ggml_time_us();
const auto & model = wctx.model; const auto & model = wctx.model;
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
@ -1613,7 +1623,6 @@ static bool whisper_decode(
WHISPER_ASSERT(!!kv_self.ctx); WHISPER_ASSERT(!!kv_self.ctx);
auto & logits_out = wctx.logits; auto & logits_out = wctx.logits;
auto & probs_out = wctx.probs;
const int n_vocab = hparams.n_vocab; const int n_vocab = hparams.n_vocab;
@ -1625,6 +1634,8 @@ static bool whisper_decode(
const int N = n_tokens; const int N = n_tokens;
const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx; const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
//fprintf(stderr, "n_past = %d, N = %d, M = %d, n_ctx = %d\n", n_past, N, M, n_ctx);
struct ggml_init_params params; struct ggml_init_params params;
params.mem_size = wctx.buf_compute.size(); params.mem_size = wctx.buf_compute.size();
params.mem_buffer = wctx.buf_compute.data(); params.mem_buffer = wctx.buf_compute.data();
@ -1933,25 +1944,18 @@ static bool whisper_decode(
struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur); struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
// logits -> probs
cur = ggml_dup(ctx0, logits);
cur = ggml_soft_max(ctx0, cur); // in-place
// run the computation // run the computation
{ {
struct ggml_cgraph gf = {}; struct ggml_cgraph gf = {};
gf.n_threads = n_threads; gf.n_threads = n_threads;
ggml_build_forward_expand(&gf, cur); ggml_build_forward_expand(&gf, logits);
ggml_graph_compute (ctx0, &gf); ggml_graph_compute (ctx0, &gf);
} }
logits_out.resize(N*n_vocab); logits_out.resize(N*n_vocab);
memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab); memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
probs_out.resize(N*n_vocab);
memcpy(probs_out.data(), ggml_get_data(cur), sizeof(float)*N*n_vocab);
if (N > 1) { if (N > 1) {
//const float mem_per_token = ggml_used_mem(ctx0)/1024.0/1024.0/N; //const float mem_per_token = ggml_used_mem(ctx0)/1024.0/1024.0/N;
//printf("%s: used_mem = %f MB / %f per token\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, mem_per_token); //printf("%s: used_mem = %f MB / %f per token\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, mem_per_token);
@ -1960,6 +1964,8 @@ static bool whisper_decode(
ggml_free(ctx0); ggml_free(ctx0);
wctx.t_decode_us += ggml_time_us() - t_start_us;
return true; return true;
} }
@ -2062,6 +2068,7 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124 // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
static bool log_mel_spectrogram( static bool log_mel_spectrogram(
whisper_context & wctx,
const float * samples, const float * samples,
const int n_samples, const int n_samples,
const int /*sample_rate*/, const int /*sample_rate*/,
@ -2072,6 +2079,7 @@ static bool log_mel_spectrogram(
const whisper_filters & filters, const whisper_filters & filters,
const bool speed_up, const bool speed_up,
whisper_mel & mel) { whisper_mel & mel) {
const int64_t t_start_us = ggml_time_us();
// Hanning window // Hanning window
std::vector<float> hann; std::vector<float> hann;
@ -2180,6 +2188,8 @@ static bool log_mel_spectrogram(
mel.data[i] = (mel.data[i] + 4.0)/4.0; mel.data[i] = (mel.data[i] + 4.0)/4.0;
} }
wctx.t_mel_us += ggml_time_us() - t_start_us;
return true; return true;
} }
@ -2324,10 +2334,6 @@ struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
whisper_context * ctx = new whisper_context; whisper_context * ctx = new whisper_context;
const int64_t t_start_us = ggml_time_us();
ctx->t_start_us = t_start_us;
if (!whisper_model_load(loader, *ctx)) { if (!whisper_model_load(loader, *ctx)) {
loader->close(loader->context); loader->close(loader->context);
fprintf(stderr, "%s: failed to load model\n", __func__); fprintf(stderr, "%s: failed to load model\n", __func__);
@ -2335,8 +2341,6 @@ struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
return nullptr; return nullptr;
} }
ctx->t_load_us = ggml_time_us() - t_start_us;
loader->close(loader->context); loader->close(loader->context);
return ctx; return ctx;
@ -2363,29 +2367,21 @@ void whisper_free(struct whisper_context * ctx) {
} }
int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
const int64_t t_start_us = ggml_time_us(); if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) {
if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) {
fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__); fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
return -1; return -1;
} }
ctx->t_mel_us = ggml_time_us() - t_start_us;
return 0; return 0;
} }
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
const int64_t t_start_us = ggml_time_us(); if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) {
if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) {
fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__); fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
return -1; return -1;
} }
ctx->t_mel_us = ggml_time_us() - t_start_us;
return 0; return 0;
} }
@ -2409,31 +2405,23 @@ int whisper_set_mel(
} }
int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) { int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
const int64_t t_start_us = ggml_time_us(); if (!whisper_encode(*ctx, offset, n_threads)) {
if (!whisper_encode(*ctx, n_threads, offset)) {
fprintf(stderr, "%s: failed to eval\n", __func__); fprintf(stderr, "%s: failed to eval\n", __func__);
return -1; return -1;
} }
ctx->t_encode_us += ggml_time_us() - t_start_us;
return 0; return 0;
} }
int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) { int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
const int64_t t_start_us = ggml_time_us();
// TODO: add selected_decoder_id to context // TODO: add selected_decoder_id to context
const int selected_decoder_id = 0; const int selected_decoder_id = 0;
if (!whisper_decode(*ctx, ctx->decoders[selected_decoder_id], n_threads, tokens, n_tokens, n_past)) { if (!whisper_decode(*ctx, ctx->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
fprintf(stderr, "%s: failed to eval\n", __func__); fprintf(stderr, "%s: failed to eval\n", __func__);
return 1; return 1;
} }
ctx->t_decode_us += ggml_time_us() - t_start_us;
return 0; return 0;
} }
@ -2517,36 +2505,39 @@ int whisper_lang_auto_detect(
return -7; return -7;
} }
auto & probs_id = ctx->probs_id; auto & logits_id = ctx->logits_id;
probs_id.clear(); logits_id.clear();
for (const auto & kv : g_lang) { for (const auto & kv : g_lang) {
const auto token_lang = whisper_token_lang(ctx, kv.second.first); const auto token_lang = whisper_token_lang(ctx, kv.second.first);
probs_id.emplace_back(ctx->probs[token_lang], kv.second.first); logits_id.emplace_back(ctx->logits[token_lang], kv.second.first);
} }
// sort descending // sort descending
{ {
using pair_type = std::remove_reference<decltype(probs_id)>::type::value_type; using pair_type = std::remove_reference<decltype(logits_id)>::type::value_type;
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) { std::sort(logits_id.begin(), logits_id.end(), [](const pair_type & a, const pair_type & b) {
return a.first > b.first; return a.first > b.first;
}); });
} }
// softmax // softmax
{ {
float sum = 0; const auto max = logits_id[0].first;
for (const auto & kv : probs_id) {
sum += exp(kv.first); double sum = 0.0f;
for (auto & kv : logits_id) {
kv.first = exp(kv.first - max);
sum += kv.first;
} }
for (auto & kv : probs_id) { for (auto & kv : logits_id) {
kv.first = exp(kv.first) / sum; kv.first /= sum;
} }
} }
{ {
for (const auto & prob : probs_id) { for (const auto & prob : logits_id) {
if (lang_probs) { if (lang_probs) {
lang_probs[prob.second] = prob.first; lang_probs[prob.second] = prob.first;
} }
@ -2555,7 +2546,7 @@ int whisper_lang_auto_detect(
} }
} }
return probs_id[0].second; return logits_id[0].second;
} }
int whisper_n_len(struct whisper_context * ctx) { int whisper_n_len(struct whisper_context * ctx) {
@ -2582,10 +2573,6 @@ float * whisper_get_logits(struct whisper_context * ctx) {
return ctx->logits.data(); return ctx->logits.data();
} }
float * whisper_get_probs(struct whisper_context * ctx) {
return ctx->probs.data();
}
const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) { const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) {
return ctx->vocab.id_to_token.at(token).c_str(); return ctx->vocab.id_to_token.at(token).c_str();
} }
@ -2878,13 +2865,11 @@ static void whisper_process_logits(
auto & logits = decoder.logits; auto & logits = decoder.logits;
auto & logprobs = decoder.logprobs; auto & logprobs = decoder.logprobs;
{ {
probs.resize(n_logits);
memcpy(probs.data(), ctx.probs.data() + (ctx.probs.size() - n_logits), n_logits*sizeof(float));
logits.resize(n_logits); logits.resize(n_logits);
memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float)); memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float));
// will be populated a bit later // will be populated a bit later
probs.resize(n_logits);
logprobs.resize(n_logits); logprobs.resize(n_logits);
} }
@ -2904,6 +2889,10 @@ static void whisper_process_logits(
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412 // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412
logits[vocab.token_not] = -INFINITY; logits[vocab.token_not] = -INFINITY;
// suppress sot and solm tokens
logits[vocab.token_sot] = -INFINITY;
logits[vocab.token_solm] = -INFINITY;
// timestamps have to appear in pairs, except directly before EOT; mask logits accordingly // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
// https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424 // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424
{ {
@ -2914,11 +2903,11 @@ static void whisper_process_logits(
if (last_was_timestamp) { if (last_was_timestamp) {
if (penultimate_was_timestamp) { if (penultimate_was_timestamp) {
for (int i = vocab.token_beg; i < n_logits; ++ i) { for (int i = vocab.token_beg; i < n_logits; ++i) {
logits[i] = -INFINITY; logits[i] = -INFINITY;
} }
} else { } else {
for (int i = 0; i < vocab.token_eot; ++ i) { for (int i = 0; i < vocab.token_eot; ++i) {
logits[i] = -INFINITY; logits[i] = -INFINITY;
} }
} }
@ -2931,7 +2920,7 @@ static void whisper_process_logits(
const float precision = float(WHISPER_CHUNK_SIZE)/ctx.model.hparams.n_audio_ctx; const float precision = float(WHISPER_CHUNK_SIZE)/ctx.model.hparams.n_audio_ctx;
const int tid0 = std::round(params.max_initial_timestamp/precision); const int tid0 = std::round(params.max_initial_timestamp/precision);
for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++ i) { for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++i) {
logits[i] = -INFINITY; logits[i] = -INFINITY;
} }
} }
@ -2940,12 +2929,19 @@ static void whisper_process_logits(
{ {
const float logit_max = *std::max_element(logits.begin(), logits.end()); const float logit_max = *std::max_element(logits.begin(), logits.end());
float logsumexp = 0.0f; float logsumexp = 0.0f;
for (int i = 0; i < n_logits; ++ i) { for (int i = 0; i < n_logits; ++i) {
if (logits[i] > -INFINITY) {
logsumexp += expf(logits[i] - logit_max); logsumexp += expf(logits[i] - logit_max);
} }
}
logsumexp = logf(logsumexp) + logit_max; logsumexp = logf(logsumexp) + logit_max;
for (int i = 0; i < n_logits; ++ i) {
for (int i = 0; i < n_logits; ++i) {
if (logits[i] > -INFINITY) {
logprobs[i] = logits[i] - logsumexp; logprobs[i] = logits[i] - logsumexp;
} else {
logprobs[i] = -INFINITY;
}
} }
} }
@ -2957,20 +2953,37 @@ static void whisper_process_logits(
{ {
float logsumexp = 0.0f; float logsumexp = 0.0f;
const float logprob_max = *std::max_element(logprobs.begin() + vocab.token_beg, logprobs.end()); const float logprob_max = *std::max_element(logprobs.begin() + vocab.token_beg, logprobs.end());
for (int i = vocab.token_beg; i < n_logits; ++ i) { for (int i = vocab.token_beg; i < n_logits; ++i) {
if (logprobs[i] > -INFINITY) {
logsumexp += expf(logprobs[i] - logprob_max); logsumexp += expf(logprobs[i] - logprob_max);
} }
logsumexp = logf(logsumexp) + logprob_max; }
timestamp_logprob = logsumexp; if (logsumexp > 0.0f) {
timestamp_logprob = logf(logsumexp) + logprob_max;
}
} }
const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg); const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg);
//fprintf(stderr, "timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
if (timestamp_logprob > max_text_token_logprob) { if (timestamp_logprob > max_text_token_logprob) {
for (int i = 0; i < vocab.token_beg; ++ i) { for (int i = 0; i < vocab.token_beg; ++i) {
logits[i] = -INFINITY; logits[i] = -INFINITY;
logprobs[i] = -INFINITY;
}
}
} }
} }
// compute probs
{
for (int i = 0; i < n_logits; ++i) {
if (logits[i] == -INFINITY) {
probs[i] = 0.0f;
} else {
probs[i] = expf(logprobs[i]);
}
} }
} }
@ -2978,8 +2991,10 @@ static void whisper_process_logits(
// print first 100 logits - token string : logit // print first 100 logits - token string : logit
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
const auto token = vocab.id_to_token.at(i); const auto token = vocab.id_to_token.at(i);
const auto prob = probs[i];
const auto logit = logits[i]; const auto logit = logits[i];
printf("%s : %f\n", token.c_str(), logit); const auto logprob = logprobs[i];
printf("%s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob);
} }
// "And", "and", " And", " and" // "And", "and", " And", " and"
@ -2994,26 +3009,13 @@ static void whisper_process_logits(
printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]); printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]); printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]); printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
#endif
//switch (params.strategy) {
// case WHISPER_SAMPLING_GREEDY:
// {
// // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L249-L274
// // TODO: implement
// result = (is_initial) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx);
// } break;
// case WHISPER_SAMPLING_BEAM_SEARCH:
// {
// // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L277C13-L364
// // TODO: implement
// } break;
//}
//sum_logprobs += logprobs[result.id]; printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]);
//printf("token: %16s, logprob: %8.3f, sum_logprobs: %8.3f, avg_logprob: %8.3f\n", vocab.id_to_token.at(result.id).c_str(), logprobs[result.id], sum_logprobs, sum_logprobs/(tokens_cur.size() + 1)); printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]);
printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
//return result; printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]);
#endif
} }
// select the most probable token // select the most probable token
@ -3052,13 +3054,6 @@ static whisper_token_data whisper_sample_best(
} }
for (int i = 0; i < n_logits; ++i) { for (int i = 0; i < n_logits; ++i) {
// never sample these:
if (i == vocab.token_sot ||
i == vocab.token_solm ||
i == vocab.token_not) {
continue;
}
if (result.p < probs[i]) { if (result.p < probs[i]) {
result.id = i; result.id = i;
result.p = probs[i]; result.p = probs[i];
@ -3334,18 +3329,23 @@ int whisper_full(
// print the prompt // print the prompt
//printf("\n\n"); //printf("\n\n");
//for (int i = 0; i < prompt.size(); i++) { //for (int i = 0; i < (int) prompt.size(); i++) {
// printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token[prompt[i]].c_str()); // printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token[prompt[i]].c_str());
//} //}
//printf("\n\n"); //printf("\n\n");
if (!whisper_decode(*ctx, ctx->decoders[0], params.n_threads, prompt.data(), prompt.size(), 0)) { if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
fprintf(stderr, "%s: failed to decode\n", __func__); fprintf(stderr, "%s: failed to decode\n", __func__);
return -7; return -7;
} }
{
const int64_t t_start_sample_us = ggml_time_us();
whisper_process_logits(*ctx, ctx->decoders[0], params); whisper_process_logits(*ctx, ctx->decoders[0], params);
ctx->decoders[0].n_past += prompt.size();
for (int j = 1; j < n_decoders_cur; ++j) { for (int j = 1; j < n_decoders_cur; ++j) {
auto & decoder = ctx->decoders[j]; auto & decoder = ctx->decoders[j];
@ -3357,9 +3357,14 @@ int whisper_full(
memcpy(decoder.logits.data(), ctx->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0])); memcpy(decoder.logits.data(), ctx->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
memcpy(decoder.logprobs.data(), ctx->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0])); memcpy(decoder.logprobs.data(), ctx->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
} }
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
}
} }
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
const int64_t t_start_sample_us = ggml_time_us();
for (int j = 0; j < n_decoders_cur; ++j) { for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = ctx->decoders[j]; auto & decoder = ctx->decoders[j];
@ -3407,7 +3412,8 @@ int whisper_full(
{ {
const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]"; const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]";
printf("%s: %3d, decoder = %d, %10s %6d %6.3f '%s'\n", __func__, i, j, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str()); printf("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n",
__func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token[token.id].c_str());
} }
// end of segment // end of segment
@ -3468,6 +3474,8 @@ int whisper_full(
} }
} }
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
for (int j = 0; j < n_decoders_cur; ++j) { for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = ctx->decoders[j]; auto & decoder = ctx->decoders[j];
@ -3480,14 +3488,20 @@ int whisper_full(
//fprintf(stderr, "%s: decoder %d: token %d, n_past %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.n_past, decoder.seek_delta); //fprintf(stderr, "%s: decoder %d: token %d, n_past %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.n_past, decoder.seek_delta);
if (!whisper_decode(*ctx, decoder, params.n_threads, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.n_past)) { if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.n_past, params.n_threads)) {
fprintf(stderr, "%s: failed to decode\n", __func__); fprintf(stderr, "%s: failed to decode\n", __func__);
return -7; return -7;
} }
{
const int64_t t_start_sample_us = ggml_time_us();
whisper_process_logits(*ctx, decoder, params); whisper_process_logits(*ctx, decoder, params);
++decoder.n_past; ++decoder.n_past;
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
}
} }
} }
@ -3502,12 +3516,15 @@ int whisper_full(
continue; continue;
} }
whisper_sequence_score(params, ctx->decoders[j].sequence); decoder.sequence.tokens.resize(decoder.sequence.result_len);
whisper_sequence_score(params, decoder.sequence);
if (best_score < decoder.sequence.score) { if (best_score < decoder.sequence.score) {
best_score = decoder.sequence.score; best_score = decoder.sequence.score;
best_decoder_id = j; best_decoder_id = j;
} }
fprintf(stderr, "%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f\n", __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs);
} }
} }
@ -3671,9 +3688,8 @@ int whisper_full_parallel(
ctx_p = *ctx; ctx_p = *ctx;
ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx); ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx);
ctx_p.probs.reserve (ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx);
ctx_p.probs_id.reserve(ctx_p.vocab.n_vocab); ctx_p.logits_id.reserve(ctx_p.vocab.n_vocab);
if (!kv_cache_reinit(ctx_p.kv_cross)) { if (!kv_cache_reinit(ctx_p.kv_cross)) {
fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention, processor %d\n", __func__, i); fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention, processor %d\n", __func__, i);

@ -191,12 +191,6 @@ extern "C" {
// Cols: n_vocab // Cols: n_vocab
WHISPER_API float * whisper_get_logits(struct whisper_context * ctx); WHISPER_API float * whisper_get_logits(struct whisper_context * ctx);
// Token probabilities (i.e. softmax(logits)) obtained from the last call to whisper_decode()
// The probabilities for the last token are stored in the last row
// Rows: n_tokens
// Cols: n_vocab
WHISPER_API float * whisper_get_probs(struct whisper_context * ctx);
// Token Id -> String. Uses the vocabulary in the provided context // Token Id -> String. Uses the vocabulary in the provided context
WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token); WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token);

Loading…
Cancel
Save