diff --git a/examples/command/command.cpp b/examples/command/command.cpp index 459f503..068233f 100644 --- a/examples/command/command.cpp +++ b/examples/command/command.cpp @@ -19,6 +19,7 @@ #include #include #include +#include // command-line parameters struct whisper_params { @@ -41,6 +42,7 @@ struct whisper_params { std::string language = "en"; std::string model = "models/ggml-base.en.bin"; std::string fname_out = ""; + std::string commands = ""; }; void whisper_print_usage(int argc, char ** argv, const whisper_params & params); @@ -68,6 +70,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; } + else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); whisper_print_usage(argc, argv, params); @@ -83,22 +86,23 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params) fprintf(stderr, "usage: %s [options]\n", argv[0]); fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); - fprintf(stderr, " -h, --help [default] show this help message and exit\n"); - fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); - fprintf(stderr, " -pms N, --prompt-ms N [%-7d] prompt duration in milliseconds\n", params.prompt_ms); - fprintf(stderr, " -cms N, --command-ms N [%-7d] command duration in milliseconds\n", params.command_ms); - fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id); - fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens); - fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); - fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold); - fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold); - fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); - fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); - fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); - fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); - fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); - fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); - fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str()); + fprintf(stderr, " -h, --help [default] show this help message and exit\n"); + fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); + fprintf(stderr, " -pms N, --prompt-ms N [%-7d] prompt duration in milliseconds\n", params.prompt_ms); + fprintf(stderr, " -cms N, --command-ms N [%-7d] command duration in milliseconds\n", params.command_ms); + fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id); + fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens); + fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); + fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold); + fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold); + fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); + fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); + fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); + fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); + fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); + fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); + fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str()); + fprintf(stderr, " -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n", params.commands.c_str()); fprintf(stderr, "\n"); } @@ -484,6 +488,28 @@ float similarity(const std::string & s0, const std::string & s1) { return 1.0f - (dist / std::max(s0.size(), s1.size())); } +std::vector read_allowed_commands(const std::string & fname) { + std::vector allowed_commands; + + std::ifstream ifs(fname); + if (!ifs.is_open()) { + return allowed_commands; + } + + std::string line; + while (std::getline(ifs, line)) { + line = trim(line); + if (line.empty()) { + continue; + } + + std::transform(line.begin(), line.end(),line.begin(), ::tolower); + allowed_commands.push_back(std::move(line)); + } + + return allowed_commands; +} + int main(int argc, char ** argv) { whisper_params params; @@ -521,7 +547,6 @@ int main(int argc, char ** argv) { fprintf(stderr, "\n"); } - // init audio audio_async audio(30*1000); @@ -532,6 +557,8 @@ int main(int argc, char ** argv) { audio.resume(); + int max_len = 0; + bool is_running = true; bool have_prompt = false; bool ask_prompt = true; @@ -542,7 +569,94 @@ int main(int argc, char ** argv) { std::vector pcmf32_cur; std::vector pcmf32_prompt; - const std::string k_prompt = "Ok Whisper, start listening for commands."; + std::vector allowed_commands; + std::vector> allowed_tokens; + + std::string k_prompt = ""; + std::vector k_tokens; + + if (params.commands != "") { + fprintf(stderr, "\n"); + fprintf(stderr, "%s: guided mode\n", __func__); + + allowed_commands = read_allowed_commands(params.commands); + + if (allowed_commands.empty()) { + fprintf(stderr, "%s: error: failed to read allowed commands from '%s'\n", __func__, params.commands.c_str()); + return 2; + } + + for (const auto & cmd : allowed_commands) { + whisper_token tokens[1024]; + allowed_tokens.emplace_back(); + + for (int l = 0; l < cmd.size(); ++l) { + // NOTE: very important to add the whitespace ! + // the reason is that the first decoded token starts with a whitespace too! + std::string ss = std::string(" ") + cmd.substr(0, l + 1); + + const int n = whisper_tokenize(ctx, ss.c_str(), tokens, 1024); + if (n < 0) { + fprintf(stderr, "%s: error: failed to tokenize command '%s'\n", __func__, cmd.c_str()); + return 3; + } + + if (n == 1) { + allowed_tokens.back().push_back(tokens[0]); + } + } + + max_len = std::max(max_len, (int) cmd.size()); + } + + fprintf(stderr, "%s: allowed commands [ tokens ]:\n", __func__); + fprintf(stderr, "\n"); + for (int i = 0; i < (int) allowed_commands.size(); ++i) { + fprintf(stderr, " - \033[1m%-*s\033[0m = [", max_len, allowed_commands[i].c_str()); + for (const auto & token : allowed_tokens[i]) { + fprintf(stderr, " %d", token); + } + fprintf(stderr, " ]\n"); + } + + k_prompt = "select one from the available words: "; + for (int i = 0; i < (int) allowed_commands.size(); ++i) { + if (i > 0) { + k_prompt += ", "; + } + k_prompt += allowed_commands[i]; + } + k_prompt += ". selected word: "; + + // tokenize prompt + { + k_tokens.resize(1024); + const int n = whisper_tokenize(ctx, k_prompt.c_str(), k_tokens.data(), 1024); + if (n < 0) { + fprintf(stderr, "%s: error: failed to tokenize prompt '%s'\n", __func__, k_prompt.c_str()); + return 4; + } + k_tokens.resize(n); + } + + fprintf(stderr, "\n"); + fprintf(stderr, "%s: prompt: '%s'\n", __func__, k_prompt.c_str()); + fprintf(stderr, "%s: tokens: [", __func__); + for (const auto & token : k_tokens) { + fprintf(stderr, " %d", token); + } + fprintf(stderr, " ]\n"); + + fprintf(stderr, "\n"); + fprintf(stderr, "%s: listening for a command ...\n", __func__); + fprintf(stderr, "\n"); + + } else { + fprintf(stderr, "\n"); + fprintf(stderr, "%s: general-purpose mode\n", __func__); + + k_prompt = "Ok Whisper, start listening for commands."; + } // main loop while (is_running) { @@ -568,78 +682,172 @@ int main(int argc, char ** argv) { // delay std::this_thread::sleep_for(std::chrono::milliseconds(100)); - if (ask_prompt) { - fprintf(stdout, "\n"); - fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m"); - fprintf(stdout, "\n"); + if (allowed_commands.empty()) { + // general-purpose mode + // freely transcribe the voice into text - ask_prompt = false; - } + if (ask_prompt) { + fprintf(stdout, "\n"); + fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m"); + fprintf(stdout, "\n"); - int64_t t_ms = 0; + ask_prompt = false; + } - { - audio.get(2000, pcmf32_cur); + { + int64_t t_ms = 0; - if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) { - fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__); + audio.get(2000, pcmf32_cur); + + if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) { + fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__); + + if (!have_prompt) { + // wait for activation phrase + audio.get(params.prompt_ms, pcmf32_cur); - if (!have_prompt) { - audio.get(params.prompt_ms, pcmf32_cur); + const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms)); - const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms)); + fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms); - fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms); + const float sim = similarity(txt, k_prompt); - const float sim = similarity(txt, k_prompt); + if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) { + fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__); + ask_prompt = true; + } else { + fprintf(stdout, "\n"); + fprintf(stdout, "%s: The prompt has been recognized!\n", __func__); + fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__); + fprintf(stdout, "\n"); - if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) { - fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__); - ask_prompt = true; + // save the audio for the prompt + pcmf32_prompt = pcmf32_cur; + have_prompt = true; + } } else { - fprintf(stdout, "\n"); - fprintf(stdout, "%s: The prompt has been recognized!\n", __func__); - fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__); - fprintf(stdout, "\n"); + // we have heard the activation phrase, now detect the commands + audio.get(params.command_ms, pcmf32_cur); - // save the audio for the prompt - pcmf32_prompt = pcmf32_cur; - have_prompt = true; + // prepend the prompt audio + pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end()); + + const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms)); + + prob = 100.0f*(prob - prob0); + + //fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str()); + + // find the prompt in the text + float best_sim = 0.0f; + size_t best_len = 0; + for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) { + const auto prompt = txt.substr(0, n); + + const float sim = similarity(prompt, k_prompt); + + //fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim); + + if (sim > best_sim) { + best_sim = sim; + best_len = n; + } + } + + const std::string command = ::trim(txt.substr(best_len)); + + fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms); + fprintf(stdout, "\n"); } - } else { - audio.get(params.command_ms, pcmf32_cur); - // prepend the prompt audio - pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end()); + audio.clear(); + } + } + } else { + // command-list mode + // guide the transcription to match the most likely command from a provided list + + audio.get(2000, pcmf32_cur); + + if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) { + fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__); + + const auto t_start = std::chrono::high_resolution_clock::now(); - const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms)); + whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); - prob = 100.0f*(prob - prob0); + wparams.print_progress = false; + wparams.print_special = params.print_special; + wparams.print_realtime = false; + wparams.print_timestamps = !params.no_timestamps; + wparams.translate = params.translate; + wparams.no_context = true; + wparams.single_segment = true; + wparams.max_tokens = 1; + wparams.language = params.language.c_str(); + wparams.n_threads = params.n_threads; - //fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str()); + wparams.audio_ctx = params.audio_ctx; + wparams.speed_up = params.speed_up; - // find the prompt in the text - float best_sim = 0.0f; - size_t best_len = 0; - for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) { - const auto prompt = txt.substr(0, n); + wparams.prompt_tokens = k_tokens.data(); + wparams.prompt_n_tokens = k_tokens.size(); - const float sim = similarity(prompt, k_prompt); + // run the transformer and a single decoding pass + if (whisper_full(ctx, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) { + fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__); + break; + } - //fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim); + const auto * probs = whisper_get_probs(ctx); + std::vector> probs_id; - if (sim > best_sim) { - best_sim = sim; - best_len = n; - } + double psum = 0.0; + for (int i = 0; i < (int) allowed_commands.size(); ++i) { + probs_id.push_back(std::make_pair(probs[allowed_tokens[i][0]], i)); + for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) { + probs_id.back().first += probs[allowed_tokens[i][j]]; } + probs_id.back().first /= allowed_tokens[i].size(); + psum += probs_id.back().first; + } - const std::string command = ::trim(txt.substr(best_len)); + // normalize + for (auto & p : probs_id) { + p.first /= psum; + } + + // sort descending + { + using pair_type = decltype(probs_id)::value_type; + std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) { + return a.first > b.first; + }); + } - fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms); + // print the commands and the respective probabilities + { fprintf(stdout, "\n"); + for (const auto & cmd : probs_id) { + fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first); + for (int i = 0; i < (int) allowed_tokens[cmd.second].size(); ++i) { + fprintf(stdout, "%f ", probs[allowed_tokens[cmd.second][i]]); + } + fprintf(stdout, "\n"); + } } + // best command + { + fprintf(stdout, "\n"); + fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__, + "\033[1m", allowed_commands[probs_id[0].second].c_str(), "\033[0m", probs_id[0].first, + (int) std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - t_start).count()); + fprintf(stdout, "\n"); + } + + const auto t_end = std::chrono::high_resolution_clock::now(); + audio.clear(); } } diff --git a/examples/command/commands.txt b/examples/command/commands.txt new file mode 100644 index 0000000..2653de6 --- /dev/null +++ b/examples/command/commands.txt @@ -0,0 +1,9 @@ +enable +disable +cat +dog +apple +red +blue +green +lightblue diff --git a/whisper.cpp b/whisper.cpp index 1f64b35..0aca60c 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -2826,13 +2826,13 @@ int whisper_full( //{ // const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]"; - // printf("%s: %10s %6d %6.3f '%s'\n", __func__, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str()); + // printf("%s: %3d %10s %6d %6.3f '%s'\n", __func__, i, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str()); //} // end of segment - if (token.id == whisper_token_eot(ctx) || // end of text token - (params.max_tokens > 0 && i > params.max_tokens) || // max tokens per segment reached - (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached + if (token.id == whisper_token_eot(ctx) || // end of text token + (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached + (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached ) { if (result_len == 0) { if (seek + seek_delta + 100 >= seek_end) {