whisper : language auto-detect (#59)

pull/288/head
Georgi Gerganov 2 years ago
parent afe2db0fe2
commit fba10a4c68

@ -154,7 +154,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
@ -453,7 +453,7 @@ int main(int argc, char ** argv) {
return 2;
}
if (whisper_lang_id(params.language.c_str()) == -1) {
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params);
exit(0);

@ -1105,7 +1105,7 @@ static bool whisper_encode(
struct ggml_init_params params;
params.mem_size = wctx.buf_compute.size();
params.mem_buffer = wctx.buf_compute.data();
params.mem_buffer = wctx.buf_compute.data();
struct ggml_context * ctx0 = ggml_init(params);
@ -2372,8 +2372,23 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to
return res.size();
}
int whisper_lang_max_id() {
auto max_id = 0;
for (const auto & kv : g_lang) {
max_id = std::max(max_id, kv.second.first);
}
return max_id;
}
int whisper_lang_id(const char * lang) {
if (!g_lang.count(lang)) {
for (const auto & kv : g_lang) {
if (kv.second.second == lang) {
return kv.second.first;
}
}
fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang);
return -1;
}
@ -2381,6 +2396,86 @@ int whisper_lang_id(const char * lang) {
return g_lang.at(lang).first;
}
const char * whisper_lang_str(int id) {
for (const auto & kv : g_lang) {
if (kv.second.first == id) {
return kv.first.c_str();
}
}
fprintf(stderr, "%s: unknown language id %d\n", __func__, id);
return NULL;
}
int whisper_lang_auto_detect(
struct whisper_context * ctx,
int offset_ms,
int n_threads,
float * lang_probs) {
const int seek = offset_ms/10;
if (seek < 0) {
fprintf(stderr, "%s: offset %dms is before the start of the audio\n", __func__, offset_ms);
return -1;
}
if (seek >= ctx->mel.n_len) {
fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, ctx->mel.n_len*10);
return -2;
}
// run the encoder
if (whisper_encode(ctx, seek, n_threads) != 0) {
fprintf(stderr, "%s: failed to encode\n", __func__);
return -6;
}
const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
if (whisper_decode(ctx, prompt.data(), prompt.size(), 0, n_threads) != 0) {
fprintf(stderr, "%s: failed to decode\n", __func__);
return -7;
}
std::vector<std::pair<float, int>> probs_id;
for (const auto kv : g_lang) {
const auto token_lang = whisper_token_lang(ctx, kv.second.first);
probs_id.push_back({ ctx->probs[token_lang], kv.second.first });
}
// sort descending
{
using pair_type = decltype(probs_id)::value_type;
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
return a.first > b.first;
});
}
// softmax
{
float sum = 0;
for (const auto & kv : probs_id) {
sum += exp(kv.first);
}
for (auto & kv : probs_id) {
kv.first = exp(kv.first) / sum;
}
}
{
for (int i = 0; i < probs_id.size(); i++) {
if (lang_probs) {
lang_probs[probs_id[i].second] = probs_id[i].first;
}
//printf("%s: lang %2d (%3s): %f\n", __func__, probs_id[i].second, whisper_lang_str(probs_id[i].second), probs_id[i].first);
}
}
return probs_id[0].second;
}
int whisper_n_len(struct whisper_context * ctx) {
return ctx->mel.n_len;
}
@ -2429,6 +2524,10 @@ whisper_token whisper_token_beg(struct whisper_context * ctx) {
return ctx->vocab.token_beg;
}
whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) {
return whisper_token_sot(ctx) + 1 + lang_id;
}
whisper_token whisper_token_translate(void) {
return whisper_vocab::token_translate;
}
@ -2661,10 +2760,25 @@ int whisper_full(
} else {
if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
return -1;
return -2;
}
}
// auto-detect language if not specified
if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) {
std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
const auto lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, probs.data());
if (lang_id < 0) {
fprintf(stderr, "%s: failed to auto-detect language\n", __func__);
return -3;
}
params.language = whisper_lang_str(lang_id);
fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
}
if (params.token_timestamps) {
ctx->t_beg = 0;
ctx->t_last = 0;
@ -2703,7 +2817,8 @@ int whisper_full(
// these tokens determine the task that will be performed
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
if (whisper_is_multilingual(ctx)) {
prompt_init.push_back(whisper_token_sot(ctx) + 1 + whisper_lang_id(params.language));
const int lang_id = whisper_lang_id(params.language);
prompt_init.push_back(whisper_token_lang(ctx, lang_id));
if (params.translate) {
prompt_init.push_back(whisper_token_translate());
} else {
@ -2752,7 +2867,7 @@ int whisper_full(
// encode audio features starting at offset seek
if (whisper_encode(ctx, seek, params.n_threads) != 0) {
fprintf(stderr, "%s: failed to encode\n", __func__);
return 7;
return -4;
}
int n_past = 0;
@ -2790,7 +2905,7 @@ int whisper_full(
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
fprintf(stderr, "%s: failed to decode\n", __func__);
return 8;
return -5;
}
n_past += prompt.size();

@ -150,9 +150,30 @@ extern "C" {
whisper_token * tokens,
int n_max_tokens);
// Largest language id (i.e. number of available languages - 1)
WHISPER_API int whisper_lang_max_id();
// Return the id of the specified language, returns -1 if not found
// Examples:
// "de" -> 2
// "german" -> 2
WHISPER_API int whisper_lang_id(const char * lang);
// Return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found
WHISPER_API const char * whisper_lang_str(int id);
// Use mel data at offset_ms to try and auto-detect the spoken language
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first
// Returns the top language id or negative on failure
// If not null, fills the lang_probs array with the probabilities of all languages
// The array must be whispe_lang_max_id() + 1 in size
// ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
WHISPER_API int whisper_lang_auto_detect(
struct whisper_context * ctx,
int offset_ms,
int n_threads,
float * lang_probs);
WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
@ -171,6 +192,7 @@ extern "C" {
WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id);
// Task tokens
WHISPER_API whisper_token whisper_token_translate (void);
@ -236,6 +258,7 @@ extern "C" {
const whisper_token * prompt_tokens;
int prompt_n_tokens;
// for auto-detection, set to nullptr, "" or "auto"
const char * language;
struct {

Loading…
Cancel
Save