diff --git a/whisper.cpp b/whisper.cpp index 83b1ce3..9c4d3e8 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -4258,7 +4258,7 @@ int whisper_full_parallel( int n_samples, int n_processors) { if (n_processors == 1) { - return whisper_full(ctx, params, samples, n_samples); + return whisper_full_with_state(ctx, result_state, params, samples, n_samples); } int ret = 0; @@ -4269,16 +4269,14 @@ int whisper_full_parallel( std::vector states{}; - for (int i = 0; i < n_processors; i++) - { - states.push_back(whisper_init_state(ctx)); - } - // the calling thread will process the first chunk // while the other threads will process the remaining chunks std::vector workers(n_processors - 1); for (int i = 0; i < n_processors - 1; ++i) { + // create a new state for each thread + states.push_back(whisper_init_state(ctx)); + const int start_samples = offset_samples + (i + 1)*n_samples_per_processor; const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor; @@ -4291,7 +4289,7 @@ int whisper_full_parallel( params_cur.new_segment_callback = nullptr; params_cur.new_segment_callback_user_data = nullptr; - workers[i] = std::thread(whisper_full_with_state, ctx, states[i + 1], std::move(params_cur), samples + start_samples, n_samples_cur); + workers[i] = std::thread(whisper_full_with_state, ctx, states[i], std::move(params_cur), samples + start_samples, n_samples_cur); } { @@ -4300,15 +4298,15 @@ int whisper_full_parallel( // We need to disable the print real-time for this one as well, otherwise it will show only for the first chunk. params_cur.print_realtime = false; - // Run the first transformation using the first state and for the first chunk. - ret = whisper_full_with_state(ctx, states[0], std::move(params_cur), samples, offset_samples + n_samples_per_processor); + // Run the first transformation using given state but only for the first chunk. + ret = whisper_full_with_state(ctx, result_state, std::move(params_cur), samples, offset_samples + n_samples_per_processor); } for (int i = 0; i < n_processors - 1; ++i) { workers[i].join(); } - // combine results into ctx->result_all + // combine results into result_state->result_all from all other states for (int i = 0; i < n_processors - 1; ++i) { auto & results_i = states[i]->result_all;