|
|
|
@ -620,6 +620,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
|
|
wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
|
|
|
|
|
wctx.probs.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
|
|
|
|
|
|
|
|
|
|
wctx.work_logits.reserve(vocab.n_vocab);
|
|
|
|
|
wctx.work_logprobs.reserve(vocab.n_vocab);
|
|
|
|
|
|
|
|
|
|
vocab.probs_id.reserve(n_vocab);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -1004,11 +1007,11 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
|
|
model.memory_cross_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const size_t memory_size =
|
|
|
|
|
ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) +
|
|
|
|
|
ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
|
|
|
|
|
const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
|
|
|
|
|
const size_t memory_cross_size = ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
|
|
|
|
|
|
|
|
|
|
fprintf(stderr, "%s: memory size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
|
|
|
|
|
fprintf(stderr, "%s: memory x size = %7.2f MB\n", __func__, memory_cross_size/1024.0/1024.0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// load weights
|
|
|
|
@ -2580,6 +2583,10 @@ int whisper_is_multilingual(struct whisper_context * ctx) {
|
|
|
|
|
return ctx->vocab.is_multilingual() ? 1 : 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
float * whisper_get_logits(struct whisper_context * ctx) {
|
|
|
|
|
return ctx->logits.data();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
float * whisper_get_probs(struct whisper_context * ctx) {
|
|
|
|
|
return ctx->probs.data();
|
|
|
|
|
}
|
|
|
|
@ -2842,6 +2849,7 @@ static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) {
|
|
|
|
|
static struct whisper_token_data whisper_sample_next_token(
|
|
|
|
|
struct whisper_context * ctx,
|
|
|
|
|
struct whisper_full_params params,
|
|
|
|
|
double & sum_logprobs,
|
|
|
|
|
const std::vector<whisper_token> & prompt,
|
|
|
|
|
const std::vector<whisper_token_data> & tokens_cur) {
|
|
|
|
|
struct whisper_token_data result = {};
|
|
|
|
@ -2849,7 +2857,7 @@ static struct whisper_token_data whisper_sample_next_token(
|
|
|
|
|
const auto & vocab = ctx->vocab;
|
|
|
|
|
|
|
|
|
|
const bool is_initial = tokens_cur.size() == 0;
|
|
|
|
|
const int n_logits = vocab.id_to_token.size();
|
|
|
|
|
const int n_logits = vocab.id_to_token.size();
|
|
|
|
|
|
|
|
|
|
WHISPER_ASSERT(n_logits == ctx->vocab.n_vocab);
|
|
|
|
|
|
|
|
|
@ -2949,6 +2957,7 @@ static struct whisper_token_data whisper_sample_next_token(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#if 0
|
|
|
|
|
// print first 100 logits - token string : logit
|
|
|
|
|
for (int i = 0; i < 100; i++) {
|
|
|
|
|
const auto token = vocab.id_to_token.at(i);
|
|
|
|
@ -2968,6 +2977,7 @@ static struct whisper_token_data whisper_sample_next_token(
|
|
|
|
|
printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
|
|
|
|
|
printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
|
|
|
|
|
printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
switch (params.strategy) {
|
|
|
|
|
case WHISPER_SAMPLING_GREEDY:
|
|
|
|
@ -2983,6 +2993,9 @@ static struct whisper_token_data whisper_sample_next_token(
|
|
|
|
|
} break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
sum_logprobs += logprobs[result.id];
|
|
|
|
|
printf("token: %16s, logprob: %8.3f, sum_logprobs: %8.3f, avg_logprob: %8.3f\n", vocab.id_to_token.at(result.id).c_str(), logprobs[result.id], sum_logprobs, sum_logprobs/(tokens_cur.size() + 1));
|
|
|
|
|
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -3151,6 +3164,8 @@ int whisper_full(
|
|
|
|
|
bool failed = false; // has the current segment failed to decode?
|
|
|
|
|
bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment?
|
|
|
|
|
|
|
|
|
|
double sum_logprobs = 0.0;
|
|
|
|
|
|
|
|
|
|
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
|
|
|
|
|
if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
|
|
|
|
|
fprintf(stderr, "%s: failed to decode\n", __func__);
|
|
|
|
@ -3163,7 +3178,7 @@ int whisper_full(
|
|
|
|
|
// sample the next token based on the selected decoding strategy + parameters
|
|
|
|
|
// also, update the sliding window position based on the sampled timestamp tokens
|
|
|
|
|
{
|
|
|
|
|
const auto token = whisper_sample_next_token(ctx, params, prompt, tokens_cur);
|
|
|
|
|
const auto token = whisper_sample_next_token(ctx, params, sum_logprobs, prompt, tokens_cur);
|
|
|
|
|
|
|
|
|
|
// timestamp token - update sliding window
|
|
|
|
|
if (token.id > whisper_token_beg(ctx)) {
|
|
|
|
|