From 9c4a1522f6243a2a6f30a285dc39f1ae0c1854e5 Mon Sep 17 00:00:00 2001 From: David Date: Sat, 7 Jan 2023 20:41:11 +0100 Subject: [PATCH] command : always-prompt mode (#383) --- examples/command/command.cpp | 118 ++++++++++++++++++++++++++++++++++- 1 file changed, 117 insertions(+), 1 deletion(-) diff --git a/examples/command/command.cpp b/examples/command/command.cpp index 3ea563a..524ad67 100644 --- a/examples/command/command.cpp +++ b/examples/command/command.cpp @@ -11,6 +11,8 @@ #include #include +#include +#include #include #include #include @@ -25,7 +27,7 @@ struct whisper_params { int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); int32_t prompt_ms = 5000; - int32_t command_ms = 4000; + int32_t command_ms = 8000; int32_t capture_id = -1; int32_t max_tokens = 32; int32_t audio_ctx = 0; @@ -43,6 +45,7 @@ struct whisper_params { std::string model = "models/ggml-base.en.bin"; std::string fname_out; std::string commands; + std::string prompt; }; void whisper_print_usage(int argc, char ** argv, const whisper_params & params); @@ -71,6 +74,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; } else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; } + else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); whisper_print_usage(argc, argv, params); @@ -103,6 +107,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str()); fprintf(stderr, " -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n", params.commands.c_str()); + fprintf(stderr, " -p, --prompt [%-7s] the required activation prompt\n", params.prompt.c_str()); fprintf(stderr, "\n"); } @@ -837,6 +842,115 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud return 0; } + +// always prompt mode +// transcribe the voice into text after valid prompt +int always_prompt_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) { + bool is_running = true; + bool ask_prompt = true; + + float prob = 0.0f; + + std::vector pcmf32_cur; + + const std::string k_prompt = params.prompt; + + std::vector words; + + std::istringstream iss(k_prompt); + std::string word; + + while (iss >> word) { + words.push_back(word); + } + + int k_prompt_length = words.size(); + + // main loop + while (is_running) { + // handle Ctrl + C + { + SDL_Event event; + while (SDL_PollEvent(&event)) { + switch (event.type) { + case SDL_QUIT: + { + is_running = false; + } break; + default: + break; + } + } + + if (!is_running) { + return 0; + } + } + + // delay + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + if (ask_prompt) { + fprintf(stdout, "\n"); + fprintf(stdout, "%s: The prompt is: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m"); + fprintf(stdout, "\n"); + + ask_prompt = false; + } + + { + audio.get(2000, pcmf32_cur); + + if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) { + fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__); + + int64_t t_ms = 0; + + // detect the commands + audio.get(params.command_ms, pcmf32_cur); + + const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms)); + + std::istringstream iss(txt); + std::string word; + std::string prompt; + std::string command; + int i = 0; + int command_length = 0; + while (iss >> word) { + if (i == k_prompt_length - 1) { + prompt += word + ' '; + break; + } + prompt += word + ' '; + i++; + } + while (iss >> word) { + command += word + ' '; + command_length++; + } + + const float sim = similarity(prompt, k_prompt); + + //debug + //fprintf(stdout, "command size: %i\n", command_length); + + + if ((sim > 0.7f) && (command_length >0)){ + fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms); + } + + fprintf(stdout, "\n"); + + + audio.clear(); + } + } + } + + return 0; +} + int main(int argc, char ** argv) { whisper_params params; @@ -892,6 +1006,8 @@ int main(int argc, char ** argv) { if (!params.commands.empty()) { ret_val = process_command_list(ctx, audio, params); + } else if (!params.prompt.empty()) { + ret_val = always_prompt_transcription(ctx, audio, params); } else { ret_val = process_general_transcription(ctx, audio, params); }