Store lang_id in context

pull/461/head
kamranjon 2 years ago
parent 2bee2650c6
commit 4d8b454b0b

@ -570,6 +570,8 @@ struct whisper_context {
mutable std::mt19937 rng; // used for sampling at t > 0.0 mutable std::mt19937 rng; // used for sampling at t > 0.0
int lang_id;
// [EXPERIMENTAL] token-level timestamps data // [EXPERIMENTAL] token-level timestamps data
int64_t t_beg; int64_t t_beg;
int64_t t_last; int64_t t_last;
@ -3303,7 +3305,7 @@ int whisper_full(
fprintf(stderr, "%s: failed to auto-detect language\n", __func__); fprintf(stderr, "%s: failed to auto-detect language\n", __func__);
return -3; return -3;
} }
ctx->lang_id = lang_id;
params.language = whisper_lang_str(lang_id); params.language = whisper_lang_str(lang_id);
fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]); fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
@ -3400,6 +3402,7 @@ int whisper_full(
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) }; std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
if (whisper_is_multilingual(ctx)) { if (whisper_is_multilingual(ctx)) {
const int lang_id = whisper_lang_id(params.language); const int lang_id = whisper_lang_id(params.language);
ctx->lang_id = lang_id;
prompt_init.push_back(whisper_token_lang(ctx, lang_id)); prompt_init.push_back(whisper_token_lang(ctx, lang_id));
if (params.translate) { if (params.translate) {
prompt_init.push_back(whisper_token_translate()); prompt_init.push_back(whisper_token_translate());
@ -4120,6 +4123,10 @@ int whisper_full_n_segments(struct whisper_context * ctx) {
return ctx->result_all.size(); return ctx->result_all.size();
} }
int whisper_full_lang_id(struct whisper_context * ctx) {
return ctx->lang_id;
}
int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) { int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
return ctx->result_all[i_segment].t0; return ctx->result_all[i_segment].t0;
} }

@ -329,6 +329,9 @@ extern "C" {
// A segment can be a few words, a sentence, or even a paragraph. // A segment can be a few words, a sentence, or even a paragraph.
WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx); WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx);
// Language id associated with the current context
WHISPER_API int whisper_full_lang_id(struct whisper_context * ctx);
// Get the start and end time of the specified segment. // Get the start and end time of the specified segment.
WHISPER_API int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment); WHISPER_API int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment);
WHISPER_API int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment); WHISPER_API int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment);

Loading…
Cancel
Save