Fixed whisper_full_parallel when only 1 thread + optimized the other.

pull/494/head
Sandro Hanea 2 years ago
parent 8841840226
commit 0cd3bdb1ff

@ -4258,7 +4258,7 @@ int whisper_full_parallel(
int n_samples,
int n_processors) {
if (n_processors == 1) {
return whisper_full(ctx, params, samples, n_samples);
return whisper_full_with_state(ctx, result_state, params, samples, n_samples);
}
int ret = 0;
@ -4269,16 +4269,14 @@ int whisper_full_parallel(
std::vector<whisper_state*> states{};
for (int i = 0; i < n_processors; i++)
{
states.push_back(whisper_init_state(ctx));
}
// the calling thread will process the first chunk
// while the other threads will process the remaining chunks
std::vector<std::thread> workers(n_processors - 1);
for (int i = 0; i < n_processors - 1; ++i) {
// create a new state for each thread
states.push_back(whisper_init_state(ctx));
const int start_samples = offset_samples + (i + 1)*n_samples_per_processor;
const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor;
@ -4291,7 +4289,7 @@ int whisper_full_parallel(
params_cur.new_segment_callback = nullptr;
params_cur.new_segment_callback_user_data = nullptr;
workers[i] = std::thread(whisper_full_with_state, ctx, states[i + 1], std::move(params_cur), samples + start_samples, n_samples_cur);
workers[i] = std::thread(whisper_full_with_state, ctx, states[i], std::move(params_cur), samples + start_samples, n_samples_cur);
}
{
@ -4300,15 +4298,15 @@ int whisper_full_parallel(
// We need to disable the print real-time for this one as well, otherwise it will show only for the first chunk.
params_cur.print_realtime = false;
// Run the first transformation using the first state and for the first chunk.
ret = whisper_full_with_state(ctx, states[0], std::move(params_cur), samples, offset_samples + n_samples_per_processor);
// Run the first transformation using given state but only for the first chunk.
ret = whisper_full_with_state(ctx, result_state, std::move(params_cur), samples, offset_samples + n_samples_per_processor);
}
for (int i = 0; i < n_processors - 1; ++i) {
workers[i].join();
}
// combine results into ctx->result_all
// combine results into result_state->result_all from all other states
for (int i = 0; i < n_processors - 1; ++i) {
auto & results_i = states[i]->result_all;

Loading…
Cancel
Save