From a1c1583cc7cd8b75222857afc936f0638c5683d6 Mon Sep 17 00:00:00 2001 From: kamranjon Date: Sun, 5 Feb 2023 04:46:26 -0800 Subject: [PATCH] whisper : add whisper_full_lang_id() for getting the context lang (#461) --- whisper.cpp | 9 ++++++++- whisper.h | 3 +++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/whisper.cpp b/whisper.cpp index 1a4a207..aedd343 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -592,6 +592,8 @@ struct whisper_context { mutable std::mt19937 rng; // used for sampling at t > 0.0 + int lang_id; + // [EXPERIMENTAL] token-level timestamps data int64_t t_beg; int64_t t_last; @@ -3478,7 +3480,7 @@ int whisper_full( fprintf(stderr, "%s: failed to auto-detect language\n", __func__); return -3; } - + ctx->lang_id = lang_id; params.language = whisper_lang_str(lang_id); fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]); @@ -3575,6 +3577,7 @@ int whisper_full( std::vector prompt_init = { whisper_token_sot(ctx) }; if (whisper_is_multilingual(ctx)) { const int lang_id = whisper_lang_id(params.language); + ctx->lang_id = lang_id; prompt_init.push_back(whisper_token_lang(ctx, lang_id)); if (params.translate) { prompt_init.push_back(whisper_token_translate()); @@ -4295,6 +4298,10 @@ int whisper_full_n_segments(struct whisper_context * ctx) { return ctx->result_all.size(); } +int whisper_full_lang_id(struct whisper_context * ctx) { + return ctx->lang_id; +} + int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) { return ctx->result_all[i_segment].t0; } diff --git a/whisper.h b/whisper.h index 3a42668..72331e6 100644 --- a/whisper.h +++ b/whisper.h @@ -330,6 +330,9 @@ extern "C" { // A segment can be a few words, a sentence, or even a paragraph. WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx); + // Language id associated with the current context + WHISPER_API int whisper_full_lang_id(struct whisper_context * ctx); + // Get the start and end time of the specified segment. WHISPER_API int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment); WHISPER_API int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment);