From 59c997ca2dfd501b0451ebdf6d798caad044179a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 15 Feb 2023 19:11:12 +0200 Subject: [PATCH] wip ignore --- examples/chess/CMakeLists.txt | 10 + examples/chess/chess.cpp | 634 ++++++++++++++++++++++++++++++++++ 2 files changed, 644 insertions(+) create mode 100644 examples/chess/CMakeLists.txt create mode 100644 examples/chess/chess.cpp diff --git a/examples/chess/CMakeLists.txt b/examples/chess/CMakeLists.txt new file mode 100644 index 0000000..cd63c84 --- /dev/null +++ b/examples/chess/CMakeLists.txt @@ -0,0 +1,10 @@ +if (WHISPER_SUPPORT_SDL2) + # chess + set(TARGET chess) + add_executable(${TARGET} chess.cpp) + + include(DefaultTargetOptions) + + target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS}) + target_link_libraries(${TARGET} PRIVATE common whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT}) +endif () diff --git a/examples/chess/chess.cpp b/examples/chess/chess.cpp new file mode 100644 index 0000000..c9737b9 --- /dev/null +++ b/examples/chess/chess.cpp @@ -0,0 +1,634 @@ +// Input chess moves via voice +// + +#include "common.h" +#include "whisper.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +// 500 -> 00:05.000 +// 6000 -> 01:00.000 +std::string to_timestamp(int64_t t) { + int64_t sec = t/100; + int64_t msec = t - sec*100; + int64_t min = sec/60; + sec = sec - min*60; + + char buf[32]; + snprintf(buf, sizeof(buf), "%02d:%02d.%03d", (int) min, (int) sec, (int) msec); + + return std::string(buf); +} + +// command-line parameters +struct whisper_params { + int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + int32_t step_ms = 3000; + int32_t length_ms = 10000; + int32_t keep_ms = 200; + int32_t capture_id = -1; + int32_t max_tokens = 32; + int32_t audio_ctx = 0; + + float vad_thold = 0.6f; + float freq_thold = 100.0f; + + bool translate = false; + bool print_special = false; + bool no_context = true; + bool no_timestamps = false; + + std::string language = "en"; + std::string model = "models/ggml-base.en.bin"; + std::string fname_inp; +}; + +void whisper_print_usage(int argc, char ** argv, const whisper_params & params); + +bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + + if (arg == "-h" || arg == "--help") { + whisper_print_usage(argc, argv, params); + exit(0); + } + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } + else if ( arg == "--step") { params.step_ms = std::stoi(argv[++i]); } + else if ( arg == "--length") { params.length_ms = std::stoi(argv[++i]); } + else if ( arg == "--keep") { params.keep_ms = std::stoi(argv[++i]); } + else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); } + else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); } + else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } + else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } + else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } + else if (arg == "-tr" || arg == "--translate") { params.translate = true; } + else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } + else if (arg == "-kc" || arg == "--keep-context") { params.no_context = false; } + else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } + else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } + else if (arg == "-f" || arg == "--file") { params.fname_inp = argv[++i]; } + else { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + whisper_print_usage(argc, argv, params); + exit(0); + } + } + + return true; +} + +void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) { + fprintf(stderr, "\n"); + fprintf(stderr, "usage: %s [options]\n", argv[0]); + fprintf(stderr, "\n"); + fprintf(stderr, "options:\n"); + fprintf(stderr, " -h, --help [default] show this help message and exit\n"); + fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); + fprintf(stderr, " --step N [%-7d] audio step size in milliseconds\n", params.step_ms); + fprintf(stderr, " --length N [%-7d] audio length in milliseconds\n", params.length_ms); + fprintf(stderr, " --keep N [%-7d] audio to keep from previous step in ms\n", params.keep_ms); + fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id); + fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens); + fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); + fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold); + fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold); + fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); + fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); + fprintf(stderr, " -kc, --keep-context [%-7s] keep context between audio chunks\n", params.no_context ? "false" : "true"); + fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); + fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); + fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", params.fname_inp.c_str()); + fprintf(stderr, "\n"); +} + +// +// SDL Audio capture +// + +class audio_async { +public: + audio_async(int len_ms); + ~audio_async(); + + bool init(int capture_id, int sample_rate); + + // start capturing audio via the provided SDL callback + // keep last len_ms seconds of audio in a circular buffer + bool resume(); + bool pause(); + bool clear(); + + // callback to be called by SDL + void callback(uint8_t * stream, int len); + + // get audio data from the circular buffer + void get(int ms, std::vector & audio); + +private: + SDL_AudioDeviceID m_dev_id_in = 0; + + int m_len_ms = 0; + int m_sample_rate = 0; + + std::atomic_bool m_running; + std::mutex m_mutex; + + std::vector m_audio; + std::vector m_audio_new; + size_t m_audio_pos = 0; + size_t m_audio_len = 0; +}; + +audio_async::audio_async(int len_ms) { + m_len_ms = len_ms; + + m_running = false; +} + +audio_async::~audio_async() { + if (m_dev_id_in) { + SDL_CloseAudioDevice(m_dev_id_in); + } +} + +bool audio_async::init(int capture_id, int sample_rate) { + SDL_LogSetPriority(SDL_LOG_CATEGORY_APPLICATION, SDL_LOG_PRIORITY_INFO); + + if (SDL_Init(SDL_INIT_AUDIO) < 0) { + SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Couldn't initialize SDL: %s\n", SDL_GetError()); + return false; + } + + SDL_SetHintWithPriority(SDL_HINT_AUDIO_RESAMPLING_MODE, "medium", SDL_HINT_OVERRIDE); + + { + int nDevices = SDL_GetNumAudioDevices(SDL_TRUE); + fprintf(stderr, "%s: found %d capture devices:\n", __func__, nDevices); + for (int i = 0; i < nDevices; i++) { + fprintf(stderr, "%s: - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE)); + } + } + + SDL_AudioSpec capture_spec_requested; + SDL_AudioSpec capture_spec_obtained; + + SDL_zero(capture_spec_requested); + SDL_zero(capture_spec_obtained); + + capture_spec_requested.freq = sample_rate; + capture_spec_requested.format = AUDIO_F32; + capture_spec_requested.channels = 1; + capture_spec_requested.samples = 1024; + capture_spec_requested.callback = [](void * userdata, uint8_t * stream, int len) { + audio_async * audio = (audio_async *) userdata; + audio->callback(stream, len); + }; + capture_spec_requested.userdata = this; + + if (capture_id >= 0) { + fprintf(stderr, "%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE)); + m_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0); + } else { + fprintf(stderr, "%s: attempt to open default capture device ...\n", __func__); + m_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0); + } + + if (!m_dev_id_in) { + fprintf(stderr, "%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError()); + m_dev_id_in = 0; + + return false; + } else { + fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, m_dev_id_in); + fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq); + fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format, + capture_spec_requested.format); + fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels, + capture_spec_requested.channels); + fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples); + } + + m_sample_rate = capture_spec_obtained.freq; + + m_audio.resize((m_sample_rate*m_len_ms)/1000); + + return true; +} + +bool audio_async::resume() { + if (!m_dev_id_in) { + fprintf(stderr, "%s: no audio device to resume!\n", __func__); + return false; + } + + if (m_running) { + fprintf(stderr, "%s: already running!\n", __func__); + return false; + } + + SDL_PauseAudioDevice(m_dev_id_in, 0); + + m_running = true; + + return true; +} + +bool audio_async::pause() { + if (!m_dev_id_in) { + fprintf(stderr, "%s: no audio device to pause!\n", __func__); + return false; + } + + if (!m_running) { + fprintf(stderr, "%s: already paused!\n", __func__); + return false; + } + + SDL_PauseAudioDevice(m_dev_id_in, 1); + + m_running = false; + + return true; +} + +bool audio_async::clear() { + if (!m_dev_id_in) { + fprintf(stderr, "%s: no audio device to clear!\n", __func__); + return false; + } + + if (!m_running) { + fprintf(stderr, "%s: not running!\n", __func__); + return false; + } + + { + std::lock_guard lock(m_mutex); + + m_audio_pos = 0; + m_audio_len = 0; + } + + return true; +} + +// callback to be called by SDL +void audio_async::callback(uint8_t * stream, int len) { + if (!m_running) { + return; + } + + const size_t n_samples = len / sizeof(float); + + m_audio_new.resize(n_samples); + memcpy(m_audio_new.data(), stream, n_samples * sizeof(float)); + + //fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len); + + { + std::lock_guard lock(m_mutex); + + if (m_audio_pos + n_samples > m_audio.size()) { + const size_t n0 = m_audio.size() - m_audio_pos; + + memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float)); + memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float)); + + m_audio_pos = (m_audio_pos + n_samples) % m_audio.size(); + m_audio_len = m_audio.size(); + } else { + memcpy(&m_audio[m_audio_pos], stream, n_samples * sizeof(float)); + + m_audio_pos = (m_audio_pos + n_samples) % m_audio.size(); + m_audio_len = std::min(m_audio_len + n_samples, m_audio.size()); + } + } +} + +void audio_async::get(int ms, std::vector & result) { + if (!m_dev_id_in) { + fprintf(stderr, "%s: no audio device to get audio from!\n", __func__); + return; + } + + if (!m_running) { + fprintf(stderr, "%s: not running!\n", __func__); + return; + } + + result.clear(); + + { + std::lock_guard lock(m_mutex); + + if (ms <= 0) { + ms = m_len_ms; + } + + size_t n_samples = (m_sample_rate * ms) / 1000; + if (n_samples > m_audio_len) { + n_samples = m_audio_len; + } + + result.resize(n_samples); + + int s0 = m_audio_pos - n_samples; + if (s0 < 0) { + s0 += m_audio.size(); + } + + if (s0 + n_samples > m_audio.size()) { + const size_t n0 = m_audio.size() - s0; + + memcpy(result.data(), &m_audio[s0], n0 * sizeof(float)); + memcpy(&result[n0], &m_audio[0], (n_samples - n0) * sizeof(float)); + } else { + memcpy(result.data(), &m_audio[s0], n_samples * sizeof(float)); + } + } +} + +/////////////////////////// + +int main(int argc, char ** argv) { + whisper_params params; + + if (whisper_params_parse(argc, argv, params) == false) { + return 1; + } + + params.keep_ms = std::min(params.keep_ms, params.step_ms); + params.length_ms = std::max(params.length_ms, params.step_ms); + + const int n_samples_step = (1e-3*params.step_ms )*WHISPER_SAMPLE_RATE; + const int n_samples_len = (1e-3*params.length_ms)*WHISPER_SAMPLE_RATE; + const int n_samples_keep = (1e-3*params.keep_ms )*WHISPER_SAMPLE_RATE; + const int n_samples_30s = (1e-3*30000.0 )*WHISPER_SAMPLE_RATE; + + const bool use_vad = n_samples_step <= 0; // sliding window mode uses VAD + + const int n_new_line = !use_vad ? std::max(1, params.length_ms / params.step_ms - 1) : 1; // number of steps to print new line + + params.no_timestamps = !use_vad; + params.no_context |= use_vad; + params.max_tokens = 0; + + // init audio + + audio_async audio(params.length_ms); + if (!audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) { + fprintf(stderr, "%s: audio.init() failed!\n", __func__); + return 1; + } + + audio.resume(); + + // whisper init + + if (whisper_lang_id(params.language.c_str()) == -1) { + fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str()); + whisper_print_usage(argc, argv, params); + exit(0); + } + + struct whisper_context * ctx = whisper_init_from_file(params.model.c_str()); + + std::vector pcmf32 (n_samples_30s, 0.0f); + std::vector pcmf32_old; + std::vector pcmf32_new(n_samples_30s, 0.0f); + + std::vector prompt_tokens; + + // print some info about the processing + { + fprintf(stderr, "\n"); + if (!whisper_is_multilingual(ctx)) { + if (params.language != "en" || params.translate) { + params.language = "en"; + params.translate = false; + fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__); + } + } + fprintf(stderr, "%s: processing %d samples (step = %.1f sec / len = %.1f sec / keep = %.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n", + __func__, + n_samples_step, + float(n_samples_step)/WHISPER_SAMPLE_RATE, + float(n_samples_len )/WHISPER_SAMPLE_RATE, + float(n_samples_keep)/WHISPER_SAMPLE_RATE, + params.n_threads, + params.language.c_str(), + params.translate ? "translate" : "transcribe", + params.no_timestamps ? 0 : 1); + + if (!use_vad) { + fprintf(stderr, "%s: n_new_line = %d, no_context = %d\n", __func__, n_new_line, params.no_context); + } else { + fprintf(stderr, "%s: using VAD, will transcribe on speech activity\n", __func__); + } + + fprintf(stderr, "\n"); + } + + int n_iter = 0; + + bool is_running = true; + + printf("[Start speaking]"); + fflush(stdout); + + auto t_last = std::chrono::high_resolution_clock::now(); + const auto t_start = t_last; + + // main audio loop + while (is_running) { + // handle Ctrl + C + { + SDL_Event event; + while (SDL_PollEvent(&event)) { + switch (event.type) { + case SDL_QUIT: + { + is_running = false; + } break; + default: + break; + } + } + + if (!is_running) { + break; + } + } + + if (!is_running) { + break; + } + + // process new audio + + if (!use_vad) { + while (true) { + audio.get(params.step_ms, pcmf32_new); + + if ((int) pcmf32_new.size() > 2*n_samples_step) { + fprintf(stderr, "\n\n%s: WARNING: cannot process audio fast enough, dropping audio ...\n\n", __func__); + audio.clear(); + continue; + } + + if ((int) pcmf32_new.size() >= n_samples_step) { + audio.clear(); + break; + } + + SDL_Delay(1); + } + + const int n_samples_new = pcmf32_new.size(); + + // take up to params.length_ms audio from previous iteration + const int n_samples_take = std::min((int) pcmf32_old.size(), std::max(0, n_samples_keep + n_samples_len - n_samples_new)); + + //printf("processing: take = %d, new = %d, old = %d\n", n_samples_take, n_samples_new, (int) pcmf32_old.size()); + + pcmf32.resize(n_samples_new + n_samples_take); + + for (int i = 0; i < n_samples_take; i++) { + pcmf32[i] = pcmf32_old[pcmf32_old.size() - n_samples_take + i]; + } + + memcpy(pcmf32.data() + n_samples_take, pcmf32_new.data(), n_samples_new*sizeof(float)); + + pcmf32_old = pcmf32; + } else { + const auto t_now = std::chrono::high_resolution_clock::now(); + const auto t_diff = std::chrono::duration_cast(t_now - t_last).count(); + + if (t_diff < 2000) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + continue; + } + + audio.get(2000, pcmf32_new); + + if (vad_simple(pcmf32_new, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, false)) { + audio.get(params.length_ms, pcmf32); + } else { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + continue; + } + + t_last = t_now; + } + + // run the inference + { + whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); + + wparams.print_progress = false; + wparams.print_special = params.print_special; + wparams.print_realtime = false; + wparams.print_timestamps = !params.no_timestamps; + wparams.translate = params.translate; + wparams.no_context = true; + wparams.single_segment = !use_vad; + wparams.max_tokens = params.max_tokens; + wparams.language = params.language.c_str(); + wparams.n_threads = params.n_threads; + + wparams.audio_ctx = params.audio_ctx; + + // disable temperature fallback + wparams.temperature_inc = -1.0f; + + wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data(); + wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size(); + + if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { + fprintf(stderr, "%s: failed to process audio\n", argv[0]); + return 6; + } + + // print result; + { + if (!use_vad) { + printf("\33[2K\r"); + + // print long empty line to clear the previous line + printf("%s", std::string(100, ' ').c_str()); + + printf("\33[2K\r"); + } else { + const int64_t t1 = (t_last - t_start).count()/1000000; + const int64_t t0 = std::max(0.0, t1 - pcmf32.size()*1000.0/WHISPER_SAMPLE_RATE); + + printf("\n"); + printf("### Transcription %d START | t0 = %d ms | t1 = %d ms\n", n_iter, (int) t0, (int) t1); + printf("\n"); + } + + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) { + const char * text = whisper_full_get_segment_text(ctx, i); + + if (params.no_timestamps) { + printf("%s", text); + fflush(stdout); + } else { + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + + printf ("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text); + } + } + + if (use_vad){ + printf("\n"); + printf("### Transcription %d END\n", n_iter); + } + } + + ++n_iter; + + if (!use_vad && (n_iter % n_new_line) == 0) { + printf("\n"); + + // keep part of the audio for next iteration to try to mitigate word boundary issues + pcmf32_old = std::vector(pcmf32.end() - n_samples_keep, pcmf32.end()); + + // Add tokens of the last full length segment as the prompt + if (!params.no_context) { + prompt_tokens.clear(); + + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) { + const int token_count = whisper_full_n_tokens(ctx, i); + for (int j = 0; j < token_count; ++j) { + prompt_tokens.push_back(whisper_full_get_token_id(ctx, i, j)); + } + } + } + } + } + } + + audio.pause(); + + whisper_print_timings(ctx); + whisper_free(ctx); + + return 0; +}