From fba10a4c68f0533a339174ef81c6a18ea228d331 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 17 Dec 2022 17:58:08 +0200 Subject: [PATCH] whisper : language auto-detect (#59) --- examples/main/main.cpp | 4 +- whisper.cpp | 125 +++++++++++++++++++++++++++++++++++++++-- whisper.h | 23 ++++++++ 3 files changed, 145 insertions(+), 7 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 042f892..9d889eb 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -154,7 +154,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params) fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false"); fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true"); - fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); + fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str()); fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str()); fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", ""); @@ -453,7 +453,7 @@ int main(int argc, char ** argv) { return 2; } - if (whisper_lang_id(params.language.c_str()) == -1) { + if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) { fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str()); whisper_print_usage(argc, argv, params); exit(0); diff --git a/whisper.cpp b/whisper.cpp index 0aca60c..95bbcdd 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -1105,7 +1105,7 @@ static bool whisper_encode( struct ggml_init_params params; params.mem_size = wctx.buf_compute.size(); - params.mem_buffer = wctx.buf_compute.data(); + params.mem_buffer = wctx.buf_compute.data(); struct ggml_context * ctx0 = ggml_init(params); @@ -2372,8 +2372,23 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to return res.size(); } +int whisper_lang_max_id() { + auto max_id = 0; + for (const auto & kv : g_lang) { + max_id = std::max(max_id, kv.second.first); + } + + return max_id; +} + int whisper_lang_id(const char * lang) { if (!g_lang.count(lang)) { + for (const auto & kv : g_lang) { + if (kv.second.second == lang) { + return kv.second.first; + } + } + fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang); return -1; } @@ -2381,6 +2396,86 @@ int whisper_lang_id(const char * lang) { return g_lang.at(lang).first; } +const char * whisper_lang_str(int id) { + for (const auto & kv : g_lang) { + if (kv.second.first == id) { + return kv.first.c_str(); + } + } + + fprintf(stderr, "%s: unknown language id %d\n", __func__, id); + return NULL; +} + +int whisper_lang_auto_detect( + struct whisper_context * ctx, + int offset_ms, + int n_threads, + float * lang_probs) { + const int seek = offset_ms/10; + + if (seek < 0) { + fprintf(stderr, "%s: offset %dms is before the start of the audio\n", __func__, offset_ms); + return -1; + } + + if (seek >= ctx->mel.n_len) { + fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, ctx->mel.n_len*10); + return -2; + } + + // run the encoder + if (whisper_encode(ctx, seek, n_threads) != 0) { + fprintf(stderr, "%s: failed to encode\n", __func__); + return -6; + } + + const std::vector prompt = { whisper_token_sot(ctx) }; + + if (whisper_decode(ctx, prompt.data(), prompt.size(), 0, n_threads) != 0) { + fprintf(stderr, "%s: failed to decode\n", __func__); + return -7; + } + + std::vector> probs_id; + for (const auto kv : g_lang) { + const auto token_lang = whisper_token_lang(ctx, kv.second.first); + probs_id.push_back({ ctx->probs[token_lang], kv.second.first }); + } + + // sort descending + { + using pair_type = decltype(probs_id)::value_type; + std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) { + return a.first > b.first; + }); + } + + // softmax + { + float sum = 0; + for (const auto & kv : probs_id) { + sum += exp(kv.first); + } + + for (auto & kv : probs_id) { + kv.first = exp(kv.first) / sum; + } + } + + { + for (int i = 0; i < probs_id.size(); i++) { + if (lang_probs) { + lang_probs[probs_id[i].second] = probs_id[i].first; + } + + //printf("%s: lang %2d (%3s): %f\n", __func__, probs_id[i].second, whisper_lang_str(probs_id[i].second), probs_id[i].first); + } + } + + return probs_id[0].second; +} + int whisper_n_len(struct whisper_context * ctx) { return ctx->mel.n_len; } @@ -2429,6 +2524,10 @@ whisper_token whisper_token_beg(struct whisper_context * ctx) { return ctx->vocab.token_beg; } +whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) { + return whisper_token_sot(ctx) + 1 + lang_id; +} + whisper_token whisper_token_translate(void) { return whisper_vocab::token_translate; } @@ -2661,10 +2760,25 @@ int whisper_full( } else { if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) { fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__); - return -1; + return -2; } } + // auto-detect language if not specified + if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) { + std::vector probs(whisper_lang_max_id() + 1, 0.0f); + + const auto lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, probs.data()); + if (lang_id < 0) { + fprintf(stderr, "%s: failed to auto-detect language\n", __func__); + return -3; + } + + params.language = whisper_lang_str(lang_id); + + fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]); + } + if (params.token_timestamps) { ctx->t_beg = 0; ctx->t_last = 0; @@ -2703,7 +2817,8 @@ int whisper_full( // these tokens determine the task that will be performed std::vector prompt_init = { whisper_token_sot(ctx) }; if (whisper_is_multilingual(ctx)) { - prompt_init.push_back(whisper_token_sot(ctx) + 1 + whisper_lang_id(params.language)); + const int lang_id = whisper_lang_id(params.language); + prompt_init.push_back(whisper_token_lang(ctx, lang_id)); if (params.translate) { prompt_init.push_back(whisper_token_translate()); } else { @@ -2752,7 +2867,7 @@ int whisper_full( // encode audio features starting at offset seek if (whisper_encode(ctx, seek, params.n_threads) != 0) { fprintf(stderr, "%s: failed to encode\n", __func__); - return 7; + return -4; } int n_past = 0; @@ -2790,7 +2905,7 @@ int whisper_full( for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) { fprintf(stderr, "%s: failed to decode\n", __func__); - return 8; + return -5; } n_past += prompt.size(); diff --git a/whisper.h b/whisper.h index a28a3b7..e2657c1 100644 --- a/whisper.h +++ b/whisper.h @@ -150,9 +150,30 @@ extern "C" { whisper_token * tokens, int n_max_tokens); + // Largest language id (i.e. number of available languages - 1) + WHISPER_API int whisper_lang_max_id(); + // Return the id of the specified language, returns -1 if not found + // Examples: + // "de" -> 2 + // "german" -> 2 WHISPER_API int whisper_lang_id(const char * lang); + // Return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found + WHISPER_API const char * whisper_lang_str(int id); + + // Use mel data at offset_ms to try and auto-detect the spoken language + // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first + // Returns the top language id or negative on failure + // If not null, fills the lang_probs array with the probabilities of all languages + // The array must be whispe_lang_max_id() + 1 in size + // ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69 + WHISPER_API int whisper_lang_auto_detect( + struct whisper_context * ctx, + int offset_ms, + int n_threads, + float * lang_probs); + WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length WHISPER_API int whisper_n_vocab (struct whisper_context * ctx); WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx); @@ -171,6 +192,7 @@ extern "C" { WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx); WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx); WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id); // Task tokens WHISPER_API whisper_token whisper_token_translate (void); @@ -236,6 +258,7 @@ extern "C" { const whisper_token * prompt_tokens; int prompt_n_tokens; + // for auto-detection, set to nullptr, "" or "auto" const char * language; struct {