diff --git a/examples/bench/bench.cpp b/examples/bench/bench.cpp index 49daaa0..fbfe4d1 100644 --- a/examples/bench/bench.cpp +++ b/examples/bench/bench.cpp @@ -55,6 +55,8 @@ int whisper_bench_encoder(const whisper_params & params) { struct whisper_context * ctx = whisper_init_from_file(params.model.c_str()); + struct whisper_state * state = whisper_init_state(ctx); + { fprintf(stderr, "\n"); fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", params.n_threads, std::thread::hardware_concurrency(), whisper_print_system_info()); @@ -65,17 +67,17 @@ int whisper_bench_encoder(const whisper_params & params) { return 2; } - if (int ret = whisper_set_mel(ctx, nullptr, 0, WHISPER_N_MEL)) { + if (int ret = whisper_set_mel(state, nullptr, 0, WHISPER_N_MEL)) { fprintf(stderr, "error: failed to set mel: %d\n", ret); return 3; } - if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) { + if (int ret = whisper_encode(ctx, state, 0, params.n_threads) != 0) { fprintf(stderr, "error: failed to encode model: %d\n", ret); return 4; } - whisper_print_timings(ctx); + whisper_print_timings(ctx, state); whisper_free(ctx); fprintf(stderr, "\n"); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index fbc9faf..0dba72f 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -191,11 +191,11 @@ struct whisper_print_user_data { const std::vector> * pcmf32s; }; -void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) { +void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) { const auto & params = *((whisper_print_user_data *) user_data)->params; const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s; - const int n_segments = whisper_full_n_segments(ctx); + const int n_segments = whisper_full_n_segments(state); std::string speaker = ""; @@ -211,8 +211,8 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi for (int i = s0; i < n_segments; i++) { if (!params.no_timestamps || params.diarize) { - t0 = whisper_full_get_segment_t0(ctx, i); - t1 = whisper_full_get_segment_t1(ctx, i); + t0 = whisper_full_get_segment_t0(state, i); + t1 = whisper_full_get_segment_t1(state, i); } if (!params.no_timestamps) { @@ -245,23 +245,23 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi } if (params.print_colors) { - for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { + for (int j = 0; j < whisper_full_n_tokens(state, i); ++j) { if (params.print_special == false) { - const whisper_token id = whisper_full_get_token_id(ctx, i, j); + const whisper_token id = whisper_full_get_token_id(state, i, j); if (id >= whisper_token_eot(ctx)) { continue; } } - const char * text = whisper_full_get_token_text(ctx, i, j); - const float p = whisper_full_get_token_p (ctx, i, j); + const char * text = whisper_full_get_token_text(ctx, state, i, j); + const float p = whisper_full_get_token_p (state, i, j); const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size())))); printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m"); } } else { - const char * text = whisper_full_get_segment_text(ctx, i); + const char * text = whisper_full_get_segment_text(state, i); printf("%s%s", speaker.c_str(), text); } @@ -275,7 +275,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi } } -bool output_txt(struct whisper_context * ctx, const char * fname) { +bool output_txt(struct whisper_context * ctx, struct whisper_state * state, const char * fname) { std::ofstream fout(fname); if (!fout.is_open()) { fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); @@ -284,16 +284,16 @@ bool output_txt(struct whisper_context * ctx, const char * fname) { fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); - const int n_segments = whisper_full_n_segments(ctx); + const int n_segments = whisper_full_n_segments(state); for (int i = 0; i < n_segments; ++i) { - const char * text = whisper_full_get_segment_text(ctx, i); + const char * text = whisper_full_get_segment_text(state, i); fout << text << "\n"; } return true; } -bool output_vtt(struct whisper_context * ctx, const char * fname) { +bool output_vtt(struct whisper_context * ctx, struct whisper_state * state, const char * fname) { std::ofstream fout(fname); if (!fout.is_open()) { fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); @@ -304,11 +304,11 @@ bool output_vtt(struct whisper_context * ctx, const char * fname) { fout << "WEBVTT\n\n"; - const int n_segments = whisper_full_n_segments(ctx); + const int n_segments = whisper_full_n_segments(state); for (int i = 0; i < n_segments; ++i) { - const char * text = whisper_full_get_segment_text(ctx, i); - const int64_t t0 = whisper_full_get_segment_t0(ctx, i); - const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + const char * text = whisper_full_get_segment_text(state, i); + const int64_t t0 = whisper_full_get_segment_t0(state, i); + const int64_t t1 = whisper_full_get_segment_t1(state, i); fout << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n"; fout << text << "\n\n"; @@ -317,7 +317,7 @@ bool output_vtt(struct whisper_context * ctx, const char * fname) { return true; } -bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_params & params) { +bool output_srt(struct whisper_context * ctx, struct whisper_state * state, const char * fname, const whisper_params & params) { std::ofstream fout(fname); if (!fout.is_open()) { fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); @@ -326,11 +326,11 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_ fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); - const int n_segments = whisper_full_n_segments(ctx); + const int n_segments = whisper_full_n_segments(state); for (int i = 0; i < n_segments; ++i) { - const char * text = whisper_full_get_segment_text(ctx, i); - const int64_t t0 = whisper_full_get_segment_t0(ctx, i); - const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + const char * text = whisper_full_get_segment_text(state, i); + const int64_t t0 = whisper_full_get_segment_t0(state, i); + const int64_t t1 = whisper_full_get_segment_t1(state, i); fout << i + 1 + params.offset_n << "\n"; fout << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n"; @@ -340,7 +340,7 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_ return true; } -bool output_csv(struct whisper_context * ctx, const char * fname) { +bool output_csv(struct whisper_context * ctx, struct whisper_state * state, const char * fname) { std::ofstream fout(fname); if (!fout.is_open()) { fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); @@ -349,11 +349,11 @@ bool output_csv(struct whisper_context * ctx, const char * fname) { fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); - const int n_segments = whisper_full_n_segments(ctx); + const int n_segments = whisper_full_n_segments(state); for (int i = 0; i < n_segments; ++i) { - const char * text = whisper_full_get_segment_text(ctx, i); - const int64_t t0 = whisper_full_get_segment_t0(ctx, i); - const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + const char * text = whisper_full_get_segment_text(state, i); + const int64_t t0 = whisper_full_get_segment_t0(state, i); + const int64_t t1 = whisper_full_get_segment_t1(state, i); //need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds. fout << 10 * t0 << ", " << 10 * t1 << ", \"" << text << "\"\n"; @@ -365,7 +365,7 @@ bool output_csv(struct whisper_context * ctx, const char * fname) { // karaoke video generation // outputs a bash script that uses ffmpeg to generate a video with the subtitles // TODO: font parameter adjustments -bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & /*params*/, float t_sec) { +bool output_wts(struct whisper_context * ctx, struct whisper_state * state, const char * fname, const char * fname_inp, const whisper_params & /*params*/, float t_sec) { std::ofstream fout(fname); fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); @@ -378,15 +378,15 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f fout << "ffmpeg -i " << fname_inp << " -f lavfi -i color=size=1200x120:duration=" << t_sec << ":rate=25:color=black -vf \""; - for (int i = 0; i < whisper_full_n_segments(ctx); i++) { - const int64_t t0 = whisper_full_get_segment_t0(ctx, i); - const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + for (int i = 0; i < whisper_full_n_segments(state); i++) { + const int64_t t0 = whisper_full_get_segment_t0(state, i); + const int64_t t1 = whisper_full_get_segment_t1(state, i); - const int n = whisper_full_n_tokens(ctx, i); + const int n = whisper_full_n_tokens(state, i); std::vector tokens(n); for (int j = 0; j < n; ++j) { - tokens[j] = whisper_full_get_token_data(ctx, i, j); + tokens[j] = whisper_full_get_token_data(state, i, j); } if (i > 0) { @@ -502,6 +502,9 @@ int main(int argc, char ** argv) { return 3; } + // state init for the current transformation + struct whisper_state* state = whisper_init_state(ctx); + // initial prompt std::vector prompt_tokens; @@ -682,14 +685,14 @@ int main(int argc, char ** argv) { { static bool is_aborted = false; // NOTE: this should be atomic to avoid data race - wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) { + wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void* user_data) { bool is_aborted = *(bool*)user_data; return !is_aborted; }; wparams.encoder_begin_callback_user_data = &is_aborted; } - if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) { + if (whisper_full_parallel(ctx, state, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) { fprintf(stderr, "%s: failed to process audio\n", argv[0]); return 10; } @@ -702,37 +705,38 @@ int main(int argc, char ** argv) { // output to text file if (params.output_txt) { const auto fname_txt = fname_outp + ".txt"; - output_txt(ctx, fname_txt.c_str()); + output_txt(ctx, state, fname_txt.c_str()); } // output to VTT file if (params.output_vtt) { const auto fname_vtt = fname_outp + ".vtt"; - output_vtt(ctx, fname_vtt.c_str()); + output_vtt(ctx, state, fname_vtt.c_str()); } // output to SRT file if (params.output_srt) { const auto fname_srt = fname_outp + ".srt"; - output_srt(ctx, fname_srt.c_str(), params); + output_srt(ctx, state, fname_srt.c_str(), params); } // output to WTS file if (params.output_wts) { const auto fname_wts = fname_outp + ".wts"; - output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE); + output_wts(ctx, state, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE); } // output to CSV file if (params.output_csv) { const auto fname_csv = fname_outp + ".csv"; - output_csv(ctx, fname_csv.c_str()); + output_csv(ctx, state, fname_csv.c_str()); } } } - whisper_print_timings(ctx); + whisper_print_timings(ctx, state); + whisper_free_state(state); whisper_free(ctx); return 0; diff --git a/whisper.cpp b/whisper.cpp index 0a54654..acb31da 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -567,7 +567,7 @@ struct whisper_state { whisper_decoder decoders[WHISPER_MAX_DECODERS] = {}; // memory buffers used by encode / decode contexts - std::vector buf_compute; + std::vector buf_compute{}; std::vector buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS]; int buf_last = 0; @@ -576,11 +576,11 @@ struct whisper_state { // decode output (2-dimensional array: [n_tokens][n_vocab]) std::vector logits; - std::vector result_all; - std::vector prompt_past; + std::vector result_all{}; + std::vector prompt_past{}; // work container used to avoid memory allocations - std::vector> logits_id; + std::vector> logits_id{}; mutable std::mt19937 rng; // used for sampling at t > 0.0 @@ -590,7 +590,7 @@ struct whisper_state { int64_t t_beg; int64_t t_last; whisper_token tid_last; - std::vector energy; // PCM signal energy + std::vector energy{}; // PCM signal energy // [EXPERIMENTAL] speed-up techniques int32_t exp_n_audio_ctx; // 0 - use default @@ -4248,6 +4248,7 @@ int whisper_full( int whisper_full_parallel( struct whisper_context * ctx, + struct whisper_state * result_state, struct whisper_full_params params, const float * samples, int n_samples, @@ -4262,8 +4263,6 @@ int whisper_full_parallel( const int n_samples_per_processor = (n_samples - offset_samples)/n_processors; const int64_t offset_t = (int64_t)params.offset_ms / 10.0; - auto aggregated_whisper_state = whisper_init_state(ctx); - std::vector states{}; for (int i = 0; i < n_processors; i++) @@ -4316,32 +4315,32 @@ int whisper_full_parallel( // make sure that segments are not overlapping - if (!aggregated_whisper_state->result_all.empty()) { - result.t0 = std::max(result.t0, aggregated_whisper_state->result_all.back().t1); + if (!result_state->result_all.empty()) { + result.t0 = std::max(result.t0, result_state->result_all.back().t1); } - aggregated_whisper_state->result_all.push_back(std::move(result)); + result_state->result_all.push_back(std::move(result)); // call the new_segment_callback for each segment if (params.new_segment_callback) { - params.new_segment_callback(ctx, aggregated_whisper_state, 1, params.new_segment_callback_user_data); + params.new_segment_callback(ctx, result_state, 1, params.new_segment_callback_user_data); } } - aggregated_whisper_state->t_mel_us += states[i]->t_mel_us; + result_state->t_mel_us += states[i]->t_mel_us; - aggregated_whisper_state->t_sample_us += states[i]->t_sample_us; - aggregated_whisper_state->t_encode_us += states[i]->t_encode_us; - aggregated_whisper_state->t_decode_us += states[i]->t_decode_us; + result_state->t_sample_us += states[i]->t_sample_us; + result_state->t_encode_us += states[i]->t_encode_us; + result_state->t_decode_us += states[i]->t_decode_us; whisper_free_state(states[i]); } // average the timings - aggregated_whisper_state->t_mel_us /= n_processors; - aggregated_whisper_state->t_sample_us /= n_processors; - aggregated_whisper_state->t_encode_us /= n_processors; - aggregated_whisper_state->t_decode_us /= n_processors; + result_state->t_mel_us /= n_processors; + result_state->t_sample_us /= n_processors; + result_state->t_encode_us /= n_processors; + result_state->t_decode_us /= n_processors; // print information about the audio boundaries fprintf(stderr, "\n"); @@ -4351,7 +4350,7 @@ int whisper_full_parallel( } fprintf(stderr, "%s: the transcription quality may be degraded near these boundaries\n", __func__); - whisper_free_state(aggregated_whisper_state); + whisper_free_state(result_state); return ret; } diff --git a/whisper.h b/whisper.h index 0f79602..a1e2ff2 100644 --- a/whisper.h +++ b/whisper.h @@ -344,6 +344,7 @@ extern "C" { // However, the transcription accuracy can be worse at the beginning and end of each chunk. WHISPER_API int whisper_full_parallel( struct whisper_context * ctx, + struct whisper_state * state, struct whisper_full_params params, const float * samples, int n_samples,