whisper : add whisper_get_logits()

pull/291/head
Georgi Gerganov 3 years ago
parent 2d8d3724b8
commit 21559537e2
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -620,6 +620,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx); wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
wctx.probs.reserve(vocab.n_vocab*model.hparams.n_text_ctx); wctx.probs.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
wctx.work_logits.reserve(vocab.n_vocab);
wctx.work_logprobs.reserve(vocab.n_vocab);
vocab.probs_id.reserve(n_vocab); vocab.probs_id.reserve(n_vocab);
} }
@ -1004,11 +1007,11 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
model.memory_cross_v = ggml_new_tensor_1d(ctx, wtype, n_elements); model.memory_cross_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
} }
const size_t memory_size = const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) + const size_t memory_cross_size = ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
fprintf(stderr, "%s: memory size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); fprintf(stderr, "%s: memory size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
fprintf(stderr, "%s: memory x size = %7.2f MB\n", __func__, memory_cross_size/1024.0/1024.0);
} }
// load weights // load weights
@ -2580,6 +2583,10 @@ int whisper_is_multilingual(struct whisper_context * ctx) {
return ctx->vocab.is_multilingual() ? 1 : 0; return ctx->vocab.is_multilingual() ? 1 : 0;
} }
float * whisper_get_logits(struct whisper_context * ctx) {
return ctx->logits.data();
}
float * whisper_get_probs(struct whisper_context * ctx) { float * whisper_get_probs(struct whisper_context * ctx) {
return ctx->probs.data(); return ctx->probs.data();
} }
@ -2842,6 +2849,7 @@ static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) {
static struct whisper_token_data whisper_sample_next_token( static struct whisper_token_data whisper_sample_next_token(
struct whisper_context * ctx, struct whisper_context * ctx,
struct whisper_full_params params, struct whisper_full_params params,
double & sum_logprobs,
const std::vector<whisper_token> & prompt, const std::vector<whisper_token> & prompt,
const std::vector<whisper_token_data> & tokens_cur) { const std::vector<whisper_token_data> & tokens_cur) {
struct whisper_token_data result = {}; struct whisper_token_data result = {};
@ -2849,7 +2857,7 @@ static struct whisper_token_data whisper_sample_next_token(
const auto & vocab = ctx->vocab; const auto & vocab = ctx->vocab;
const bool is_initial = tokens_cur.size() == 0; const bool is_initial = tokens_cur.size() == 0;
const int n_logits = vocab.id_to_token.size(); const int n_logits = vocab.id_to_token.size();
WHISPER_ASSERT(n_logits == ctx->vocab.n_vocab); WHISPER_ASSERT(n_logits == ctx->vocab.n_vocab);
@ -2949,6 +2957,7 @@ static struct whisper_token_data whisper_sample_next_token(
} }
} }
#if 0
// print first 100 logits - token string : logit // print first 100 logits - token string : logit
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
const auto token = vocab.id_to_token.at(i); const auto token = vocab.id_to_token.at(i);
@ -2968,6 +2977,7 @@ static struct whisper_token_data whisper_sample_next_token(
printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]); printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]); printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]); printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
#endif
switch (params.strategy) { switch (params.strategy) {
case WHISPER_SAMPLING_GREEDY: case WHISPER_SAMPLING_GREEDY:
@ -2983,6 +2993,9 @@ static struct whisper_token_data whisper_sample_next_token(
} break; } break;
} }
sum_logprobs += logprobs[result.id];
printf("token: %16s, logprob: %8.3f, sum_logprobs: %8.3f, avg_logprob: %8.3f\n", vocab.id_to_token.at(result.id).c_str(), logprobs[result.id], sum_logprobs, sum_logprobs/(tokens_cur.size() + 1));
return result; return result;
} }
@ -3151,6 +3164,8 @@ int whisper_full(
bool failed = false; // has the current segment failed to decode? bool failed = false; // has the current segment failed to decode?
bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment? bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment?
double sum_logprobs = 0.0;
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) { if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
fprintf(stderr, "%s: failed to decode\n", __func__); fprintf(stderr, "%s: failed to decode\n", __func__);
@ -3163,7 +3178,7 @@ int whisper_full(
// sample the next token based on the selected decoding strategy + parameters // sample the next token based on the selected decoding strategy + parameters
// also, update the sliding window position based on the sampled timestamp tokens // also, update the sliding window position based on the sampled timestamp tokens
{ {
const auto token = whisper_sample_next_token(ctx, params, prompt, tokens_cur); const auto token = whisper_sample_next_token(ctx, params, sum_logprobs, prompt, tokens_cur);
// timestamp token - update sliding window // timestamp token - update sliding window
if (token.id > whisper_token_beg(ctx)) { if (token.id > whisper_token_beg(ctx)) {

@ -145,7 +145,7 @@ extern "C" {
// Token sampling methods. // Token sampling methods.
// These are provided for convenience and can be used after each call to whisper_decode(). // These are provided for convenience and can be used after each call to whisper_decode().
// You can also implement your own sampling method using the whisper_get_probs() function. // You can also implement your own sampling method using the whiper_get_logits() or whisper_get_probs() functions.
// whisper_sample_best() returns the token with the highest probability // whisper_sample_best() returns the token with the highest probability
// whisper_sample_timestamp() returns the most probable timestamp token // whisper_sample_timestamp() returns the most probable timestamp token
WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx); WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
@ -192,7 +192,16 @@ extern "C" {
WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx); WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx); WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
// The probabilities for the next token // Token logits obtained from the last call to whisper_decode()
// The logits for the last token are stored in the last row
// Rows: n_tokens
// Cols: n_vocab
WHISPER_API float * whisper_get_logits(struct whisper_context * ctx);
// Token probabilities (i.e. softmax(logits)) obtained from the last call to whisper_decode()
// The probabilities for the last token are stored in the last row
// Rows: n_tokens
// Cols: n_vocab
WHISPER_API float * whisper_get_probs(struct whisper_context * ctx); WHISPER_API float * whisper_get_probs(struct whisper_context * ctx);
// Token Id -> String. Uses the vocabulary in the provided context // Token Id -> String. Uses the vocabulary in the provided context

Loading…
Cancel
Save