Fixed bench and main examples to use state.

pull/494/head
Sandro Hanea 2 years ago
parent 704566bb9f
commit b6aa7becb6

@ -55,6 +55,8 @@ int whisper_bench_encoder(const whisper_params & params) {
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
struct whisper_state * state = whisper_init_state(ctx);
{
fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", params.n_threads, std::thread::hardware_concurrency(), whisper_print_system_info());
@ -65,17 +67,17 @@ int whisper_bench_encoder(const whisper_params & params) {
return 2;
}
if (int ret = whisper_set_mel(ctx, nullptr, 0, WHISPER_N_MEL)) {
if (int ret = whisper_set_mel(state, nullptr, 0, WHISPER_N_MEL)) {
fprintf(stderr, "error: failed to set mel: %d\n", ret);
return 3;
}
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
if (int ret = whisper_encode(ctx, state, 0, params.n_threads) != 0) {
fprintf(stderr, "error: failed to encode model: %d\n", ret);
return 4;
}
whisper_print_timings(ctx);
whisper_print_timings(ctx, state);
whisper_free(ctx);
fprintf(stderr, "\n");

@ -191,11 +191,11 @@ struct whisper_print_user_data {
const std::vector<std::vector<float>> * pcmf32s;
};
void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) {
const auto & params = *((whisper_print_user_data *) user_data)->params;
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
const int n_segments = whisper_full_n_segments(ctx);
const int n_segments = whisper_full_n_segments(state);
std::string speaker = "";
@ -211,8 +211,8 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
for (int i = s0; i < n_segments; i++) {
if (!params.no_timestamps || params.diarize) {
t0 = whisper_full_get_segment_t0(ctx, i);
t1 = whisper_full_get_segment_t1(ctx, i);
t0 = whisper_full_get_segment_t0(state, i);
t1 = whisper_full_get_segment_t1(state, i);
}
if (!params.no_timestamps) {
@ -245,23 +245,23 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
}
if (params.print_colors) {
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
for (int j = 0; j < whisper_full_n_tokens(state, i); ++j) {
if (params.print_special == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
const whisper_token id = whisper_full_get_token_id(state, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
}
}
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
const char * text = whisper_full_get_token_text(ctx, state, i, j);
const float p = whisper_full_get_token_p (state, i, j);
const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
}
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
const char * text = whisper_full_get_segment_text(state, i);
printf("%s%s", speaker.c_str(), text);
}
@ -275,7 +275,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
}
}
bool output_txt(struct whisper_context * ctx, const char * fname) {
bool output_txt(struct whisper_context * ctx, struct whisper_state * state, const char * fname) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@ -284,16 +284,16 @@ bool output_txt(struct whisper_context * ctx, const char * fname) {
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
const int n_segments = whisper_full_n_segments(ctx);
const int n_segments = whisper_full_n_segments(state);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const char * text = whisper_full_get_segment_text(state, i);
fout << text << "\n";
}
return true;
}
bool output_vtt(struct whisper_context * ctx, const char * fname) {
bool output_vtt(struct whisper_context * ctx, struct whisper_state * state, const char * fname) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@ -304,11 +304,11 @@ bool output_vtt(struct whisper_context * ctx, const char * fname) {
fout << "WEBVTT\n\n";
const int n_segments = whisper_full_n_segments(ctx);
const int n_segments = whisper_full_n_segments(state);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
const char * text = whisper_full_get_segment_text(state, i);
const int64_t t0 = whisper_full_get_segment_t0(state, i);
const int64_t t1 = whisper_full_get_segment_t1(state, i);
fout << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n";
fout << text << "\n\n";
@ -317,7 +317,7 @@ bool output_vtt(struct whisper_context * ctx, const char * fname) {
return true;
}
bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_params & params) {
bool output_srt(struct whisper_context * ctx, struct whisper_state * state, const char * fname, const whisper_params & params) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@ -326,11 +326,11 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
const int n_segments = whisper_full_n_segments(ctx);
const int n_segments = whisper_full_n_segments(state);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
const char * text = whisper_full_get_segment_text(state, i);
const int64_t t0 = whisper_full_get_segment_t0(state, i);
const int64_t t1 = whisper_full_get_segment_t1(state, i);
fout << i + 1 + params.offset_n << "\n";
fout << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n";
@ -340,7 +340,7 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_
return true;
}
bool output_csv(struct whisper_context * ctx, const char * fname) {
bool output_csv(struct whisper_context * ctx, struct whisper_state * state, const char * fname) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@ -349,11 +349,11 @@ bool output_csv(struct whisper_context * ctx, const char * fname) {
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
const int n_segments = whisper_full_n_segments(ctx);
const int n_segments = whisper_full_n_segments(state);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
const char * text = whisper_full_get_segment_text(state, i);
const int64_t t0 = whisper_full_get_segment_t0(state, i);
const int64_t t1 = whisper_full_get_segment_t1(state, i);
//need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
fout << 10 * t0 << ", " << 10 * t1 << ", \"" << text << "\"\n";
@ -365,7 +365,7 @@ bool output_csv(struct whisper_context * ctx, const char * fname) {
// karaoke video generation
// outputs a bash script that uses ffmpeg to generate a video with the subtitles
// TODO: font parameter adjustments
bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & /*params*/, float t_sec) {
bool output_wts(struct whisper_context * ctx, struct whisper_state * state, const char * fname, const char * fname_inp, const whisper_params & /*params*/, float t_sec) {
std::ofstream fout(fname);
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
@ -378,15 +378,15 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
fout << "ffmpeg -i " << fname_inp << " -f lavfi -i color=size=1200x120:duration=" << t_sec << ":rate=25:color=black -vf \"";
for (int i = 0; i < whisper_full_n_segments(ctx); i++) {
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
for (int i = 0; i < whisper_full_n_segments(state); i++) {
const int64_t t0 = whisper_full_get_segment_t0(state, i);
const int64_t t1 = whisper_full_get_segment_t1(state, i);
const int n = whisper_full_n_tokens(ctx, i);
const int n = whisper_full_n_tokens(state, i);
std::vector<whisper_token_data> tokens(n);
for (int j = 0; j < n; ++j) {
tokens[j] = whisper_full_get_token_data(ctx, i, j);
tokens[j] = whisper_full_get_token_data(state, i, j);
}
if (i > 0) {
@ -502,6 +502,9 @@ int main(int argc, char ** argv) {
return 3;
}
// state init for the current transformation
struct whisper_state* state = whisper_init_state(ctx);
// initial prompt
std::vector<whisper_token> prompt_tokens;
@ -682,14 +685,14 @@ int main(int argc, char ** argv) {
{
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) {
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void* user_data) {
bool is_aborted = *(bool*)user_data;
return !is_aborted;
};
wparams.encoder_begin_callback_user_data = &is_aborted;
}
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
if (whisper_full_parallel(ctx, state, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
return 10;
}
@ -702,37 +705,38 @@ int main(int argc, char ** argv) {
// output to text file
if (params.output_txt) {
const auto fname_txt = fname_outp + ".txt";
output_txt(ctx, fname_txt.c_str());
output_txt(ctx, state, fname_txt.c_str());
}
// output to VTT file
if (params.output_vtt) {
const auto fname_vtt = fname_outp + ".vtt";
output_vtt(ctx, fname_vtt.c_str());
output_vtt(ctx, state, fname_vtt.c_str());
}
// output to SRT file
if (params.output_srt) {
const auto fname_srt = fname_outp + ".srt";
output_srt(ctx, fname_srt.c_str(), params);
output_srt(ctx, state, fname_srt.c_str(), params);
}
// output to WTS file
if (params.output_wts) {
const auto fname_wts = fname_outp + ".wts";
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
output_wts(ctx, state, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
}
// output to CSV file
if (params.output_csv) {
const auto fname_csv = fname_outp + ".csv";
output_csv(ctx, fname_csv.c_str());
output_csv(ctx, state, fname_csv.c_str());
}
}
}
whisper_print_timings(ctx);
whisper_print_timings(ctx, state);
whisper_free_state(state);
whisper_free(ctx);
return 0;

@ -567,7 +567,7 @@ struct whisper_state {
whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
// memory buffers used by encode / decode contexts
std::vector<uint8_t> buf_compute;
std::vector<uint8_t> buf_compute{};
std::vector<uint8_t> buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS];
int buf_last = 0;
@ -576,11 +576,11 @@ struct whisper_state {
// decode output (2-dimensional array: [n_tokens][n_vocab])
std::vector<float> logits;
std::vector<whisper_segment> result_all;
std::vector<whisper_token> prompt_past;
std::vector<whisper_segment> result_all{};
std::vector<whisper_token> prompt_past{};
// work container used to avoid memory allocations
std::vector<std::pair<double, whisper_vocab::id>> logits_id;
std::vector<std::pair<double, whisper_vocab::id>> logits_id{};
mutable std::mt19937 rng; // used for sampling at t > 0.0
@ -590,7 +590,7 @@ struct whisper_state {
int64_t t_beg;
int64_t t_last;
whisper_token tid_last;
std::vector<float> energy; // PCM signal energy
std::vector<float> energy{}; // PCM signal energy
// [EXPERIMENTAL] speed-up techniques
int32_t exp_n_audio_ctx; // 0 - use default
@ -4248,6 +4248,7 @@ int whisper_full(
int whisper_full_parallel(
struct whisper_context * ctx,
struct whisper_state * result_state,
struct whisper_full_params params,
const float * samples,
int n_samples,
@ -4262,8 +4263,6 @@ int whisper_full_parallel(
const int n_samples_per_processor = (n_samples - offset_samples)/n_processors;
const int64_t offset_t = (int64_t)params.offset_ms / 10.0;
auto aggregated_whisper_state = whisper_init_state(ctx);
std::vector<whisper_state*> states{};
for (int i = 0; i < n_processors; i++)
@ -4316,32 +4315,32 @@ int whisper_full_parallel(
// make sure that segments are not overlapping
if (!aggregated_whisper_state->result_all.empty()) {
result.t0 = std::max(result.t0, aggregated_whisper_state->result_all.back().t1);
if (!result_state->result_all.empty()) {
result.t0 = std::max(result.t0, result_state->result_all.back().t1);
}
aggregated_whisper_state->result_all.push_back(std::move(result));
result_state->result_all.push_back(std::move(result));
// call the new_segment_callback for each segment
if (params.new_segment_callback) {
params.new_segment_callback(ctx, aggregated_whisper_state, 1, params.new_segment_callback_user_data);
params.new_segment_callback(ctx, result_state, 1, params.new_segment_callback_user_data);
}
}
aggregated_whisper_state->t_mel_us += states[i]->t_mel_us;
result_state->t_mel_us += states[i]->t_mel_us;
aggregated_whisper_state->t_sample_us += states[i]->t_sample_us;
aggregated_whisper_state->t_encode_us += states[i]->t_encode_us;
aggregated_whisper_state->t_decode_us += states[i]->t_decode_us;
result_state->t_sample_us += states[i]->t_sample_us;
result_state->t_encode_us += states[i]->t_encode_us;
result_state->t_decode_us += states[i]->t_decode_us;
whisper_free_state(states[i]);
}
// average the timings
aggregated_whisper_state->t_mel_us /= n_processors;
aggregated_whisper_state->t_sample_us /= n_processors;
aggregated_whisper_state->t_encode_us /= n_processors;
aggregated_whisper_state->t_decode_us /= n_processors;
result_state->t_mel_us /= n_processors;
result_state->t_sample_us /= n_processors;
result_state->t_encode_us /= n_processors;
result_state->t_decode_us /= n_processors;
// print information about the audio boundaries
fprintf(stderr, "\n");
@ -4351,7 +4350,7 @@ int whisper_full_parallel(
}
fprintf(stderr, "%s: the transcription quality may be degraded near these boundaries\n", __func__);
whisper_free_state(aggregated_whisper_state);
whisper_free_state(result_state);
return ret;
}

@ -344,6 +344,7 @@ extern "C" {
// However, the transcription accuracy can be worse at the beginning and end of each chunk.
WHISPER_API int whisper_full_parallel(
struct whisper_context * ctx,
struct whisper_state * state,
struct whisper_full_params params,
const float * samples,
int n_samples,

Loading…
Cancel
Save