diff --git a/stream.cpp b/stream.cpp index 18a3f6f..f927819 100644 --- a/stream.cpp +++ b/stream.cpp @@ -40,6 +40,7 @@ struct whisper_params { bool verbose = false; bool translate = false; + bool no_context = true; bool print_special_tokens = false; bool no_timestamps = true; @@ -64,6 +65,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { params.verbose = true; } else if (arg == "--translate") { params.translate = true; + } else if (arg == "-kc" || arg == "--keep-context") { + params.no_context = false; } else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; if (whisper_lang_id(params.language.c_str()) == -1) { @@ -103,6 +106,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params) fprintf(stderr, " --step N audio step size in milliseconds (default: %d)\n", params.step_ms); fprintf(stderr, " -v, --verbose verbose output\n"); fprintf(stderr, " --translate translate from source language to english\n"); + fprintf(stderr, " -nc, --no-context disable context from earlier audio (default: false)\n"); fprintf(stderr, " -ps, --print_special print special tokens\n"); fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n"); fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str()); @@ -273,6 +277,7 @@ int main(int argc, char ** argv) { wparams.print_realtime = false; wparams.print_timestamps = !params.no_timestamps; wparams.translate = params.translate; + wparams.no_context = params.no_context; wparams.language = params.language.c_str(); wparams.n_threads = params.n_threads; diff --git a/whisper.cpp b/whisper.cpp index ca0c6a4..9913ab6 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -405,6 +405,8 @@ struct whisper_context { std::vector result_cur; std::vector result_all; + + std::vector prompt_past; }; // load the model from a ggml file @@ -1020,8 +1022,6 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) { // - model: the model // - n_threads: number of threads to use // - mel_offset: offset in the mel spectrogram (i.e. audio offset) -// - mel_inp: input mel spectrogram -// - features: output encoded features // bool whisper_encode( whisper_context & wctx, @@ -1405,10 +1405,9 @@ bool whisper_encode( // // - model: the model // - n_threads: number of threads to use -// - n_past: prompt length -// - prompt: text prompt -// - logits_out: output logits -// - probs_out: output probabilities +// - tokens: text prompt +// - n_tokens: number of tokens in the prompt +// - n_past: number of past tokens to prefix the prompt with // bool whisper_decode( whisper_context & wctx, @@ -2259,6 +2258,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_decode_strat .offset_ms = 0, .translate = false, + .no_context = false, .print_special_tokens = false, .print_progress = true, .print_realtime = false, @@ -2279,6 +2279,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_decode_strat .offset_ms = 0, .translate = false, + .no_context = false, .print_special_tokens = false, .print_progress = true, .print_realtime = false, @@ -2297,6 +2298,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_decode_strat return result; } + int whisper_full( struct whisper_context * ctx, struct whisper_full_params params, @@ -2309,7 +2311,10 @@ int whisper_full( } // the accumulated text context so far - std::vector prompt_past = { }; + auto & prompt_past = ctx->prompt_past; + if (params.no_context) { + prompt_past.clear(); + } // these tokens determine the task that will be performed std::vector prompt_init = { whisper_token_sot(ctx) }; diff --git a/whisper.h b/whisper.h index 78e08b7..79df0e0 100644 --- a/whisper.h +++ b/whisper.h @@ -105,6 +105,7 @@ extern "C" { int offset_ms; bool translate; + bool no_context; bool print_special_tokens; bool print_progress; bool print_realtime;