Fixed rest of the samples

pull/494/head
Sandro Hanea 2 years ago
parent b6aa7becb6
commit d4ee0266ef

@ -413,7 +413,7 @@ bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float
return true; return true;
} }
std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) { std::string transcribe(whisper_context * ctx, whisper_state * state, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
const auto t_start = std::chrono::high_resolution_clock::now(); const auto t_start = std::chrono::high_resolution_clock::now();
prob = 0.0f; prob = 0.0f;
@ -435,22 +435,22 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
wparams.audio_ctx = params.audio_ctx; wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up; wparams.speed_up = params.speed_up;
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { if (whisper_full_with_state(ctx, state, wparams, pcmf32.data(), pcmf32.size()) != 0) {
return ""; return "";
} }
int prob_n = 0; int prob_n = 0;
std::string result; std::string result;
const int n_segments = whisper_full_n_segments(ctx); const int n_segments = whisper_full_n_segments(state);
for (int i = 0; i < n_segments; ++i) { for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i); const char * text = whisper_full_get_segment_text(state, i);
result += text; result += text;
const int n_tokens = whisper_full_n_tokens(ctx, i); const int n_tokens = whisper_full_n_tokens(state, i);
for (int j = 0; j < n_tokens; ++j) { for (int j = 0; j < n_tokens; ++j) {
const auto token = whisper_full_get_token_data(ctx, i, j); const auto token = whisper_full_get_token_data(state, i, j);
prob += token.p; prob += token.p;
++prob_n; ++prob_n;
@ -545,7 +545,7 @@ bool process_sdl_events() {
// command-list mode // command-list mode
// guide the transcription to match the most likely command from a provided list // guide the transcription to match the most likely command from a provided list
int process_command_list(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) { int process_command_list(struct whisper_context * ctx, struct whisper_state * state, audio_async &audio, const whisper_params &params) {
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "%s: guided mode\n", __func__); fprintf(stderr, "%s: guided mode\n", __func__);
@ -666,7 +666,7 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
wparams.prompt_n_tokens = k_tokens.size(); wparams.prompt_n_tokens = k_tokens.size();
// run the transformer and a single decoding pass // run the transformer and a single decoding pass
if (whisper_full(ctx, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) { if (whisper_full_with_state(ctx, state, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) {
fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__); fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
break; break;
} }
@ -674,7 +674,7 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
// estimate command probability // estimate command probability
// NOTE: not optimal // NOTE: not optimal
{ {
const auto * logits = whisper_get_logits(ctx); const auto * logits = whisper_get_logits(state);
std::vector<float> probs(whisper_n_vocab(ctx), 0.0f); std::vector<float> probs(whisper_n_vocab(ctx), 0.0f);
@ -757,7 +757,7 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
// always-prompt mode // always-prompt mode
// transcribe the voice into text after valid prompt // transcribe the voice into text after valid prompt
int always_prompt_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) { int always_prompt_transcription(struct whisper_context * ctx, struct whisper_state * state, audio_async & audio, const whisper_params & params) {
bool is_running = true; bool is_running = true;
bool ask_prompt = true; bool ask_prompt = true;
@ -799,7 +799,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
// detect the commands // detect the commands
audio.get(params.command_ms, pcmf32_cur); audio.get(params.command_ms, pcmf32_cur);
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms)); const auto txt = ::trim(::transcribe(ctx, state, params, pcmf32_cur, prob, t_ms));
const auto words = get_words(txt); const auto words = get_words(txt);
@ -835,7 +835,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
// general-purpose mode // general-purpose mode
// freely transcribe the voice into text // freely transcribe the voice into text
int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) { int process_general_transcription(struct whisper_context * ctx, struct whisper_state * state, audio_async &audio, const whisper_params &params) {
bool is_running = true; bool is_running = true;
bool have_prompt = false; bool have_prompt = false;
bool ask_prompt = true; bool ask_prompt = true;
@ -879,7 +879,7 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
// wait for activation phrase // wait for activation phrase
audio.get(params.prompt_ms, pcmf32_cur); audio.get(params.prompt_ms, pcmf32_cur);
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms)); const auto txt = ::trim(::transcribe(ctx, state, params, pcmf32_cur, prob0, t_ms));
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms); fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
@ -905,7 +905,7 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
// prepend the prompt audio // prepend the prompt audio
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end()); pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms)); const auto txt = ::trim(::transcribe(ctx, state, params, pcmf32_cur, prob, t_ms));
prob = 100.0f*(prob - prob0); prob = 100.0f*(prob - prob0);
@ -958,6 +958,8 @@ int main(int argc, char ** argv) {
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str()); struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
struct whisper_state * state = whisper_init_state(ctx);
// print some info about the processing // print some info about the processing
{ {
fprintf(stderr, "\n"); fprintf(stderr, "\n");
@ -995,16 +997,17 @@ int main(int argc, char ** argv) {
int ret_val = 0; int ret_val = 0;
if (!params.commands.empty()) { if (!params.commands.empty()) {
ret_val = process_command_list(ctx, audio, params); ret_val = process_command_list(ctx, state, audio, params);
} else if (!params.prompt.empty()) { } else if (!params.prompt.empty()) {
ret_val = always_prompt_transcription(ctx, audio, params); ret_val = always_prompt_transcription(ctx, state, audio, params);
} else { } else {
ret_val = process_general_transcription(ctx, audio, params); ret_val = process_general_transcription(ctx, state, audio, params);
} }
audio.pause(); audio.pause();
whisper_print_timings(ctx); whisper_print_timings(ctx, state);
whisper_free_state(state);
whisper_free(ctx); whisper_free(ctx);
return ret_val; return ret_val;

@ -503,7 +503,7 @@ int main(int argc, char ** argv) {
} }
// state init for the current transformation // state init for the current transformation
struct whisper_state* state = whisper_init_state(ctx); struct whisper_state * state = whisper_init_state(ctx);
// initial prompt // initial prompt
std::vector<whisper_token> prompt_tokens; std::vector<whisper_token> prompt_tokens;

@ -458,6 +458,7 @@ int main(int argc, char ** argv) {
} }
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str()); struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
struct whisper_state * state = whisper_init_state(ctx);
std::vector<float> pcmf32 (n_samples_30s, 0.0f); std::vector<float> pcmf32 (n_samples_30s, 0.0f);
std::vector<float> pcmf32_old; std::vector<float> pcmf32_old;
@ -622,7 +623,7 @@ int main(int argc, char ** argv) {
wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data(); wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size(); wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size();
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { if (whisper_full_with_state(ctx, state, wparams, pcmf32.data(), pcmf32.size()) != 0) {
fprintf(stderr, "%s: failed to process audio\n", argv[0]); fprintf(stderr, "%s: failed to process audio\n", argv[0]);
return 6; return 6;
} }
@ -645,9 +646,9 @@ int main(int argc, char ** argv) {
printf("\n"); printf("\n");
} }
const int n_segments = whisper_full_n_segments(ctx); const int n_segments = whisper_full_n_segments(state);
for (int i = 0; i < n_segments; ++i) { for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i); const char * text = whisper_full_get_segment_text(state, i);
if (params.no_timestamps) { if (params.no_timestamps) {
printf("%s", text); printf("%s", text);
@ -657,8 +658,8 @@ int main(int argc, char ** argv) {
fout << text; fout << text;
} }
} else { } else {
const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t0 = whisper_full_get_segment_t0(state, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(state, i);
printf ("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text); printf ("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
@ -690,11 +691,11 @@ int main(int argc, char ** argv) {
if (!params.no_context) { if (!params.no_context) {
prompt_tokens.clear(); prompt_tokens.clear();
const int n_segments = whisper_full_n_segments(ctx); const int n_segments = whisper_full_n_segments(state);
for (int i = 0; i < n_segments; ++i) { for (int i = 0; i < n_segments; ++i) {
const int token_count = whisper_full_n_tokens(ctx, i); const int token_count = whisper_full_n_tokens(state, i);
for (int j = 0; j < token_count; ++j) { for (int j = 0; j < token_count; ++j) {
prompt_tokens.push_back(whisper_full_get_token_id(ctx, i, j)); prompt_tokens.push_back(whisper_full_get_token_id(state, i, j));
} }
} }
} }
@ -704,7 +705,8 @@ int main(int argc, char ** argv) {
audio.pause(); audio.pause();
whisper_print_timings(ctx); whisper_print_timings(ctx, state);
whisper_free_state(state);
whisper_free(ctx); whisper_free(ctx);
return 0; return 0;

@ -16,6 +16,7 @@
#include <thread> #include <thread>
#include <vector> #include <vector>
#include <regex> #include <regex>
#include <cstdlib>
// command-line parameters // command-line parameters
struct whisper_params { struct whisper_params {
@ -419,7 +420,7 @@ bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float
return true; return true;
} }
std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) { std::string transcribe(whisper_context * ctx, whisper_state * state, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
const auto t_start = std::chrono::high_resolution_clock::now(); const auto t_start = std::chrono::high_resolution_clock::now();
prob = 0.0f; prob = 0.0f;
@ -441,22 +442,22 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
wparams.audio_ctx = params.audio_ctx; wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up; wparams.speed_up = params.speed_up;
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { if (whisper_full_with_state(ctx, state, wparams, pcmf32.data(), pcmf32.size()) != 0) {
return ""; return "";
} }
int prob_n = 0; int prob_n = 0;
std::string result; std::string result;
const int n_segments = whisper_full_n_segments(ctx); const int n_segments = whisper_full_n_segments(state);
for (int i = 0; i < n_segments; ++i) { for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i); const char * text = whisper_full_get_segment_text(state, i);
result += text; result += text;
const int n_tokens = whisper_full_n_tokens(ctx, i); const int n_tokens = whisper_full_n_tokens(state, i);
for (int j = 0; j < n_tokens; ++j) { for (int j = 0; j < n_tokens; ++j) {
const auto token = whisper_full_get_token_data(ctx, i, j); const auto token = whisper_full_get_token_data(state, i, j);
prob += token.p; prob += token.p;
++prob_n; ++prob_n;
@ -470,6 +471,8 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
const auto t_end = std::chrono::high_resolution_clock::now(); const auto t_end = std::chrono::high_resolution_clock::now();
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count(); t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
whisper_free_state(state);
return result; return result;
} }
@ -500,6 +503,8 @@ int main(int argc, char ** argv) {
struct whisper_context * ctx_wsp = whisper_init_from_file(params.model_wsp.c_str()); struct whisper_context * ctx_wsp = whisper_init_from_file(params.model_wsp.c_str());
struct whisper_state * state_wsp = whisper_init_state(ctx_wsp);
// gpt init // gpt init
struct gpt2_context * ctx_gpt = gpt2_init(params.model_gpt.c_str()); struct gpt2_context * ctx_gpt = gpt2_init(params.model_gpt.c_str());
@ -591,7 +596,7 @@ int main(int argc, char ** argv) {
std::string text_heard; std::string text_heard;
if (!force_speak) { if (!force_speak) {
text_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prob0, t_ms)); text_heard = ::trim(::transcribe(ctx_wsp, state_wsp, params, pcmf32_cur, prob0, t_ms));
} }
// remove text between brackets using regex // remove text between brackets using regex
@ -688,7 +693,8 @@ int main(int argc, char ** argv) {
audio.pause(); audio.pause();
whisper_print_timings(ctx_wsp); whisper_print_timings(ctx_wsp, state_wsp);
whisper_free_state(state_wsp);
whisper_free(ctx_wsp); whisper_free(ctx_wsp);
return 0; return 0;

@ -2481,7 +2481,7 @@ struct whisper_state* whisper_init_state(whisper_context * ctx) {
if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->wtype, ctx->model.hparams.n_text_ctx)) { if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->wtype, ctx->model.hparams.n_text_ctx)) {
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
return false; return NULL;
} }
{ {
@ -2491,7 +2491,7 @@ struct whisper_state* whisper_init_state(whisper_context * ctx) {
if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_CROSS.at(ctx->model.type), state->kv_cross, ctx->wtype, ctx->model.hparams.n_audio_ctx)) { if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_CROSS.at(ctx->model.type), state->kv_cross, ctx->wtype, ctx->model.hparams.n_audio_ctx)) {
fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__); fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__);
return false; return NULL;
} }
{ {
@ -3514,7 +3514,7 @@ static void whisper_sequence_score(
} }
} }
int whisper_full_state( int whisper_full_with_state(
struct whisper_context * ctx, struct whisper_context * ctx,
struct whisper_state * state, struct whisper_state * state,
struct whisper_full_params params, struct whisper_full_params params,
@ -3523,6 +3523,11 @@ int whisper_full_state(
auto& result_all = state->result_all; auto& result_all = state->result_all;
// clear previous results in case some state is re-used
if (!result_all.empty()) {
result_all.clear();
}
// compute log mel spectrogram // compute log mel spectrogram
if (params.speed_up) { if (params.speed_up) {
if (whisper_pcm_to_mel_phase_vocoder(ctx, state, samples, n_samples, params.n_threads) != 0) { if (whisper_pcm_to_mel_phase_vocoder(ctx, state, samples, n_samples, params.n_threads) != 0) {
@ -4238,7 +4243,7 @@ int whisper_full(
// Initialize a new state for the current transformation // Initialize a new state for the current transformation
whisper_state* state = whisper_init_state(ctx); whisper_state* state = whisper_init_state(ctx);
int result = whisper_full_state(ctx, state, params, samples, n_samples); int result = whisper_full_with_state(ctx, state, params, samples, n_samples);
whisper_free_state(state); whisper_free_state(state);
@ -4287,7 +4292,7 @@ int whisper_full_parallel(
params_cur.new_segment_callback = nullptr; params_cur.new_segment_callback = nullptr;
params_cur.new_segment_callback_user_data = nullptr; params_cur.new_segment_callback_user_data = nullptr;
workers[i] = std::thread(whisper_full_state, ctx, states[i + 1], std::move(params_cur), samples + start_samples, n_samples_cur); workers[i] = std::thread(whisper_full_with_state, ctx, states[i + 1], std::move(params_cur), samples + start_samples, n_samples_cur);
} }
{ {
@ -4297,7 +4302,7 @@ int whisper_full_parallel(
params_cur.print_realtime = false; params_cur.print_realtime = false;
// Run the first transformation using the first state and for the first chunk. // Run the first transformation using the first state and for the first chunk.
ret = whisper_full_state(ctx, states[0], std::move(params_cur), samples, offset_samples + n_samples_per_processor); ret = whisper_full_with_state(ctx, states[0], std::move(params_cur), samples, offset_samples + n_samples_per_processor);
} }
for (int i = 0; i < n_processors - 1; ++i) { for (int i = 0; i < n_processors - 1; ++i) {

@ -339,6 +339,14 @@ extern "C" {
const float * samples, const float * samples,
int n_samples); int n_samples);
// Same as whisper_full() but uses the specified state, so that the caller can read the result on the state after the method call
WHISPER_API int whisper_full_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
struct whisper_full_params params,
const float * samples,
int n_samples);
// Split the input audio in chunks and process each chunk separately using whisper_full() // Split the input audio in chunks and process each chunk separately using whisper_full()
// It seems this approach can offer some speedup in some cases. // It seems this approach can offer some speedup in some cases.
// However, the transcription accuracy can be worse at the beginning and end of each chunk. // However, the transcription accuracy can be worse at the beginning and end of each chunk.

Loading…
Cancel
Save