From 4d8b454b0b03896b8ed0c133180fda2b2e7adc0e Mon Sep 17 00:00:00 2001 From: kamranjon Date: Tue, 31 Jan 2023 15:53:04 -0800 Subject: [PATCH] Store lang_id in context --- whisper.cpp | 9 ++++++++- whisper.h | 3 +++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/whisper.cpp b/whisper.cpp index d52143f..b4d4cc9 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -570,6 +570,8 @@ struct whisper_context { mutable std::mt19937 rng; // used for sampling at t > 0.0 + int lang_id; + // [EXPERIMENTAL] token-level timestamps data int64_t t_beg; int64_t t_last; @@ -3303,7 +3305,7 @@ int whisper_full( fprintf(stderr, "%s: failed to auto-detect language\n", __func__); return -3; } - + ctx->lang_id = lang_id; params.language = whisper_lang_str(lang_id); fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]); @@ -3400,6 +3402,7 @@ int whisper_full( std::vector prompt_init = { whisper_token_sot(ctx) }; if (whisper_is_multilingual(ctx)) { const int lang_id = whisper_lang_id(params.language); + ctx->lang_id = lang_id; prompt_init.push_back(whisper_token_lang(ctx, lang_id)); if (params.translate) { prompt_init.push_back(whisper_token_translate()); @@ -4120,6 +4123,10 @@ int whisper_full_n_segments(struct whisper_context * ctx) { return ctx->result_all.size(); } +int whisper_full_lang_id(struct whisper_context * ctx) { + return ctx->lang_id; +} + int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) { return ctx->result_all[i_segment].t0; } diff --git a/whisper.h b/whisper.h index 51a1888..04ef5b0 100644 --- a/whisper.h +++ b/whisper.h @@ -329,6 +329,9 @@ extern "C" { // A segment can be a few words, a sentence, or even a paragraph. WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx); + // Language id associated with the current context + WHISPER_API int whisper_full_lang_id(struct whisper_context * ctx); + // Get the start and end time of the specified segment. WHISPER_API int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment); WHISPER_API int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment);