From b211d2f9b1eb1ac05198b83d46913d8b19d11d13 Mon Sep 17 00:00:00 2001 From: Niranjan Yadla Date: Wed, 5 Oct 2022 15:01:23 -0700 Subject: [PATCH] ADD VAD to save inference cycles --- stream.cpp | 36 +++++++++++++++++++++++++++++++++--- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/stream.cpp b/stream.cpp index e9d0364..aacc9ec 100644 --- a/stream.cpp +++ b/stream.cpp @@ -17,6 +17,34 @@ #include #include #include +#include + +// vad settings +const double MIN_ENERGY = 0.9; +const double INCREASE_FACTOR = 0.0008; +const double FRAME_MARGIN = 5; + +// vad algorithm support variables +// TODO struct +size_t n_frame = 0; +double emin, emax; +double delta; +size_t margin_frame_counter; + +double vad(float *buffer, size_t packet_length, size_t n_frame, size_t *mf_counter, double *emin, double *emax, double *delta) +{ + + // calculate energy of current frame (RMSE) + double current_energy = 0; + for (int i = 0; i < packet_length; i++) + { + current_energy += pow((double)buffer[i], 2); + } + + current_energy = sqrt(current_energy / (double)packet_length); + printf("\ncurrent energy: %f\n", current_energy); + return current_energy; +} // 500 -> 00:05.000 // 6000 -> 01:00.000 @@ -241,8 +269,8 @@ int main(int argc, char ** argv) { while (SDL_GetQueuedAudioSize(g_dev_id_in) < 3*WHISPER_SAMPLE_RATE*sizeof(float)) { SDL_Delay(1); } - const int n_samples_new = SDL_GetQueuedAudioSize(g_dev_id_in)/sizeof(float); + const int n_samples_new = SDL_GetQueuedAudioSize(g_dev_id_in)/sizeof(float); // take one second from previous iteration // TODO: better strategy const int n_samples_take = std::min((int) pcmf32_old.size(), std::max(0, n_samples_30s/30 - n_samples_new)); @@ -259,8 +287,10 @@ int main(int argc, char ** argv) { pcmf32_old = pcmf32; - // run the inference - { + double vad_val; + vad_val = vad(pcmf32.data(), pcmf32.size(), n_frame++, &margin_frame_counter, &emin, &emax, &delta); + if (vad_val > 0.2f) // if vad says it's a voice packet... + { // run the inference whisper_full_params wparams = whisper_full_default_params(WHISPER_DECODE_GREEDY); wparams.print_progress = false;