Fixed comments

pull/523/head
Sandro Hanea 2 years ago
parent 4608a7524e
commit 4d5317101d

@ -567,7 +567,7 @@ struct whisper_state {
whisper_decoder decoders[WHISPER_MAX_DECODERS] = {}; whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
// memory buffers used by encode / decode contexts // memory buffers used by encode / decode contexts
std::vector<uint8_t> buf_compute{}; std::vector<uint8_t> buf_compute;
std::vector<uint8_t> buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS]; std::vector<uint8_t> buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS];
int buf_last = 0; int buf_last = 0;
@ -576,11 +576,11 @@ struct whisper_state {
// decode output (2-dimensional array: [n_tokens][n_vocab]) // decode output (2-dimensional array: [n_tokens][n_vocab])
std::vector<float> logits; std::vector<float> logits;
std::vector<whisper_segment> result_all{}; std::vector<whisper_segment> result_all;
std::vector<whisper_token> prompt_past{}; std::vector<whisper_token> prompt_past;
// work container used to avoid memory allocations // work container used to avoid memory allocations
std::vector<std::pair<double, whisper_vocab::id>> logits_id{}; std::vector<std::pair<double, whisper_vocab::id>> logits_id;
mutable std::mt19937 rng; // used for sampling at t > 0.0 mutable std::mt19937 rng; // used for sampling at t > 0.0
@ -590,7 +590,7 @@ struct whisper_state {
int64_t t_beg = 0; int64_t t_beg = 0;
int64_t t_last = 0; int64_t t_last = 0;
whisper_token tid_last; whisper_token tid_last;
std::vector<float> energy{}; // PCM signal energy std::vector<float> energy; // PCM signal energy
// [EXPERIMENTAL] speed-up techniques // [EXPERIMENTAL] speed-up techniques
int32_t exp_n_audio_ctx = 0; // 0 - use default int32_t exp_n_audio_ctx = 0; // 0 - use default
@ -601,8 +601,7 @@ struct whisper_state {
if (i == -1) { if (i == -1) {
last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, }); last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, });
} } else {
else {
auto & buf = buf_scratch[i]; auto & buf = buf_scratch[i];
last_size = ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), }); last_size = ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), });
} }
@ -637,7 +636,7 @@ struct whisper_context {
whisper_model model; whisper_model model;
whisper_vocab vocab; whisper_vocab vocab;
whisper_state* default_state = nullptr; whisper_state * state = nullptr;
}; };
template<typename T> template<typename T>
@ -2475,19 +2474,6 @@ static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, cons
return tokens; return tokens;
} }
// Initializes default state in the context
// returns true if new state was initiallized or false if the state already exists
//
bool whisper_init_default_state(struct whisper_context * wctx)
{
if (wctx->default_state == nullptr) {
wctx->default_state = whisper_init_state(wctx);
return true;
}
return false;
}
// //
// interface implementation // interface implementation
// //
@ -2540,7 +2526,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
return state; return state;
} }
struct whisper_context * whisper_init_from_file(const char * path_model) { struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
whisper_model_loader loader = {}; whisper_model_loader loader = {};
fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model); fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model);
@ -2568,10 +2554,10 @@ struct whisper_context * whisper_init_from_file(const char * path_model) {
fin->close(); fin->close();
}; };
return whisper_init(&loader); return whisper_init_no_state(&loader);
} }
struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) { struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size) {
struct buf_context { struct buf_context {
uint8_t* buffer; uint8_t* buffer;
size_t size; size_t size;
@ -2604,10 +2590,10 @@ struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_s
loader.close = [](void * /*ctx*/) { }; loader.close = [](void * /*ctx*/) { };
return whisper_init(&loader); return whisper_init_no_state(&loader);
} }
struct whisper_context * whisper_init(struct whisper_model_loader * loader) { struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader) {
ggml_time_init(); ggml_time_init();
whisper_context * ctx = new whisper_context; whisper_context * ctx = new whisper_context;
@ -2624,6 +2610,51 @@ struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
return ctx; return ctx;
} }
struct whisper_context * whisper_init_from_file(const char * path_model) {
whisper_context * ctx = whisper_init_from_file_no_state(path_model);
if (!ctx) {
return nullptr;
}
ctx->state = whisper_init_state(ctx);
if (!ctx->state) {
whisper_free(ctx);
return nullptr;
}
return ctx;
}
struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) {
whisper_context * ctx = whisper_init_from_buffer_no_state(buffer, buffer_size);
if (!ctx) {
return nullptr;
}
ctx->state = whisper_init_state(ctx);
if (!ctx->state) {
whisper_free(ctx);
return nullptr;
}
return ctx;
}
struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
whisper_context * ctx = whisper_init_no_state(loader);
if (!ctx) {
return nullptr;
}
ctx->state = whisper_init_state(ctx);
if (!ctx->state) {
whisper_free(ctx);
return nullptr;
}
return ctx;
}
void whisper_free_state(struct whisper_state * state) void whisper_free_state(struct whisper_state * state)
{ {
if (state) { if (state) {
@ -2648,15 +2679,13 @@ void whisper_free(struct whisper_context * ctx) {
delete ctx->model.buf; delete ctx->model.buf;
} }
whisper_free_state(ctx->default_state); whisper_free_state(ctx->state);
delete ctx; delete ctx;
} }
} }
int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
whisper_init_default_state(ctx);
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) { if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) {
fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__); fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
return -1; return -1;
@ -2666,9 +2695,7 @@ int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_s
} }
int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
whisper_init_default_state(ctx); return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads);
return whisper_pcm_to_mel_with_state(ctx, ctx->default_state, samples, n_samples, n_threads);
} }
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
@ -2683,9 +2710,7 @@ int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, st
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
whisper_init_default_state(ctx); return whisper_pcm_to_mel_phase_vocoder_with_state(ctx, ctx->state, samples, n_samples, n_threads);
return whisper_pcm_to_mel_phase_vocoder_with_state(ctx, ctx->default_state, samples, n_samples, n_threads);
} }
int whisper_set_mel_with_state( int whisper_set_mel_with_state(
@ -2713,9 +2738,7 @@ int whisper_set_mel(
const float * data, const float * data,
int n_len, int n_len,
int n_mel) { int n_mel) {
whisper_init_default_state(ctx); return whisper_set_mel_with_state(ctx, ctx->state, data, n_len, n_mel);
return whisper_set_mel_with_state(ctx, ctx->default_state, data, n_len, n_mel);
} }
int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) { int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
@ -2728,7 +2751,7 @@ int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state
} }
int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) { int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
if (!whisper_encode_internal(*ctx, *ctx->default_state, offset, n_threads)) { if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads)) {
fprintf(stderr, "%s: failed to eval\n", __func__); fprintf(stderr, "%s: failed to eval\n", __func__);
return -1; return -1;
} }
@ -2751,13 +2774,13 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
// TODO: add selected_decoder_id to state // TODO: add selected_decoder_id to state
const int selected_decoder_id = 0; const int selected_decoder_id = 0;
if (ctx->default_state == nullptr) { if (ctx->state == nullptr) {
fprintf(stderr, "%s: ERROR default_state was not loaded.\n", __func__); fprintf(stderr, "%s: ERROR state was not loaded.\n", __func__);
return false; return false;
} }
if (!whisper_decode_internal(*ctx, *ctx->default_state, ctx->default_state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) { if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
fprintf(stderr, "%s: failed to eval\n", __func__); fprintf(stderr, "%s: failed to eval\n", __func__);
return 1; return 1;
} }
@ -2895,8 +2918,7 @@ int whisper_lang_auto_detect(
int offset_ms, int offset_ms,
int n_threads, int n_threads,
float * lang_probs) { float * lang_probs) {
whisper_init_default_state(ctx); return whisper_lang_auto_detect_with_state(ctx, ctx->state, offset_ms, n_threads, lang_probs);
return whisper_lang_auto_detect_with_state(ctx, ctx->default_state, offset_ms, n_threads, lang_probs);
} }
int whisper_n_len_from_state(struct whisper_state * state) { int whisper_n_len_from_state(struct whisper_state * state) {
@ -2904,7 +2926,7 @@ int whisper_n_len_from_state(struct whisper_state * state) {
} }
int whisper_n_len(struct whisper_context * ctx) { int whisper_n_len(struct whisper_context * ctx) {
return ctx->default_state->mel.n_len; return ctx->state->mel.n_len;
} }
int whisper_n_vocab(struct whisper_context * ctx) { int whisper_n_vocab(struct whisper_context * ctx) {
@ -2924,7 +2946,7 @@ int whisper_is_multilingual(struct whisper_context * ctx) {
} }
float * whisper_get_logits(struct whisper_context * ctx) { float * whisper_get_logits(struct whisper_context * ctx) {
return ctx->default_state->logits.data(); return ctx->state->logits.data();
} }
@ -2977,27 +2999,26 @@ void whisper_print_timings(struct whisper_context * ctx) {
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f); fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
if (ctx->default_state != nullptr) { if (ctx->state != nullptr) {
const int32_t n_sample = std::max(1, ctx->default_state->n_sample); const int32_t n_sample = std::max(1, ctx->state->n_sample);
const int32_t n_encode = std::max(1, ctx->default_state->n_encode); const int32_t n_encode = std::max(1, ctx->state->n_encode);
const int32_t n_decode = std::max(1, ctx->default_state->n_decode); const int32_t n_decode = std::max(1, ctx->state->n_decode);
fprintf(stderr, "%s: fallbacks = %3d p / %3d h\n", __func__, ctx->default_state->n_fail_p, ctx->default_state->n_fail_h); fprintf(stderr, "%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
fprintf(stderr, "%s: mel time = %8.2f ms\n", __func__, ctx->default_state->t_mel_us / 1000.0f); fprintf(stderr, "%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->default_state->t_sample_us, n_sample, 1e-3f * ctx->default_state->t_sample_us / n_sample); fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
fprintf(stderr, "%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->default_state->t_encode_us, n_encode, 1e-3f * ctx->default_state->t_encode_us / n_encode); fprintf(stderr, "%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
fprintf(stderr, "%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->default_state->t_decode_us, n_decode, 1e-3f * ctx->default_state->t_decode_us / n_decode); fprintf(stderr, "%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
} }
fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
} }
void whisper_reset_timings(struct whisper_context * ctx) { void whisper_reset_timings(struct whisper_context * ctx) {
if (ctx->default_state != nullptr) { if (ctx->state != nullptr) {
ctx->state->t_sample_us = 0;
ctx->default_state->t_sample_us = 0; ctx->state->t_encode_us = 0;
ctx->default_state->t_encode_us = 0; ctx->state->t_decode_us = 0;
ctx->default_state->t_decode_us = 0;
} }
} }
@ -4335,9 +4356,7 @@ int whisper_full(
struct whisper_full_params params, struct whisper_full_params params,
const float * samples, const float * samples,
int n_samples) { int n_samples) {
//This is not thread safe and it's using default_state return whisper_full_with_state(ctx, ctx->state, params, samples, n_samples);
whisper_init_default_state(ctx);
return whisper_full_with_state(ctx, ctx->default_state, params, samples, n_samples);
} }
int whisper_full_parallel( int whisper_full_parallel(
@ -4349,12 +4368,10 @@ int whisper_full_parallel(
if (n_processors == 1) { if (n_processors == 1) {
return whisper_full(ctx, params, samples, n_samples); return whisper_full(ctx, params, samples, n_samples);
} }
whisper_init_default_state(ctx);
int ret = 0; int ret = 0;
// prepare separate states for each thread // prepare separate states for each thread
std::vector<whisper_state*> states{}; std::vector<whisper_state*> states;
const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000; const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000;
const int n_samples_per_processor = (n_samples - offset_samples)/n_processors; const int n_samples_per_processor = (n_samples - offset_samples)/n_processors;
@ -4389,7 +4406,7 @@ int whisper_full_parallel(
params_cur.print_realtime = false; params_cur.print_realtime = false;
// Run the first transformation using default state but only for the first chunk. // Run the first transformation using default state but only for the first chunk.
ret = whisper_full_with_state(ctx, ctx->default_state, std::move(params_cur), samples, offset_samples + n_samples_per_processor); ret = whisper_full_with_state(ctx, ctx->state, std::move(params_cur), samples, offset_samples + n_samples_per_processor);
} }
for (int i = 0; i < n_processors - 1; ++i) { for (int i = 0; i < n_processors - 1; ++i) {
@ -4409,32 +4426,32 @@ int whisper_full_parallel(
// make sure that segments are not overlapping // make sure that segments are not overlapping
if (!ctx->default_state->result_all.empty()) { if (!ctx->state->result_all.empty()) {
result.t0 = std::max(result.t0, ctx->default_state->result_all.back().t1); result.t0 = std::max(result.t0, ctx->state->result_all.back().t1);
} }
ctx->default_state->result_all.push_back(std::move(result)); ctx->state->result_all.push_back(std::move(result));
// call the new_segment_callback for each segment // call the new_segment_callback for each segment
if (params.new_segment_callback) { if (params.new_segment_callback) {
params.new_segment_callback(ctx, ctx->default_state, 1, params.new_segment_callback_user_data); params.new_segment_callback(ctx, ctx->state, 1, params.new_segment_callback_user_data);
} }
} }
ctx->default_state->t_mel_us += states[i]->t_mel_us; ctx->state->t_mel_us += states[i]->t_mel_us;
ctx->default_state->t_sample_us += states[i]->t_sample_us; ctx->state->t_sample_us += states[i]->t_sample_us;
ctx->default_state->t_encode_us += states[i]->t_encode_us; ctx->state->t_encode_us += states[i]->t_encode_us;
ctx->default_state->t_decode_us += states[i]->t_decode_us; ctx->state->t_decode_us += states[i]->t_decode_us;
whisper_free_state(states[i]); whisper_free_state(states[i]);
} }
// average the timings // average the timings
ctx->default_state->t_mel_us /= n_processors; ctx->state->t_mel_us /= n_processors;
ctx->default_state->t_sample_us /= n_processors; ctx->state->t_sample_us /= n_processors;
ctx->default_state->t_encode_us /= n_processors; ctx->state->t_encode_us /= n_processors;
ctx->default_state->t_decode_us /= n_processors; ctx->state->t_decode_us /= n_processors;
// print information about the audio boundaries // print information about the audio boundaries
fprintf(stderr, "\n"); fprintf(stderr, "\n");
@ -4452,7 +4469,7 @@ int whisper_full_n_segments_from_state(struct whisper_state * state) {
} }
int whisper_full_n_segments(struct whisper_context * ctx) { int whisper_full_n_segments(struct whisper_context * ctx) {
return ctx->default_state->result_all.size(); return ctx->state->result_all.size();
} }
int whisper_full_lang_id_from_state(struct whisper_state * state) { int whisper_full_lang_id_from_state(struct whisper_state * state) {
@ -4460,7 +4477,7 @@ int whisper_full_lang_id_from_state(struct whisper_state * state) {
} }
int whisper_full_lang_id(struct whisper_context * ctx) { int whisper_full_lang_id(struct whisper_context * ctx) {
return ctx->default_state->lang_id; return ctx->state->lang_id;
} }
int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) { int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) {
@ -4468,7 +4485,7 @@ int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int
} }
int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) { int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
return ctx->default_state->result_all[i_segment].t0; return ctx->state->result_all[i_segment].t0;
} }
int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment) { int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment) {
@ -4476,7 +4493,7 @@ int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int
} }
int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) { int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) {
return ctx->default_state->result_all[i_segment].t1; return ctx->state->result_all[i_segment].t1;
} }
const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment) { const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment) {
@ -4484,7 +4501,7 @@ const char * whisper_full_get_segment_text_from_state(struct whisper_state * sta
} }
const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) { const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) {
return ctx->default_state->result_all[i_segment].text.c_str(); return ctx->state->result_all[i_segment].text.c_str();
} }
int whisper_full_n_tokens_from_state(struct whisper_state * state, int i_segment) { int whisper_full_n_tokens_from_state(struct whisper_state * state, int i_segment) {
@ -4492,7 +4509,7 @@ int whisper_full_n_tokens_from_state(struct whisper_state * state, int i_segment
} }
int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment) { int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment) {
return ctx->default_state->result_all[i_segment].tokens.size(); return ctx->state->result_all[i_segment].tokens.size();
} }
const char * whisper_full_get_token_text_from_state(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token) { const char * whisper_full_get_token_text_from_state(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token) {
@ -4500,7 +4517,7 @@ const char * whisper_full_get_token_text_from_state(struct whisper_context * ctx
} }
const char* whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) { const char* whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) {
return ctx->vocab.id_to_token[ctx->default_state->result_all[i_segment].tokens[i_token].id].c_str(); return ctx->vocab.id_to_token[ctx->state->result_all[i_segment].tokens[i_token].id].c_str();
} }
whisper_token whisper_full_get_token_id_from_state(struct whisper_state * state, int i_segment, int i_token) { whisper_token whisper_full_get_token_id_from_state(struct whisper_state * state, int i_segment, int i_token) {
@ -4508,7 +4525,7 @@ whisper_token whisper_full_get_token_id_from_state(struct whisper_state * state,
} }
whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) { whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) {
return ctx->default_state->result_all[i_segment].tokens[i_token].id; return ctx->state->result_all[i_segment].tokens[i_token].id;
} }
struct whisper_token_data whisper_full_get_token_data_from_state(struct whisper_state * state, int i_segment, int i_token) { struct whisper_token_data whisper_full_get_token_data_from_state(struct whisper_state * state, int i_segment, int i_token) {
@ -4516,7 +4533,7 @@ struct whisper_token_data whisper_full_get_token_data_from_state(struct whisper_
} }
struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token) { struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token) {
return ctx->default_state->result_all[i_segment].tokens[i_token]; return ctx->state->result_all[i_segment].tokens[i_token];
} }
float whisper_full_get_token_p_from_state(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token) { float whisper_full_get_token_p_from_state(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token) {
@ -4524,7 +4541,7 @@ float whisper_full_get_token_p_from_state(struct whisper_context * ctx, struct w
} }
float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) { float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
return ctx->default_state->result_all[i_segment].tokens[i_token].p; return ctx->state->result_all[i_segment].tokens[i_token].p;
} }
// ================================================================================================= // =================================================================================================

@ -101,6 +101,9 @@ extern "C" {
WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model); WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model);
WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size); WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size);
WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader); WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader);
WHISPER_API struct whisper_context * whisper_init_from_file_no_state(const char * path_model);
WHISPER_API struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size);
WHISPER_API struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader);
WHISPER_API struct whisper_state * whisper_init_state(struct whisper_context * ctx); WHISPER_API struct whisper_state * whisper_init_state(struct whisper_context * ctx);
@ -119,9 +122,6 @@ extern "C" {
int n_samples, int n_samples,
int n_threads); int n_threads);
// Convert RAW PCM audio to log mel spectrogram.
// The resulting spectrogram is stored inside the provided state.
// Returns 0 on success
WHISPER_API int whisper_pcm_to_mel_with_state( WHISPER_API int whisper_pcm_to_mel_with_state(
struct whisper_context * ctx, struct whisper_context * ctx,
struct whisper_state * state, struct whisper_state * state,
@ -138,9 +138,6 @@ extern "C" {
int n_samples, int n_samples,
int n_threads); int n_threads);
// Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2.
// The resulting spectrogram is stored inside the provided state.
// Returns 0 on success
WHISPER_API int whisper_pcm_to_mel_phase_vocoder_with_state( WHISPER_API int whisper_pcm_to_mel_phase_vocoder_with_state(
struct whisper_context * ctx, struct whisper_context * ctx,
struct whisper_state * state, struct whisper_state * state,
@ -158,10 +155,6 @@ extern "C" {
int n_len, int n_len,
int n_mel); int n_mel);
// This can be used to set a custom log mel spectrogram inside the provided state.
// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
// n_mel must be 80
// Returns 0 on success
WHISPER_API int whisper_set_mel_with_state( WHISPER_API int whisper_set_mel_with_state(
struct whisper_context * ctx, struct whisper_context * ctx,
struct whisper_state * state, struct whisper_state * state,
@ -178,10 +171,6 @@ extern "C" {
int offset, int offset,
int n_threads); int n_threads);
// Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper state.
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
// offset can be used to specify the offset of the first frame in the spectrogram.
// Returns 0 on success
WHISPER_API int whisper_encode_with_state( WHISPER_API int whisper_encode_with_state(
struct whisper_context * ctx, struct whisper_context * ctx,
struct whisper_state * state, struct whisper_state * state,
@ -201,13 +190,6 @@ extern "C" {
int n_past, int n_past,
int n_threads); int n_threads);
// Run the Whisper decoder to obtain the logits and probabilities for the next token.
// Make sure to call whisper_encode() first.
// tokens + n_tokens is the provided context for the decoder.
// n_past is the number of tokens to use from previous decoder calls.
// Returns 0 on success
// TODO: add support for multiple decoders
WHISPER_API int whisper_decode_with_state( WHISPER_API int whisper_decode_with_state(
struct whisper_context * ctx, struct whisper_context * ctx,
struct whisper_state * state, struct whisper_state * state,
@ -251,12 +233,6 @@ extern "C" {
int n_threads, int n_threads,
float * lang_probs); float * lang_probs);
// Use mel data at offset_ms to try and auto-detect the spoken language
// Make sure to call whisper_pcm_to_mel_with_state() or whisper_set_mel_with_state() first
// Returns the top language id or negative on failure
// If not null, fills the lang_probs array with the probabilities of all languages
// The array must be whispe_lang_max_id() + 1 in size
// ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
WHISPER_API int whisper_lang_auto_detect_with_state( WHISPER_API int whisper_lang_auto_detect_with_state(
struct whisper_context * ctx, struct whisper_context * ctx,
struct whisper_state * state, struct whisper_state * state,
@ -424,9 +400,6 @@ extern "C" {
const float * samples, const float * samples,
int n_samples); int n_samples);
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text on the given state
// Thread safe for same context, as long as different state is provided.
// Uses the specified decoding strategy to obtain the text.
WHISPER_API int whisper_full_with_state( WHISPER_API int whisper_full_with_state(
struct whisper_context * ctx, struct whisper_context * ctx,
struct whisper_state * state, struct whisper_state * state,

Loading…
Cancel
Save