diff --git a/examples/command/command.cpp b/examples/command/command.cpp index 0ee3306..3ea563a 100644 --- a/examples/command/command.cpp +++ b/examples/command/command.cpp @@ -510,351 +510,390 @@ std::vector read_allowed_commands(const std::string & fname) { return allowed_commands; } -int main(int argc, char ** argv) { - whisper_params params; +// command-list mode +// guide the transcription to match the most likely command from a provided list +int process_command_list(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) { + fprintf(stderr, "\n"); + fprintf(stderr, "%s: guided mode\n", __func__); + + std::vector allowed_commands = read_allowed_commands(params.commands); + + if (allowed_commands.empty()) { + fprintf(stderr, "%s: error: failed to read allowed commands from '%s'\n", __func__, params.commands.c_str()); + return 2; + } + + int max_len = 0; + + std::vector> allowed_tokens; + + for (const auto & cmd : allowed_commands) { + whisper_token tokens[1024]; + allowed_tokens.emplace_back(); + + for (int l = 0; l < (int) cmd.size(); ++l) { + // NOTE: very important to add the whitespace ! + // the reason is that the first decoded token starts with a whitespace too! + std::string ss = std::string(" ") + cmd.substr(0, l + 1); + + const int n = whisper_tokenize(ctx, ss.c_str(), tokens, 1024); + if (n < 0) { + fprintf(stderr, "%s: error: failed to tokenize command '%s'\n", __func__, cmd.c_str()); + return 3; + } + + if (n == 1) { + allowed_tokens.back().push_back(tokens[0]); + } + } + + max_len = std::max(max_len, (int) cmd.size()); + } + + fprintf(stderr, "%s: allowed commands [ tokens ]:\n", __func__); + fprintf(stderr, "\n"); + for (int i = 0; i < (int) allowed_commands.size(); ++i) { + fprintf(stderr, " - \033[1m%-*s\033[0m = [", max_len, allowed_commands[i].c_str()); + for (const auto & token : allowed_tokens[i]) { + fprintf(stderr, " %5d", token); + } + fprintf(stderr, " ]\n"); + } + + std::string k_prompt = "select one from the available words: "; + for (int i = 0; i < (int) allowed_commands.size(); ++i) { + if (i > 0) { + k_prompt += ", "; + } + k_prompt += allowed_commands[i]; + } + k_prompt += ". selected word: "; + + // tokenize prompt + std::vector k_tokens; + { + k_tokens.resize(1024); + const int n = whisper_tokenize(ctx, k_prompt.c_str(), k_tokens.data(), 1024); + if (n < 0) { + fprintf(stderr, "%s: error: failed to tokenize prompt '%s'\n", __func__, k_prompt.c_str()); + return 4; + } + k_tokens.resize(n); + } + + fprintf(stderr, "\n"); + fprintf(stderr, "%s: prompt: '%s'\n", __func__, k_prompt.c_str()); + fprintf(stderr, "%s: tokens: [", __func__); + for (const auto & token : k_tokens) { + fprintf(stderr, " %d", token); + } + fprintf(stderr, " ]\n"); + + fprintf(stderr, "\n"); + fprintf(stderr, "%s: listening for a command ...\n", __func__); + fprintf(stderr, "\n"); + + bool is_running = true; + + std::vector pcmf32_cur; + std::vector pcmf32_prompt; + + // main loop + while (is_running) { + // handle Ctrl + C + { + SDL_Event event; + while (SDL_PollEvent(&event)) { + switch (event.type) { + case SDL_QUIT: + { + is_running = false; + } break; + default: + break; + } + } - if (whisper_params_parse(argc, argv, params) == false) { - return 1; - } + if (!is_running) { + return 0; + } + } - if (whisper_lang_id(params.language.c_str()) == -1) { - fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str()); - whisper_print_usage(argc, argv, params); - exit(0); - } + // delay + std::this_thread::sleep_for(std::chrono::milliseconds(100)); - // whisper init + audio.get(2000, pcmf32_cur); - struct whisper_context * ctx = whisper_init(params.model.c_str()); + if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) { + fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__); - // print some info about the processing - { - fprintf(stderr, "\n"); - if (!whisper_is_multilingual(ctx)) { - if (params.language != "en" || params.translate) { - params.language = "en"; - params.translate = false; - fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__); - } - } - fprintf(stderr, "%s: processing, %d threads, lang = %s, task = %s, timestamps = %d ...\n", - __func__, - params.n_threads, - params.language.c_str(), - params.translate ? "translate" : "transcribe", - params.no_timestamps ? 0 : 1); + const auto t_start = std::chrono::high_resolution_clock::now(); - fprintf(stderr, "\n"); - } + whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); - // init audio + wparams.print_progress = false; + wparams.print_special = params.print_special; + wparams.print_realtime = false; + wparams.print_timestamps = !params.no_timestamps; + wparams.translate = params.translate; + wparams.no_context = true; + wparams.single_segment = true; + wparams.max_tokens = 1; + wparams.language = params.language.c_str(); + wparams.n_threads = params.n_threads; - audio_async audio(30*1000); - if (!audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) { - fprintf(stderr, "%s: audio.init() failed!\n", __func__); - return 1; - } + wparams.audio_ctx = params.audio_ctx; + wparams.speed_up = params.speed_up; - audio.resume(); + wparams.prompt_tokens = k_tokens.data(); + wparams.prompt_n_tokens = k_tokens.size(); - // wait for 1 second to avoid any buffered noise - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); - audio.clear(); + // run the transformer and a single decoding pass + if (whisper_full(ctx, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) { + fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__); + break; + } - int max_len = 0; + const auto * probs = whisper_get_probs(ctx); + std::vector> probs_id; - bool is_running = true; - bool have_prompt = false; - bool ask_prompt = true; + double psum = 0.0; + for (int i = 0; i < (int) allowed_commands.size(); ++i) { + probs_id.emplace_back(probs[allowed_tokens[i][0]], i); + for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) { + probs_id.back().first += probs[allowed_tokens[i][j]]; + } + probs_id.back().first /= allowed_tokens[i].size(); + psum += probs_id.back().first; + } + + // normalize + for (auto & p : probs_id) { + p.first /= psum; + } + + // sort descending + { + using pair_type = decltype(probs_id)::value_type; + std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) { + return a.first > b.first; + }); + } + + // print the commands and the respective probabilities + { + fprintf(stdout, "\n"); + for (const auto & cmd : probs_id) { + fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first); + for (int token : allowed_tokens[cmd.second]) { + fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]); + } + fprintf(stdout, "\n"); + } + } - float prob0 = 0.0f; - float prob = 0.0f; + // best command + { + const auto t_end = std::chrono::high_resolution_clock::now(); - std::vector pcmf32_cur; - std::vector pcmf32_prompt; + const float prob = probs_id[0].first; + const int index = probs_id[0].second; - std::vector allowed_commands; - std::vector> allowed_tokens; + fprintf(stdout, "\n"); + fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__, + "\033[1m", allowed_commands[index].c_str(), "\033[0m", prob, + (int) std::chrono::duration_cast(t_end - t_start).count()); + fprintf(stdout, "\n"); + } - std::string k_prompt; - std::vector k_tokens; + audio.clear(); + } + } - if (!params.commands.empty()) { - fprintf(stderr, "\n"); - fprintf(stderr, "%s: guided mode\n", __func__); + return 0; +} - allowed_commands = read_allowed_commands(params.commands); +// general-purpose mode +// freely transcribe the voice into text +int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) { + bool is_running = true; + bool have_prompt = false; + bool ask_prompt = true; + + float prob0 = 0.0f; + float prob = 0.0f; + + std::vector pcmf32_cur; + std::vector pcmf32_prompt; + + const std::string k_prompt = "Ok Whisper, start listening for commands."; + + fprintf(stderr, "\n"); + fprintf(stderr, "%s: general-purpose mode\n", __func__); + + // main loop + while (is_running) { + // handle Ctrl + C + { + SDL_Event event; + while (SDL_PollEvent(&event)) { + switch (event.type) { + case SDL_QUIT: + { + is_running = false; + } break; + default: + break; + } + } - if (allowed_commands.empty()) { - fprintf(stderr, "%s: error: failed to read allowed commands from '%s'\n", __func__, params.commands.c_str()); - return 2; - } + if (!is_running) { + return 0; + } + } - for (const auto & cmd : allowed_commands) { - whisper_token tokens[1024]; - allowed_tokens.emplace_back(); + // delay + std::this_thread::sleep_for(std::chrono::milliseconds(100)); - for (int l = 0; l < (int) cmd.size(); ++l) { - // NOTE: very important to add the whitespace ! - // the reason is that the first decoded token starts with a whitespace too! - std::string ss = std::string(" ") + cmd.substr(0, l + 1); + if (ask_prompt) { + fprintf(stdout, "\n"); + fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m"); + fprintf(stdout, "\n"); - const int n = whisper_tokenize(ctx, ss.c_str(), tokens, 1024); - if (n < 0) { - fprintf(stderr, "%s: error: failed to tokenize command '%s'\n", __func__, cmd.c_str()); - return 3; - } + ask_prompt = false; + } - if (n == 1) { - allowed_tokens.back().push_back(tokens[0]); - } - } + { + audio.get(2000, pcmf32_cur); - max_len = std::max(max_len, (int) cmd.size()); - } + if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) { + fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__); - fprintf(stderr, "%s: allowed commands [ tokens ]:\n", __func__); - fprintf(stderr, "\n"); - for (int i = 0; i < (int) allowed_commands.size(); ++i) { - fprintf(stderr, " - \033[1m%-*s\033[0m = [", max_len, allowed_commands[i].c_str()); - for (const auto & token : allowed_tokens[i]) { - fprintf(stderr, " %5d", token); - } - fprintf(stderr, " ]\n"); - } + int64_t t_ms = 0; - k_prompt = "select one from the available words: "; - for (int i = 0; i < (int) allowed_commands.size(); ++i) { - if (i > 0) { - k_prompt += ", "; - } - k_prompt += allowed_commands[i]; - } - k_prompt += ". selected word: "; - - // tokenize prompt - { - k_tokens.resize(1024); - const int n = whisper_tokenize(ctx, k_prompt.c_str(), k_tokens.data(), 1024); - if (n < 0) { - fprintf(stderr, "%s: error: failed to tokenize prompt '%s'\n", __func__, k_prompt.c_str()); - return 4; - } - k_tokens.resize(n); - } + if (!have_prompt) { + // wait for activation phrase + audio.get(params.prompt_ms, pcmf32_cur); - fprintf(stderr, "\n"); - fprintf(stderr, "%s: prompt: '%s'\n", __func__, k_prompt.c_str()); - fprintf(stderr, "%s: tokens: [", __func__); - for (const auto & token : k_tokens) { - fprintf(stderr, " %d", token); - } - fprintf(stderr, " ]\n"); + const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms)); - fprintf(stderr, "\n"); - fprintf(stderr, "%s: listening for a command ...\n", __func__); - fprintf(stderr, "\n"); + fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms); - } else { - fprintf(stderr, "\n"); - fprintf(stderr, "%s: general-purpose mode\n", __func__); - - k_prompt = "Ok Whisper, start listening for commands."; - } - - // main loop - while (is_running) { - // handle Ctrl + C - { - SDL_Event event; - while (SDL_PollEvent(&event)) { - switch (event.type) { - case SDL_QUIT: - { - is_running = false; - } break; - default: - break; - } - } + const float sim = similarity(txt, k_prompt); - if (!is_running) { - break; - } - } + if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) { + fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__); + ask_prompt = true; + } else { + fprintf(stdout, "\n"); + fprintf(stdout, "%s: The prompt has been recognized!\n", __func__); + fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__); + fprintf(stdout, "\n"); - // delay - std::this_thread::sleep_for(std::chrono::milliseconds(100)); + // save the audio for the prompt + pcmf32_prompt = pcmf32_cur; + have_prompt = true; + } + } else { + // we have heard the activation phrase, now detect the commands + audio.get(params.command_ms, pcmf32_cur); - if (allowed_commands.empty()) { - // general-purpose mode - // freely transcribe the voice into text + // prepend the prompt audio + pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end()); - if (ask_prompt) { - fprintf(stdout, "\n"); - fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m"); - fprintf(stdout, "\n"); + const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms)); - ask_prompt = false; - } + prob = 100.0f*(prob - prob0); - { - int64_t t_ms = 0; + //fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str()); - audio.get(2000, pcmf32_cur); + // find the prompt in the text + float best_sim = 0.0f; + size_t best_len = 0; + for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) { + const auto prompt = txt.substr(0, n); - if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) { - fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__); + const float sim = similarity(prompt, k_prompt); - if (!have_prompt) { - // wait for activation phrase - audio.get(params.prompt_ms, pcmf32_cur); + //fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim); - const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms)); + if (sim > best_sim) { + best_sim = sim; + best_len = n; + } + } - fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms); + const std::string command = ::trim(txt.substr(best_len)); - const float sim = similarity(txt, k_prompt); + fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms); + fprintf(stdout, "\n"); + } - if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) { - fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__); - ask_prompt = true; - } else { - fprintf(stdout, "\n"); - fprintf(stdout, "%s: The prompt has been recognized!\n", __func__); - fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__); - fprintf(stdout, "\n"); + audio.clear(); + } + } + } - // save the audio for the prompt - pcmf32_prompt = pcmf32_cur; - have_prompt = true; - } - } else { - // we have heard the activation phrase, now detect the commands - audio.get(params.command_ms, pcmf32_cur); + return 0; +} - // prepend the prompt audio - pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end()); +int main(int argc, char ** argv) { + whisper_params params; - const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms)); + if (whisper_params_parse(argc, argv, params) == false) { + return 1; + } - prob = 100.0f*(prob - prob0); + if (whisper_lang_id(params.language.c_str()) == -1) { + fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str()); + whisper_print_usage(argc, argv, params); + exit(0); + } - //fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str()); + // whisper init - // find the prompt in the text - float best_sim = 0.0f; - size_t best_len = 0; - for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) { - const auto prompt = txt.substr(0, n); + struct whisper_context * ctx = whisper_init(params.model.c_str()); - const float sim = similarity(prompt, k_prompt); + // print some info about the processing + { + fprintf(stderr, "\n"); + if (!whisper_is_multilingual(ctx)) { + if (params.language != "en" || params.translate) { + params.language = "en"; + params.translate = false; + fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__); + } + } + fprintf(stderr, "%s: processing, %d threads, lang = %s, task = %s, timestamps = %d ...\n", + __func__, + params.n_threads, + params.language.c_str(), + params.translate ? "translate" : "transcribe", + params.no_timestamps ? 0 : 1); - //fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim); + fprintf(stderr, "\n"); + } - if (sim > best_sim) { - best_sim = sim; - best_len = n; - } - } + // init audio - const std::string command = ::trim(txt.substr(best_len)); + audio_async audio(30*1000); + if (!audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) { + fprintf(stderr, "%s: audio.init() failed!\n", __func__); + return 1; + } - fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms); - fprintf(stdout, "\n"); - } + audio.resume(); - audio.clear(); - } - } - } else { - // command-list mode - // guide the transcription to match the most likely command from a provided list - - audio.get(2000, pcmf32_cur); - - if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) { - fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__); - - const auto t_start = std::chrono::high_resolution_clock::now(); - - whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); - - wparams.print_progress = false; - wparams.print_special = params.print_special; - wparams.print_realtime = false; - wparams.print_timestamps = !params.no_timestamps; - wparams.translate = params.translate; - wparams.no_context = true; - wparams.single_segment = true; - wparams.max_tokens = 1; - wparams.language = params.language.c_str(); - wparams.n_threads = params.n_threads; - - wparams.audio_ctx = params.audio_ctx; - wparams.speed_up = params.speed_up; - - wparams.prompt_tokens = k_tokens.data(); - wparams.prompt_n_tokens = k_tokens.size(); - - // run the transformer and a single decoding pass - if (whisper_full(ctx, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) { - fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__); - break; - } - - const auto * probs = whisper_get_probs(ctx); - std::vector> probs_id; - - double psum = 0.0; - for (int i = 0; i < (int) allowed_commands.size(); ++i) { - probs_id.emplace_back(probs[allowed_tokens[i][0]], i); - for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) { - probs_id.back().first += probs[allowed_tokens[i][j]]; - } - probs_id.back().first /= allowed_tokens[i].size(); - psum += probs_id.back().first; - } - - // normalize - for (auto & p : probs_id) { - p.first /= psum; - } - - // sort descending - { - using pair_type = decltype(probs_id)::value_type; - std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) { - return a.first > b.first; - }); - } - - // print the commands and the respective probabilities - { - fprintf(stdout, "\n"); - for (const auto & cmd : probs_id) { - fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first); - for (int i = 0; i < (int) allowed_tokens[cmd.second].size(); ++i) { - fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, allowed_tokens[cmd.second][i]), probs[allowed_tokens[cmd.second][i]]); - } - fprintf(stdout, "\n"); - } - } - - // best command - { - const auto t_end = std::chrono::high_resolution_clock::now(); - - fprintf(stdout, "\n"); - fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__, - "\033[1m", allowed_commands[probs_id[0].second].c_str(), "\033[0m", probs_id[0].first, - (int) std::chrono::duration_cast(t_end - t_start).count()); - fprintf(stdout, "\n"); - } - - audio.clear(); - } - } + // wait for 1 second to avoid any buffered noise + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + audio.clear(); + + int ret_val = 0; + + if (!params.commands.empty()) { + ret_val = process_command_list(ctx, audio, params); + } else { + ret_val = process_general_transcription(ctx, audio, params); } audio.pause(); @@ -862,5 +901,5 @@ int main(int argc, char ** argv) { whisper_print_timings(ctx); whisper_free(ctx); - return 0; + return ret_val; }