whisper : add option to speed up the audio tempo by x2

Using a Phase Vocoder for speeding up the audio tempo by scaling down
the frequencies in the frequency domain.

This reduces the computation in the Encoder by a factor of 2.
The transcription accuracy is degraded, but for slow to normal speech -
it seems to be still very good.

I think this can find application for real-time transcription - i.e. the
"stream" example.
pull/147/head
Georgi Gerganov 2 years ago
parent 41b48ab7f1
commit 83c742f1a7

@ -59,6 +59,7 @@ struct whisper_params {
float word_thold = 0.01f;
bool speed_up = false;
bool verbose = false;
bool translate = false;
bool output_txt = false;
@ -104,6 +105,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
params.max_len = std::stoi(argv[++i]);
} else if (arg == "-wt" || arg == "--word-thold") {
params.word_thold = std::stof(argv[++i]);
} else if (arg == "-su" || arg == "--speed-up") {
params.speed_up = true;
} else if (arg == "-v" || arg == "--verbose") {
params.verbose = true;
} else if (arg == "--translate") {
@ -161,6 +164,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
fprintf(stderr, " -mc N, --max-context N maximum number of text context tokens to store (default: max)\n");
fprintf(stderr, " -ml N, --max-len N maximum segment length in characters (default: %d)\n", params.max_len);
fprintf(stderr, " -wt N, --word-thold N word timestamp probability threshold (default: %f)\n", params.word_thold);
fprintf(stderr, " -su, --speed-up speed up audio by factor of 2 (faster processing, reduced accuracy, default: %s)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -v, --verbose verbose output\n");
fprintf(stderr, " --translate translate from source language to english\n");
fprintf(stderr, " -otxt, --output-txt output result in a text file\n");
@ -454,7 +458,7 @@ int main(int argc, char ** argv) {
std::vector<float> pcmf32;
{
drwav wav;
if (fname_inp == "-") {
std::vector<uint8_t> wav_data;
{
@ -563,6 +567,8 @@ int main(int argc, char ** argv) {
wparams.thold_pt = params.word_thold;
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
wparams.speed_up = params.speed_up;
// this callback is called on each new segment
if (!wparams.print_realtime) {
wparams.new_segment_callback = whisper_print_segment_callback;

@ -41,6 +41,7 @@ struct whisper_params {
int32_t length_ms = 10000;
int32_t capture_id = -1;
bool speed_up = false;
bool verbose = false;
bool translate = false;
bool no_context = true;
@ -68,6 +69,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
params.length_ms = std::stoi(argv[++i]);
} else if (arg == "-c" || arg == "--capture") {
params.capture_id = std::stoi(argv[++i]);
} else if (arg == "-su" || arg == "--speed-up") {
params.speed_up = true;
} else if (arg == "-v" || arg == "--verbose") {
params.verbose = true;
} else if (arg == "--translate") {
@ -113,6 +116,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
fprintf(stderr, " --step N audio step size in milliseconds (default: %d)\n", params.step_ms);
fprintf(stderr, " --length N audio length in milliseconds (default: %d)\n", params.length_ms);
fprintf(stderr, " -c ID, --capture ID capture device ID (default: -1)\n");
fprintf(stderr, " -su, --speed-up speed up audio by factor of 2 (faster processing, reduced accuracy, default: %s)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -v, --verbose verbose output\n");
fprintf(stderr, " --translate translate from source language to english\n");
fprintf(stderr, " -kc, --keep-context keep text context from earlier audio (default: false)\n");
@ -326,6 +330,8 @@ int main(int argc, char ** argv) {
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.speed_up = params.speed_up;
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
return 6;

@ -2031,6 +2031,7 @@ static bool log_mel_spectrogram(
const int n_mel,
const int n_threads,
const whisper_filters & filters,
const bool speed_up,
whisper_mel & mel) {
// Hanning window
@ -2044,7 +2045,7 @@ static bool log_mel_spectrogram(
mel.n_len = (n_samples)/fft_step;
mel.data.resize(mel.n_mel*mel.n_len);
const int n_fft = 1 + fft_size/2;
const int n_fft = 1 + (speed_up ? fft_size/4 : fft_size/2);
//printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
//printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);
@ -2091,6 +2092,13 @@ static bool log_mel_spectrogram(
//}
}
if (speed_up) {
// scale down in the frequency domain results in a speed up in the time domain
for (int j = 0; j < n_fft; j++) {
fft_out[j] = 0.5*(fft_out[2*j] + fft_out[2*j + 1]);
}
}
// mel spectrogram
for (int j = 0; j < mel.n_mel; j++) {
double sum = 0.0;
@ -2171,7 +2179,21 @@ void whisper_free(struct whisper_context * ctx) {
int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
const int64_t t_start_us = ggml_time_us();
if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, ctx->mel)) {
if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) {
fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
return -1;
}
ctx->t_mel_us = ggml_time_us() - t_start_us;
return 0;
}
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
const int64_t t_start_us = ggml_time_us();
if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) {
fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
return -1;
}
@ -2353,6 +2375,8 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.thold_ptsum =*/ 0.01f,
/*.max_len =*/ 0,
/*.speed_up =*/ false,
/*.language =*/ "en",
/*.greedy =*/ {
@ -2391,6 +2415,8 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.thold_ptsum =*/ 0.01f,
/*.max_len =*/ 0,
/*.speed_up =*/ false,
/*.language =*/ "en",
/*.greedy =*/ {
@ -2485,9 +2511,16 @@ int whisper_full(
result_all.clear();
// compute log mel spectrogram
if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
return -1;
if (params.speed_up) {
if (whisper_pcm_to_mel_phase_vocoder(ctx, samples, n_samples, params.n_threads) != 0) {
fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
return -1;
}
} else {
if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
return -1;
}
}
if (params.token_timestamps) {
@ -2673,16 +2706,19 @@ int whisper_full(
if (tokens_cur[i].id > whisper_token_beg(ctx)) {
const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
if (!text.empty()) {
const auto tt0 = params.speed_up ? 2*t0 : t0;
const auto tt1 = params.speed_up ? 2*t1 : t1;
if (params.print_realtime) {
if (params.print_timestamps) {
printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text.c_str());
printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
} else {
printf("%s", text.c_str());
fflush(stdout);
}
}
result_all.push_back({ t0, t1, text, {} });
result_all.push_back({ tt0, tt1, text, {} });
for (int j = i0; j <= i; j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
@ -2714,16 +2750,19 @@ int whisper_full(
if (!text.empty()) {
const auto t1 = seek + seek_delta;
const auto tt0 = params.speed_up ? 2*t0 : t0;
const auto tt1 = params.speed_up ? 2*t1 : t1;
if (params.print_realtime) {
if (params.print_timestamps) {
printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text.c_str());
printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
} else {
printf("%s", text.c_str());
fflush(stdout);
}
}
result_all.push_back({ t0, t1, text, {} });
result_all.push_back({ tt0, tt1, text, {} });
for (int j = i0; j < (int) tokens_cur.size(); j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}

@ -202,6 +202,9 @@ extern "C" {
float thold_ptsum; // timestamp token sum probability threshold (~0.01)
int max_len; // max segment length in characters
// [EXPERIMENTAL] speed-up techniques
bool speed_up; // speed-up the audio by 2x using Phase Vocoder
const char * language;
struct {

Loading…
Cancel
Save