From dd1f4dfbab5d405d27f0f7c95935fdc44c80dda5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 3 Oct 2022 19:31:17 +0300 Subject: [PATCH] whisper : various fixes --- examples/whisper/main.cpp | 76 +++++++++++++++++++++++++++++++-------- 1 file changed, 62 insertions(+), 14 deletions(-) diff --git a/examples/whisper/main.cpp b/examples/whisper/main.cpp index 79935c9..b39f360 100644 --- a/examples/whisper/main.cpp +++ b/examples/whisper/main.cpp @@ -1859,7 +1859,7 @@ whisper_vocab::id whisper_sample_best( if (need_timestamp) { // at the end of the 30-second audio segment, we start giving preference to time tokens for (int i = 0; i < top_k; i++) { - if (probs_id[i].second > vocab.token_beg + 1300 && probs_id[i].first > probs_id[0].first*0.1) { + if (probs_id[i].second > vocab.token_beg + 1300 && probs_id[i].first > 0.01*probs_id[0].first) { return probs_id[i].second; } } @@ -1909,8 +1909,31 @@ whisper_vocab::id whisper_sample_timestamp( return probs_id[0].second; } +// naive Discrete Fourier Transform +// input is real-valued +// output is complex-valued +void dft(const std::vector & in, std::vector & out) { + int N = in.size(); + + out.resize(N*2); + + for (int k = 0; k < N; k++) { + float re = 0; + float im = 0; + + for (int n = 0; n < N; n++) { + float angle = 2*M_PI*k*n/N; + re += in[n]*cos(angle); + im -= in[n]*sin(angle); + } + + out[k*2 + 0] = re; + out[k*2 + 1] = im; + } +} + // Cooley-Tukey FFT -// poor man's implmentation - use something better +// poor man's implementation - use something better // input is real-valued // output is complex-valued void fft(const std::vector & in, std::vector & out) { @@ -1924,6 +1947,11 @@ void fft(const std::vector & in, std::vector & out) { return; } + if (N%2 == 1) { + dft(in, out); + return; + } + std::vector even; std::vector odd; @@ -2014,9 +2042,20 @@ bool log_mel_spectrogram( // FFT -> mag^2 fft(fft_in, fft_out); - for (int j = 0; j < n_fft; j++) { + for (int j = 0; j < fft_size; j++) { fft_out[j] = (fft_out[2*j + 0]*fft_out[2*j + 0] + fft_out[2*j + 1]*fft_out[2*j + 1]); } + for (int j = 1; j < fft_size/2; j++) { + //if (i == 0) { + // printf("%d: %f %f\n", j, fft_out[j], fft_out[fft_size - j]); + //} + fft_out[j] += fft_out[fft_size - j]; + } + if (i == 0) { + //for (int j = 0; j < fft_size; j++) { + // printf("%d: %e\n", j, fft_out[j]); + //} + } // mel spectrogram for (int j = 0; j < mel.n_mel; j++) { @@ -2048,6 +2087,7 @@ bool log_mel_spectrogram( mmax = mel.data[i]; } } + //printf("%s: max = %f\n", __func__, mmax); mmax -= 8.0; @@ -2125,8 +2165,8 @@ int main(int argc, char ** argv) { return 2; } - if (wav.channels != 1) { - fprintf(stderr, "%s: WAV file '%s' must be mono\n", argv[0], params.fname_inp.c_str()); + if (wav.channels != 1 && wav.channels != 2) { + fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], params.fname_inp.c_str()); return 3; } @@ -2140,15 +2180,23 @@ int main(int argc, char ** argv) { return 5; } + int n = wav.totalPCMFrameCount; + std::vector pcm16; - pcm16.resize(wav.totalPCMFrameCount); - drwav_read_pcm_frames_s16(&wav, wav.totalPCMFrameCount, pcm16.data()); + pcm16.resize(n*wav.channels); + drwav_read_pcm_frames_s16(&wav, n, pcm16.data()); drwav_uninit(&wav); - // convert to float - pcmf32.resize(pcm16.size()); - for (size_t i = 0; i < pcm16.size(); i++) { - pcmf32[i] = float(pcm16[i])/32768.0f; + // convert to mono, float + pcmf32.resize(n); + if (wav.channels == 1) { + for (size_t i = 0; i < n; i++) { + pcmf32[i] = float(pcm16[i])/32768.0f; + } + } else { + for (size_t i = 0; i < n; i++) { + pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f; + } } } @@ -2195,7 +2243,7 @@ int main(int argc, char ** argv) { } // the generated text including timestamps - std::vector result_all; + //std::vector result_all; // main loop int seek = 0; @@ -2252,7 +2300,7 @@ int main(int argc, char ** argv) { int result_len = 0; std::vector result_cur; - for (int i = 0; i < model.hparams.n_text_ctx/2; ++i) { + for (int i = 0; i < model.hparams.n_text_ctx/2 - 4; ++i) { // decode if (prompt.size() > 0) { const int64_t t_start_us = ggml_time_us(); @@ -2317,7 +2365,7 @@ int main(int argc, char ** argv) { } result_cur.resize(result_len); - result_all.insert(result_all.end(), result_cur.begin(), result_cur.end()); + //result_all.insert(result_all.end(), result_cur.begin(), result_cur.end()); for (const auto & r : result_cur) { prompt_past.push_back(r.id);