From 3f15bb8a08efc91b4fad6ecf9520771db5177692 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 7 Oct 2022 22:07:24 +0300 Subject: [PATCH] ref #10 : add "step" argument for "stream" example Controls how often we run the inference. By default, we run it every 3 seconds. --- stream.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/stream.cpp b/stream.cpp index 1f84d66..18a3f6f 100644 --- a/stream.cpp +++ b/stream.cpp @@ -36,6 +36,7 @@ std::string to_timestamp(int64_t t) { struct whisper_params { int32_t seed = -1; // RNG seed, not used currently int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + int32_t step_ms = 3000; bool verbose = false; bool translate = false; @@ -57,6 +58,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { params.seed = std::stoi(argv[++i]); } else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); + } else if (arg == "--step") { + params.step_ms = std::stoi(argv[++i]); } else if (arg == "-v" || arg == "--verbose") { params.verbose = true; } else if (arg == "--translate") { @@ -97,6 +100,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params) fprintf(stderr, " -h, --help show this help message and exit\n"); fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n"); fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); + fprintf(stderr, " --step N audio step size in milliseconds (default: %d)\n", params.step_ms); fprintf(stderr, " -v, --verbose verbose output\n"); fprintf(stderr, " --translate translate from source language to english\n"); fprintf(stderr, " -ps, --print_special print special tokens\n"); @@ -197,6 +201,7 @@ int main(int argc, char ** argv) { struct whisper_context * ctx = whisper_init(params.model.c_str()); + const int n_samples = (params.step_ms/1000.0)*WHISPER_SAMPLE_RATE; const int n_samples_30s = 30*WHISPER_SAMPLE_RATE; std::vector pcmf32(n_samples_30s, 0.0f); std::vector pcmf32_old; @@ -212,7 +217,7 @@ int main(int argc, char ** argv) { } } printf("%s: processing %d samples (%.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n", - __func__, int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, params.n_threads, + __func__, n_samples, float(n_samples)/WHISPER_SAMPLE_RATE, params.n_threads, params.language.c_str(), params.translate ? "translate" : "transcribe", params.no_timestamps ? 0 : 1); @@ -238,7 +243,7 @@ int main(int argc, char ** argv) { } // process 3 seconds of new audio - while (SDL_GetQueuedAudioSize(g_dev_id_in) < 3*WHISPER_SAMPLE_RATE*sizeof(float)) { + while (SDL_GetQueuedAudioSize(g_dev_id_in) < n_samples*sizeof(float)) { SDL_Delay(1); } const int n_samples_new = SDL_GetQueuedAudioSize(g_dev_id_in)/sizeof(float);