Fixed bench and main examples to use state.

pull/494/head
Sandro Hanea 2 years ago
parent 704566bb9f
commit b6aa7becb6

@ -55,6 +55,8 @@ int whisper_bench_encoder(const whisper_params & params) {
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str()); struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
struct whisper_state * state = whisper_init_state(ctx);
{ {
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", params.n_threads, std::thread::hardware_concurrency(), whisper_print_system_info()); fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", params.n_threads, std::thread::hardware_concurrency(), whisper_print_system_info());
@ -65,17 +67,17 @@ int whisper_bench_encoder(const whisper_params & params) {
return 2; return 2;
} }
if (int ret = whisper_set_mel(ctx, nullptr, 0, WHISPER_N_MEL)) { if (int ret = whisper_set_mel(state, nullptr, 0, WHISPER_N_MEL)) {
fprintf(stderr, "error: failed to set mel: %d\n", ret); fprintf(stderr, "error: failed to set mel: %d\n", ret);
return 3; return 3;
} }
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) { if (int ret = whisper_encode(ctx, state, 0, params.n_threads) != 0) {
fprintf(stderr, "error: failed to encode model: %d\n", ret); fprintf(stderr, "error: failed to encode model: %d\n", ret);
return 4; return 4;
} }
whisper_print_timings(ctx); whisper_print_timings(ctx, state);
whisper_free(ctx); whisper_free(ctx);
fprintf(stderr, "\n"); fprintf(stderr, "\n");

@ -191,11 +191,11 @@ struct whisper_print_user_data {
const std::vector<std::vector<float>> * pcmf32s; const std::vector<std::vector<float>> * pcmf32s;
}; };
void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) { void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) {
const auto & params = *((whisper_print_user_data *) user_data)->params; const auto & params = *((whisper_print_user_data *) user_data)->params;
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s; const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
const int n_segments = whisper_full_n_segments(ctx); const int n_segments = whisper_full_n_segments(state);
std::string speaker = ""; std::string speaker = "";
@ -211,8 +211,8 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
for (int i = s0; i < n_segments; i++) { for (int i = s0; i < n_segments; i++) {
if (!params.no_timestamps || params.diarize) { if (!params.no_timestamps || params.diarize) {
t0 = whisper_full_get_segment_t0(ctx, i); t0 = whisper_full_get_segment_t0(state, i);
t1 = whisper_full_get_segment_t1(ctx, i); t1 = whisper_full_get_segment_t1(state, i);
} }
if (!params.no_timestamps) { if (!params.no_timestamps) {
@ -245,23 +245,23 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
} }
if (params.print_colors) { if (params.print_colors) {
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { for (int j = 0; j < whisper_full_n_tokens(state, i); ++j) {
if (params.print_special == false) { if (params.print_special == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j); const whisper_token id = whisper_full_get_token_id(state, i, j);
if (id >= whisper_token_eot(ctx)) { if (id >= whisper_token_eot(ctx)) {
continue; continue;
} }
} }
const char * text = whisper_full_get_token_text(ctx, i, j); const char * text = whisper_full_get_token_text(ctx, state, i, j);
const float p = whisper_full_get_token_p (ctx, i, j); const float p = whisper_full_get_token_p (state, i, j);
const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size())))); const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m"); printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
} }
} else { } else {
const char * text = whisper_full_get_segment_text(ctx, i); const char * text = whisper_full_get_segment_text(state, i);
printf("%s%s", speaker.c_str(), text); printf("%s%s", speaker.c_str(), text);
} }
@ -275,7 +275,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
} }
} }
bool output_txt(struct whisper_context * ctx, const char * fname) { bool output_txt(struct whisper_context * ctx, struct whisper_state * state, const char * fname) {
std::ofstream fout(fname); std::ofstream fout(fname);
if (!fout.is_open()) { if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@ -284,16 +284,16 @@ bool output_txt(struct whisper_context * ctx, const char * fname) {
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
const int n_segments = whisper_full_n_segments(ctx); const int n_segments = whisper_full_n_segments(state);
for (int i = 0; i < n_segments; ++i) { for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i); const char * text = whisper_full_get_segment_text(state, i);
fout << text << "\n"; fout << text << "\n";
} }
return true; return true;
} }
bool output_vtt(struct whisper_context * ctx, const char * fname) { bool output_vtt(struct whisper_context * ctx, struct whisper_state * state, const char * fname) {
std::ofstream fout(fname); std::ofstream fout(fname);
if (!fout.is_open()) { if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@ -304,11 +304,11 @@ bool output_vtt(struct whisper_context * ctx, const char * fname) {
fout << "WEBVTT\n\n"; fout << "WEBVTT\n\n";
const int n_segments = whisper_full_n_segments(ctx); const int n_segments = whisper_full_n_segments(state);
for (int i = 0; i < n_segments; ++i) { for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i); const char * text = whisper_full_get_segment_text(state, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t0 = whisper_full_get_segment_t0(state, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(state, i);
fout << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n"; fout << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n";
fout << text << "\n\n"; fout << text << "\n\n";
@ -317,7 +317,7 @@ bool output_vtt(struct whisper_context * ctx, const char * fname) {
return true; return true;
} }
bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_params & params) { bool output_srt(struct whisper_context * ctx, struct whisper_state * state, const char * fname, const whisper_params & params) {
std::ofstream fout(fname); std::ofstream fout(fname);
if (!fout.is_open()) { if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@ -326,11 +326,11 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
const int n_segments = whisper_full_n_segments(ctx); const int n_segments = whisper_full_n_segments(state);
for (int i = 0; i < n_segments; ++i) { for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i); const char * text = whisper_full_get_segment_text(state, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t0 = whisper_full_get_segment_t0(state, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(state, i);
fout << i + 1 + params.offset_n << "\n"; fout << i + 1 + params.offset_n << "\n";
fout << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n"; fout << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n";
@ -340,7 +340,7 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_
return true; return true;
} }
bool output_csv(struct whisper_context * ctx, const char * fname) { bool output_csv(struct whisper_context * ctx, struct whisper_state * state, const char * fname) {
std::ofstream fout(fname); std::ofstream fout(fname);
if (!fout.is_open()) { if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@ -349,11 +349,11 @@ bool output_csv(struct whisper_context * ctx, const char * fname) {
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
const int n_segments = whisper_full_n_segments(ctx); const int n_segments = whisper_full_n_segments(state);
for (int i = 0; i < n_segments; ++i) { for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i); const char * text = whisper_full_get_segment_text(state, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t0 = whisper_full_get_segment_t0(state, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(state, i);
//need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds. //need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
fout << 10 * t0 << ", " << 10 * t1 << ", \"" << text << "\"\n"; fout << 10 * t0 << ", " << 10 * t1 << ", \"" << text << "\"\n";
@ -365,7 +365,7 @@ bool output_csv(struct whisper_context * ctx, const char * fname) {
// karaoke video generation // karaoke video generation
// outputs a bash script that uses ffmpeg to generate a video with the subtitles // outputs a bash script that uses ffmpeg to generate a video with the subtitles
// TODO: font parameter adjustments // TODO: font parameter adjustments
bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & /*params*/, float t_sec) { bool output_wts(struct whisper_context * ctx, struct whisper_state * state, const char * fname, const char * fname_inp, const whisper_params & /*params*/, float t_sec) {
std::ofstream fout(fname); std::ofstream fout(fname);
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
@ -378,15 +378,15 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
fout << "ffmpeg -i " << fname_inp << " -f lavfi -i color=size=1200x120:duration=" << t_sec << ":rate=25:color=black -vf \""; fout << "ffmpeg -i " << fname_inp << " -f lavfi -i color=size=1200x120:duration=" << t_sec << ":rate=25:color=black -vf \"";
for (int i = 0; i < whisper_full_n_segments(ctx); i++) { for (int i = 0; i < whisper_full_n_segments(state); i++) {
const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t0 = whisper_full_get_segment_t0(state, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(state, i);
const int n = whisper_full_n_tokens(ctx, i); const int n = whisper_full_n_tokens(state, i);
std::vector<whisper_token_data> tokens(n); std::vector<whisper_token_data> tokens(n);
for (int j = 0; j < n; ++j) { for (int j = 0; j < n; ++j) {
tokens[j] = whisper_full_get_token_data(ctx, i, j); tokens[j] = whisper_full_get_token_data(state, i, j);
} }
if (i > 0) { if (i > 0) {
@ -502,6 +502,9 @@ int main(int argc, char ** argv) {
return 3; return 3;
} }
// state init for the current transformation
struct whisper_state* state = whisper_init_state(ctx);
// initial prompt // initial prompt
std::vector<whisper_token> prompt_tokens; std::vector<whisper_token> prompt_tokens;
@ -682,14 +685,14 @@ int main(int argc, char ** argv) {
{ {
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) { wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void* user_data) {
bool is_aborted = *(bool*)user_data; bool is_aborted = *(bool*)user_data;
return !is_aborted; return !is_aborted;
}; };
wparams.encoder_begin_callback_user_data = &is_aborted; wparams.encoder_begin_callback_user_data = &is_aborted;
} }
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) { if (whisper_full_parallel(ctx, state, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
fprintf(stderr, "%s: failed to process audio\n", argv[0]); fprintf(stderr, "%s: failed to process audio\n", argv[0]);
return 10; return 10;
} }
@ -702,37 +705,38 @@ int main(int argc, char ** argv) {
// output to text file // output to text file
if (params.output_txt) { if (params.output_txt) {
const auto fname_txt = fname_outp + ".txt"; const auto fname_txt = fname_outp + ".txt";
output_txt(ctx, fname_txt.c_str()); output_txt(ctx, state, fname_txt.c_str());
} }
// output to VTT file // output to VTT file
if (params.output_vtt) { if (params.output_vtt) {
const auto fname_vtt = fname_outp + ".vtt"; const auto fname_vtt = fname_outp + ".vtt";
output_vtt(ctx, fname_vtt.c_str()); output_vtt(ctx, state, fname_vtt.c_str());
} }
// output to SRT file // output to SRT file
if (params.output_srt) { if (params.output_srt) {
const auto fname_srt = fname_outp + ".srt"; const auto fname_srt = fname_outp + ".srt";
output_srt(ctx, fname_srt.c_str(), params); output_srt(ctx, state, fname_srt.c_str(), params);
} }
// output to WTS file // output to WTS file
if (params.output_wts) { if (params.output_wts) {
const auto fname_wts = fname_outp + ".wts"; const auto fname_wts = fname_outp + ".wts";
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE); output_wts(ctx, state, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
} }
// output to CSV file // output to CSV file
if (params.output_csv) { if (params.output_csv) {
const auto fname_csv = fname_outp + ".csv"; const auto fname_csv = fname_outp + ".csv";
output_csv(ctx, fname_csv.c_str()); output_csv(ctx, state, fname_csv.c_str());
} }
} }
} }
whisper_print_timings(ctx); whisper_print_timings(ctx, state);
whisper_free_state(state);
whisper_free(ctx); whisper_free(ctx);
return 0; return 0;

@ -567,7 +567,7 @@ struct whisper_state {
whisper_decoder decoders[WHISPER_MAX_DECODERS] = {}; whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
// memory buffers used by encode / decode contexts // memory buffers used by encode / decode contexts
std::vector<uint8_t> buf_compute; std::vector<uint8_t> buf_compute{};
std::vector<uint8_t> buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS]; std::vector<uint8_t> buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS];
int buf_last = 0; int buf_last = 0;
@ -576,11 +576,11 @@ struct whisper_state {
// decode output (2-dimensional array: [n_tokens][n_vocab]) // decode output (2-dimensional array: [n_tokens][n_vocab])
std::vector<float> logits; std::vector<float> logits;
std::vector<whisper_segment> result_all; std::vector<whisper_segment> result_all{};
std::vector<whisper_token> prompt_past; std::vector<whisper_token> prompt_past{};
// work container used to avoid memory allocations // work container used to avoid memory allocations
std::vector<std::pair<double, whisper_vocab::id>> logits_id; std::vector<std::pair<double, whisper_vocab::id>> logits_id{};
mutable std::mt19937 rng; // used for sampling at t > 0.0 mutable std::mt19937 rng; // used for sampling at t > 0.0
@ -590,7 +590,7 @@ struct whisper_state {
int64_t t_beg; int64_t t_beg;
int64_t t_last; int64_t t_last;
whisper_token tid_last; whisper_token tid_last;
std::vector<float> energy; // PCM signal energy std::vector<float> energy{}; // PCM signal energy
// [EXPERIMENTAL] speed-up techniques // [EXPERIMENTAL] speed-up techniques
int32_t exp_n_audio_ctx; // 0 - use default int32_t exp_n_audio_ctx; // 0 - use default
@ -4248,6 +4248,7 @@ int whisper_full(
int whisper_full_parallel( int whisper_full_parallel(
struct whisper_context * ctx, struct whisper_context * ctx,
struct whisper_state * result_state,
struct whisper_full_params params, struct whisper_full_params params,
const float * samples, const float * samples,
int n_samples, int n_samples,
@ -4262,8 +4263,6 @@ int whisper_full_parallel(
const int n_samples_per_processor = (n_samples - offset_samples)/n_processors; const int n_samples_per_processor = (n_samples - offset_samples)/n_processors;
const int64_t offset_t = (int64_t)params.offset_ms / 10.0; const int64_t offset_t = (int64_t)params.offset_ms / 10.0;
auto aggregated_whisper_state = whisper_init_state(ctx);
std::vector<whisper_state*> states{}; std::vector<whisper_state*> states{};
for (int i = 0; i < n_processors; i++) for (int i = 0; i < n_processors; i++)
@ -4316,32 +4315,32 @@ int whisper_full_parallel(
// make sure that segments are not overlapping // make sure that segments are not overlapping
if (!aggregated_whisper_state->result_all.empty()) { if (!result_state->result_all.empty()) {
result.t0 = std::max(result.t0, aggregated_whisper_state->result_all.back().t1); result.t0 = std::max(result.t0, result_state->result_all.back().t1);
} }
aggregated_whisper_state->result_all.push_back(std::move(result)); result_state->result_all.push_back(std::move(result));
// call the new_segment_callback for each segment // call the new_segment_callback for each segment
if (params.new_segment_callback) { if (params.new_segment_callback) {
params.new_segment_callback(ctx, aggregated_whisper_state, 1, params.new_segment_callback_user_data); params.new_segment_callback(ctx, result_state, 1, params.new_segment_callback_user_data);
} }
} }
aggregated_whisper_state->t_mel_us += states[i]->t_mel_us; result_state->t_mel_us += states[i]->t_mel_us;
aggregated_whisper_state->t_sample_us += states[i]->t_sample_us; result_state->t_sample_us += states[i]->t_sample_us;
aggregated_whisper_state->t_encode_us += states[i]->t_encode_us; result_state->t_encode_us += states[i]->t_encode_us;
aggregated_whisper_state->t_decode_us += states[i]->t_decode_us; result_state->t_decode_us += states[i]->t_decode_us;
whisper_free_state(states[i]); whisper_free_state(states[i]);
} }
// average the timings // average the timings
aggregated_whisper_state->t_mel_us /= n_processors; result_state->t_mel_us /= n_processors;
aggregated_whisper_state->t_sample_us /= n_processors; result_state->t_sample_us /= n_processors;
aggregated_whisper_state->t_encode_us /= n_processors; result_state->t_encode_us /= n_processors;
aggregated_whisper_state->t_decode_us /= n_processors; result_state->t_decode_us /= n_processors;
// print information about the audio boundaries // print information about the audio boundaries
fprintf(stderr, "\n"); fprintf(stderr, "\n");
@ -4351,7 +4350,7 @@ int whisper_full_parallel(
} }
fprintf(stderr, "%s: the transcription quality may be degraded near these boundaries\n", __func__); fprintf(stderr, "%s: the transcription quality may be degraded near these boundaries\n", __func__);
whisper_free_state(aggregated_whisper_state); whisper_free_state(result_state);
return ret; return ret;
} }

@ -344,6 +344,7 @@ extern "C" {
// However, the transcription accuracy can be worse at the beginning and end of each chunk. // However, the transcription accuracy can be worse at the beginning and end of each chunk.
WHISPER_API int whisper_full_parallel( WHISPER_API int whisper_full_parallel(
struct whisper_context * ctx, struct whisper_context * ctx,
struct whisper_state * state,
struct whisper_full_params params, struct whisper_full_params params,
const float * samples, const float * samples,
int n_samples, int n_samples,

Loading…
Cancel
Save