feat: prompt previous tokens for streaming

I used a vector pointer instead of vector itself because it gave weird errors, and why not
pull/163/head
Eren Akbiyik 3 years ago
parent 2065572a11
commit 12a0558317

@ -234,6 +234,7 @@ int main(int argc, char ** argv) {
std::vector<float> pcmf32(n_samples_30s, 0.0f);
std::vector<float> pcmf32_old;
std::vector<whisper_token> prompt_tokens;
const int n_new_line = params.length_ms / params.step_ms - 1;
// print some info about the processing
@ -344,6 +345,8 @@ int main(int argc, char ** argv) {
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
wparams.prompt_tokens = &prompt_tokens;
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
return 6;
@ -393,6 +396,16 @@ int main(int argc, char ** argv) {
// keep part of the audio for next iteration to try to mitigate word boundary issues
pcmf32_old = std::vector<float>(pcmf32.end() - n_samples_keep, pcmf32.end());
// Add tokens of the last full length segment as the prompt
prompt_tokens.clear();
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const int token_count = whisper_full_n_tokens(ctx, i);
for (int j = 0; j < token_count; ++j) {
prompt_tokens.push_back(whisper_full_get_token_id(ctx, i, j));
}
}
}
}
}

@ -2412,6 +2412,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.speed_up =*/ false,
/*.audio_ctx =*/ 0,
/*.prompt_tokens =*/ nullptr,
/*.language =*/ "en",
/*.greedy =*/ {
@ -2455,6 +2456,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.speed_up =*/ false,
/*.audio_ctx =*/ 0,
/*.prompt_tokens =*/ nullptr,
/*.language =*/ "en",
/*.greedy =*/ {
@ -2584,6 +2586,14 @@ int whisper_full(
prompt_past.clear();
}
// Prepend the prompt tokens to the prompt_past
if (params.prompt_tokens) {
for (int i = 0; i < (int) params.prompt_tokens->size(); i++) {
prompt_past.push_back((*params.prompt_tokens)[i]);
}
std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_tokens->size(), prompt_past.end());
}
// overwrite audio_ctx
ctx->exp_n_audio_ctx = params.audio_ctx;

@ -3,6 +3,7 @@
#include <stdint.h>
#include <stdbool.h>
#include <vector>
#ifdef WHISPER_SHARED
# ifdef _WIN32
@ -208,6 +209,7 @@ extern "C" {
bool speed_up; // speed-up the audio by 2x using Phase Vocoder
int audio_ctx; // overwrite the audio context size (0 = use default)
const std::vector<whisper_token> * prompt_tokens;
const char * language;
struct {

Loading…
Cancel
Save