diff --git a/whisper.cpp b/whisper.cpp index 38c994d..dad1139 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #define WHISPER_ASSERT(x) \ do { \ @@ -479,6 +480,8 @@ struct whisper_context { std::vector> logits_id; + mutable std::mt19937 rng; // used for sampling at t > 0.0 + // [EXPERIMENTAL] token-level timestamps data int64_t t_beg; int64_t t_last; @@ -1180,6 +1183,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con } } + wctx.rng = std::mt19937(0); + wctx.t_load_us = ggml_time_us() - t_start_us; return true; @@ -2848,9 +2853,10 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) { // - applyies logit filters // - computes logprobs static void whisper_process_logits( - struct whisper_context & ctx, - struct whisper_decoder & decoder, - struct whisper_full_params params) { + const struct whisper_context & ctx, + const struct whisper_full_params params, + struct whisper_decoder & decoder, + float temperature) { const auto & vocab = ctx.vocab; const auto & tokens_cur = decoder.sequence.tokens; @@ -2868,6 +2874,12 @@ static void whisper_process_logits( logits.resize(n_logits); memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float)); + if (temperature > 0.0f) { + for (int i = 0; i < n_logits; i++) { + logits[i] /= temperature; + } + } + // will be populated a bit later probs.resize(n_logits); logprobs.resize(n_logits); @@ -3018,10 +3030,10 @@ static void whisper_process_logits( #endif } -// select the most probable token -static whisper_token_data whisper_sample_best( - whisper_context & ctx, - whisper_decoder & decoder) { +static whisper_token_data whisper_sample_token( + const whisper_context & ctx, + const whisper_decoder & decoder, + bool best) { whisper_token_data result = { 0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f, }; @@ -3053,12 +3065,20 @@ static whisper_token_data whisper_sample_best( result.ptsum = sum_ts; } - for (int i = 0; i < n_logits; ++i) { - if (result.p < probs[i]) { - result.id = i; - result.p = probs[i]; - result.plog = logprobs[i]; + if (best) { + for (int i = 0; i < n_logits; ++i) { + if (result.p < probs[i]) { + result.id = i; + result.p = probs[i]; + result.plog = logprobs[i]; + } } + } else { + std::discrete_distribution<> dist(probs.begin(), probs.end()); + + result.id = dist(ctx.rng); + result.p = probs[result.id]; + result.plog = logprobs[result.id]; } return result; @@ -3167,22 +3187,24 @@ int whisper_full( } break; }; - for (int i = 1; i < n_decoders; i++) { - // TAGS: WHISPER_DECODER_INIT - if (ctx->decoders[i].kv_self.ctx == nullptr) { - ctx->decoders[i].kv_self = ctx->decoders[0].kv_self; - if (!kv_cache_reinit(ctx->decoders[i].kv_self)) { - fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, i); + // TAGS: WHISPER_DECODER_INIT + for (int j = 1; j < n_decoders; j++) { + auto & decoder = ctx->decoders[j]; + + if (decoder.kv_self.ctx == nullptr) { + decoder.kv_self = ctx->decoders[0].kv_self; + if (!kv_cache_reinit(decoder.kv_self)) { + fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j); return -4; } - fprintf(stderr, "%s: initialized self-attention kv cache, decoder %d\n", __func__, i); + fprintf(stderr, "%s: initialized self-attention kv cache, decoder %d\n", __func__, j); - ctx->decoders[i].sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity()); + decoder.sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity()); - ctx->decoders[i].probs.resize (ctx->vocab.n_vocab); - ctx->decoders[i].logits.resize (ctx->vocab.n_vocab); - ctx->decoders[i].logprobs.resize(ctx->vocab.n_vocab); + decoder.probs.resize (ctx->vocab.n_vocab); + decoder.logits.resize (ctx->vocab.n_vocab); + decoder.logprobs.resize(ctx->vocab.n_vocab); } } @@ -3268,6 +3290,7 @@ int whisper_full( const float t_cur = temperatures[it]; int n_decoders_cur = 1; + switch (params.strategy) { case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: { @@ -3330,7 +3353,7 @@ int whisper_full( // print the prompt //printf("\n\n"); //for (int i = 0; i < (int) prompt.size(); i++) { - // printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token[prompt[i]].c_str()); + // printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str()); //} //printf("\n\n"); @@ -3342,7 +3365,7 @@ int whisper_full( { const int64_t t_start_sample_us = ggml_time_us(); - whisper_process_logits(*ctx, ctx->decoders[0], params); + whisper_process_logits(*ctx, params, ctx->decoders[0], t_cur); ctx->decoders[0].n_past += prompt.size(); @@ -3376,8 +3399,9 @@ int whisper_full( case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: { if (t_cur < 1e-6f) { - decoder.sequence.tokens.push_back(whisper_sample_best(*ctx, decoder)); + decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true)); } else { + decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false)); } } break; case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: @@ -3402,7 +3426,7 @@ int whisper_full( // do not allow to go back in time if (has_ts && seek_delta > seek_delta_new && result_len < i) { failed = true; // TODO: maybe this is not a failure ? - break; + continue; } seek_delta = seek_delta_new; @@ -3410,11 +3434,11 @@ int whisper_full( has_ts = true; } - { - const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]"; - printf("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n", - __func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token[token.id].c_str()); - } + //{ + // const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]"; + // printf("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n", + // __func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token.at(token.id).c_str()); + //} // end of segment if (token.id == whisper_token_eot(ctx) || // end of text token @@ -3426,7 +3450,7 @@ int whisper_full( result_len = i + 1; } else { failed = true; - break; + continue; } } @@ -3436,14 +3460,14 @@ int whisper_full( } completed = true; - break; + continue; } // TESTS: if no tensors are loaded, it means we are running tests if (ctx->model.n_loaded == 0) { seek_delta = 100*WHISPER_CHUNK_SIZE; completed = true; - break; + continue; } } @@ -3451,7 +3475,7 @@ int whisper_full( // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) { failed = true; - break; + continue; } } @@ -3496,7 +3520,7 @@ int whisper_full( { const int64_t t_start_sample_us = ggml_time_us(); - whisper_process_logits(*ctx, decoder, params); + whisper_process_logits(*ctx, params, decoder, t_cur); ++decoder.n_past; @@ -3524,7 +3548,7 @@ int whisper_full( best_decoder_id = j; } - fprintf(stderr, "%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f\n", __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs); + //fprintf(stderr, "%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f\n", __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs); } } @@ -3541,6 +3565,11 @@ int whisper_full( } if (success) { + //fprintf(stderr, "%s: best decoder = %d\n", __func__, best_decoder_id); + //for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) { + // fprintf(stderr, "%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str()); + //} + break; }