From 12a0558317683e9e05f02dbf34b725f45e84bd58 Mon Sep 17 00:00:00 2001 From: Eren Akbiyik Date: Mon, 21 Nov 2022 11:27:50 +0000 Subject: [PATCH] feat: prompt previous tokens for streaming I used a vector pointer instead of vector itself because it gave weird errors, and why not --- examples/stream/stream.cpp | 13 +++++++++++++ whisper.cpp | 10 ++++++++++ whisper.h | 2 ++ 3 files changed, 25 insertions(+) diff --git a/examples/stream/stream.cpp b/examples/stream/stream.cpp index 6f3634b..2256e69 100644 --- a/examples/stream/stream.cpp +++ b/examples/stream/stream.cpp @@ -234,6 +234,7 @@ int main(int argc, char ** argv) { std::vector pcmf32(n_samples_30s, 0.0f); std::vector pcmf32_old; + std::vector prompt_tokens; const int n_new_line = params.length_ms / params.step_ms - 1; // print some info about the processing @@ -344,6 +345,8 @@ int main(int argc, char ** argv) { wparams.audio_ctx = params.audio_ctx; wparams.speed_up = params.speed_up; + wparams.prompt_tokens = &prompt_tokens; + if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { fprintf(stderr, "%s: failed to process audio\n", argv[0]); return 6; @@ -393,6 +396,16 @@ int main(int argc, char ** argv) { // keep part of the audio for next iteration to try to mitigate word boundary issues pcmf32_old = std::vector(pcmf32.end() - n_samples_keep, pcmf32.end()); + + // Add tokens of the last full length segment as the prompt + prompt_tokens.clear(); + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) { + const int token_count = whisper_full_n_tokens(ctx, i); + for (int j = 0; j < token_count; ++j) { + prompt_tokens.push_back(whisper_full_get_token_id(ctx, i, j)); + } + } } } } diff --git a/whisper.cpp b/whisper.cpp index d35b90f..48bdb31 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -2412,6 +2412,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.speed_up =*/ false, /*.audio_ctx =*/ 0, + /*.prompt_tokens =*/ nullptr, /*.language =*/ "en", /*.greedy =*/ { @@ -2455,6 +2456,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.speed_up =*/ false, /*.audio_ctx =*/ 0, + /*.prompt_tokens =*/ nullptr, /*.language =*/ "en", /*.greedy =*/ { @@ -2584,6 +2586,14 @@ int whisper_full( prompt_past.clear(); } + // Prepend the prompt tokens to the prompt_past + if (params.prompt_tokens) { + for (int i = 0; i < (int) params.prompt_tokens->size(); i++) { + prompt_past.push_back((*params.prompt_tokens)[i]); + } + std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_tokens->size(), prompt_past.end()); + } + // overwrite audio_ctx ctx->exp_n_audio_ctx = params.audio_ctx; diff --git a/whisper.h b/whisper.h index 88cc711..917b459 100644 --- a/whisper.h +++ b/whisper.h @@ -3,6 +3,7 @@ #include #include +#include #ifdef WHISPER_SHARED # ifdef _WIN32 @@ -208,6 +209,7 @@ extern "C" { bool speed_up; // speed-up the audio by 2x using Phase Vocoder int audio_ctx; // overwrite the audio context size (0 = use default) + const std::vector * prompt_tokens; const char * language; struct {