whisper : various fixes

experiments/blocking
Georgi Gerganov 2 years ago
parent 0116c03fb7
commit dd1f4dfbab
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -1859,7 +1859,7 @@ whisper_vocab::id whisper_sample_best(
if (need_timestamp) { if (need_timestamp) {
// at the end of the 30-second audio segment, we start giving preference to time tokens // at the end of the 30-second audio segment, we start giving preference to time tokens
for (int i = 0; i < top_k; i++) { for (int i = 0; i < top_k; i++) {
if (probs_id[i].second > vocab.token_beg + 1300 && probs_id[i].first > probs_id[0].first*0.1) { if (probs_id[i].second > vocab.token_beg + 1300 && probs_id[i].first > 0.01*probs_id[0].first) {
return probs_id[i].second; return probs_id[i].second;
} }
} }
@ -1909,8 +1909,31 @@ whisper_vocab::id whisper_sample_timestamp(
return probs_id[0].second; return probs_id[0].second;
} }
// naive Discrete Fourier Transform
// input is real-valued
// output is complex-valued
void dft(const std::vector<float> & in, std::vector<float> & out) {
int N = in.size();
out.resize(N*2);
for (int k = 0; k < N; k++) {
float re = 0;
float im = 0;
for (int n = 0; n < N; n++) {
float angle = 2*M_PI*k*n/N;
re += in[n]*cos(angle);
im -= in[n]*sin(angle);
}
out[k*2 + 0] = re;
out[k*2 + 1] = im;
}
}
// Cooley-Tukey FFT // Cooley-Tukey FFT
// poor man's implmentation - use something better // poor man's implementation - use something better
// input is real-valued // input is real-valued
// output is complex-valued // output is complex-valued
void fft(const std::vector<float> & in, std::vector<float> & out) { void fft(const std::vector<float> & in, std::vector<float> & out) {
@ -1924,6 +1947,11 @@ void fft(const std::vector<float> & in, std::vector<float> & out) {
return; return;
} }
if (N%2 == 1) {
dft(in, out);
return;
}
std::vector<float> even; std::vector<float> even;
std::vector<float> odd; std::vector<float> odd;
@ -2014,9 +2042,20 @@ bool log_mel_spectrogram(
// FFT -> mag^2 // FFT -> mag^2
fft(fft_in, fft_out); fft(fft_in, fft_out);
for (int j = 0; j < n_fft; j++) { for (int j = 0; j < fft_size; j++) {
fft_out[j] = (fft_out[2*j + 0]*fft_out[2*j + 0] + fft_out[2*j + 1]*fft_out[2*j + 1]); fft_out[j] = (fft_out[2*j + 0]*fft_out[2*j + 0] + fft_out[2*j + 1]*fft_out[2*j + 1]);
} }
for (int j = 1; j < fft_size/2; j++) {
//if (i == 0) {
// printf("%d: %f %f\n", j, fft_out[j], fft_out[fft_size - j]);
//}
fft_out[j] += fft_out[fft_size - j];
}
if (i == 0) {
//for (int j = 0; j < fft_size; j++) {
// printf("%d: %e\n", j, fft_out[j]);
//}
}
// mel spectrogram // mel spectrogram
for (int j = 0; j < mel.n_mel; j++) { for (int j = 0; j < mel.n_mel; j++) {
@ -2048,6 +2087,7 @@ bool log_mel_spectrogram(
mmax = mel.data[i]; mmax = mel.data[i];
} }
} }
//printf("%s: max = %f\n", __func__, mmax);
mmax -= 8.0; mmax -= 8.0;
@ -2125,8 +2165,8 @@ int main(int argc, char ** argv) {
return 2; return 2;
} }
if (wav.channels != 1) { if (wav.channels != 1 && wav.channels != 2) {
fprintf(stderr, "%s: WAV file '%s' must be mono\n", argv[0], params.fname_inp.c_str()); fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], params.fname_inp.c_str());
return 3; return 3;
} }
@ -2140,16 +2180,24 @@ int main(int argc, char ** argv) {
return 5; return 5;
} }
int n = wav.totalPCMFrameCount;
std::vector<int16_t> pcm16; std::vector<int16_t> pcm16;
pcm16.resize(wav.totalPCMFrameCount); pcm16.resize(n*wav.channels);
drwav_read_pcm_frames_s16(&wav, wav.totalPCMFrameCount, pcm16.data()); drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
drwav_uninit(&wav); drwav_uninit(&wav);
// convert to float // convert to mono, float
pcmf32.resize(pcm16.size()); pcmf32.resize(n);
for (size_t i = 0; i < pcm16.size(); i++) { if (wav.channels == 1) {
for (size_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[i])/32768.0f; pcmf32[i] = float(pcm16[i])/32768.0f;
} }
} else {
for (size_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
}
}
} }
// compute log mel spectrogram // compute log mel spectrogram
@ -2195,7 +2243,7 @@ int main(int argc, char ** argv) {
} }
// the generated text including timestamps // the generated text including timestamps
std::vector<whisper_result> result_all; //std::vector<whisper_result> result_all;
// main loop // main loop
int seek = 0; int seek = 0;
@ -2252,7 +2300,7 @@ int main(int argc, char ** argv) {
int result_len = 0; int result_len = 0;
std::vector<whisper_result> result_cur; std::vector<whisper_result> result_cur;
for (int i = 0; i < model.hparams.n_text_ctx/2; ++i) { for (int i = 0; i < model.hparams.n_text_ctx/2 - 4; ++i) {
// decode // decode
if (prompt.size() > 0) { if (prompt.size() > 0) {
const int64_t t_start_us = ggml_time_us(); const int64_t t_start_us = ggml_time_us();
@ -2317,7 +2365,7 @@ int main(int argc, char ** argv) {
} }
result_cur.resize(result_len); result_cur.resize(result_len);
result_all.insert(result_all.end(), result_cur.begin(), result_cur.end()); //result_all.insert(result_all.end(), result_cur.begin(), result_cur.end());
for (const auto & r : result_cur) { for (const auto & r : result_cur) {
prompt_past.push_back(r.id); prompt_past.push_back(r.id);

Loading…
Cancel
Save