Fixed rest of the samples

pull/494/head
Sandro Hanea 2 years ago
parent b6aa7becb6
commit d4ee0266ef

@ -413,7 +413,7 @@ bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float
return true;
}
std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
std::string transcribe(whisper_context * ctx, whisper_state * state, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
const auto t_start = std::chrono::high_resolution_clock::now();
prob = 0.0f;
@ -435,22 +435,22 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
if (whisper_full_with_state(ctx, state, wparams, pcmf32.data(), pcmf32.size()) != 0) {
return "";
}
int prob_n = 0;
std::string result;
const int n_segments = whisper_full_n_segments(ctx);
const int n_segments = whisper_full_n_segments(state);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const char * text = whisper_full_get_segment_text(state, i);
result += text;
const int n_tokens = whisper_full_n_tokens(ctx, i);
const int n_tokens = whisper_full_n_tokens(state, i);
for (int j = 0; j < n_tokens; ++j) {
const auto token = whisper_full_get_token_data(ctx, i, j);
const auto token = whisper_full_get_token_data(state, i, j);
prob += token.p;
++prob_n;
@ -545,7 +545,7 @@ bool process_sdl_events() {
// command-list mode
// guide the transcription to match the most likely command from a provided list
int process_command_list(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) {
int process_command_list(struct whisper_context * ctx, struct whisper_state * state, audio_async &audio, const whisper_params &params) {
fprintf(stderr, "\n");
fprintf(stderr, "%s: guided mode\n", __func__);
@ -666,7 +666,7 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
wparams.prompt_n_tokens = k_tokens.size();
// run the transformer and a single decoding pass
if (whisper_full(ctx, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) {
if (whisper_full_with_state(ctx, state, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) {
fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
break;
}
@ -674,7 +674,7 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
// estimate command probability
// NOTE: not optimal
{
const auto * logits = whisper_get_logits(ctx);
const auto * logits = whisper_get_logits(state);
std::vector<float> probs(whisper_n_vocab(ctx), 0.0f);
@ -757,7 +757,7 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
// always-prompt mode
// transcribe the voice into text after valid prompt
int always_prompt_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) {
int always_prompt_transcription(struct whisper_context * ctx, struct whisper_state * state, audio_async & audio, const whisper_params & params) {
bool is_running = true;
bool ask_prompt = true;
@ -799,7 +799,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
// detect the commands
audio.get(params.command_ms, pcmf32_cur);
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
const auto txt = ::trim(::transcribe(ctx, state, params, pcmf32_cur, prob, t_ms));
const auto words = get_words(txt);
@ -835,7 +835,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
// general-purpose mode
// freely transcribe the voice into text
int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) {
int process_general_transcription(struct whisper_context * ctx, struct whisper_state * state, audio_async &audio, const whisper_params &params) {
bool is_running = true;
bool have_prompt = false;
bool ask_prompt = true;
@ -879,7 +879,7 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
// wait for activation phrase
audio.get(params.prompt_ms, pcmf32_cur);
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
const auto txt = ::trim(::transcribe(ctx, state, params, pcmf32_cur, prob0, t_ms));
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
@ -905,7 +905,7 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
// prepend the prompt audio
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
const auto txt = ::trim(::transcribe(ctx, state, params, pcmf32_cur, prob, t_ms));
prob = 100.0f*(prob - prob0);
@ -958,6 +958,8 @@ int main(int argc, char ** argv) {
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
struct whisper_state * state = whisper_init_state(ctx);
// print some info about the processing
{
fprintf(stderr, "\n");
@ -995,16 +997,17 @@ int main(int argc, char ** argv) {
int ret_val = 0;
if (!params.commands.empty()) {
ret_val = process_command_list(ctx, audio, params);
ret_val = process_command_list(ctx, state, audio, params);
} else if (!params.prompt.empty()) {
ret_val = always_prompt_transcription(ctx, audio, params);
ret_val = always_prompt_transcription(ctx, state, audio, params);
} else {
ret_val = process_general_transcription(ctx, audio, params);
ret_val = process_general_transcription(ctx, state, audio, params);
}
audio.pause();
whisper_print_timings(ctx);
whisper_print_timings(ctx, state);
whisper_free_state(state);
whisper_free(ctx);
return ret_val;

@ -458,6 +458,7 @@ int main(int argc, char ** argv) {
}
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
struct whisper_state * state = whisper_init_state(ctx);
std::vector<float> pcmf32 (n_samples_30s, 0.0f);
std::vector<float> pcmf32_old;
@ -622,7 +623,7 @@ int main(int argc, char ** argv) {
wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size();
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
if (whisper_full_with_state(ctx, state, wparams, pcmf32.data(), pcmf32.size()) != 0) {
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
return 6;
}
@ -645,9 +646,9 @@ int main(int argc, char ** argv) {
printf("\n");
}
const int n_segments = whisper_full_n_segments(ctx);
const int n_segments = whisper_full_n_segments(state);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const char * text = whisper_full_get_segment_text(state, i);
if (params.no_timestamps) {
printf("%s", text);
@ -657,8 +658,8 @@ int main(int argc, char ** argv) {
fout << text;
}
} else {
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(state, i);
const int64_t t1 = whisper_full_get_segment_t1(state, i);
printf ("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
@ -690,11 +691,11 @@ int main(int argc, char ** argv) {
if (!params.no_context) {
prompt_tokens.clear();
const int n_segments = whisper_full_n_segments(ctx);
const int n_segments = whisper_full_n_segments(state);
for (int i = 0; i < n_segments; ++i) {
const int token_count = whisper_full_n_tokens(ctx, i);
const int token_count = whisper_full_n_tokens(state, i);
for (int j = 0; j < token_count; ++j) {
prompt_tokens.push_back(whisper_full_get_token_id(ctx, i, j));
prompt_tokens.push_back(whisper_full_get_token_id(state, i, j));
}
}
}
@ -704,7 +705,8 @@ int main(int argc, char ** argv) {
audio.pause();
whisper_print_timings(ctx);
whisper_print_timings(ctx, state);
whisper_free_state(state);
whisper_free(ctx);
return 0;

@ -16,6 +16,7 @@
#include <thread>
#include <vector>
#include <regex>
#include <cstdlib>
// command-line parameters
struct whisper_params {
@ -419,7 +420,7 @@ bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float
return true;
}
std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
std::string transcribe(whisper_context * ctx, whisper_state * state, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
const auto t_start = std::chrono::high_resolution_clock::now();
prob = 0.0f;
@ -441,22 +442,22 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
if (whisper_full_with_state(ctx, state, wparams, pcmf32.data(), pcmf32.size()) != 0) {
return "";
}
int prob_n = 0;
std::string result;
const int n_segments = whisper_full_n_segments(ctx);
const int n_segments = whisper_full_n_segments(state);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const char * text = whisper_full_get_segment_text(state, i);
result += text;
const int n_tokens = whisper_full_n_tokens(ctx, i);
const int n_tokens = whisper_full_n_tokens(state, i);
for (int j = 0; j < n_tokens; ++j) {
const auto token = whisper_full_get_token_data(ctx, i, j);
const auto token = whisper_full_get_token_data(state, i, j);
prob += token.p;
++prob_n;
@ -470,6 +471,8 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
const auto t_end = std::chrono::high_resolution_clock::now();
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
whisper_free_state(state);
return result;
}
@ -500,6 +503,8 @@ int main(int argc, char ** argv) {
struct whisper_context * ctx_wsp = whisper_init_from_file(params.model_wsp.c_str());
struct whisper_state * state_wsp = whisper_init_state(ctx_wsp);
// gpt init
struct gpt2_context * ctx_gpt = gpt2_init(params.model_gpt.c_str());
@ -591,7 +596,7 @@ int main(int argc, char ** argv) {
std::string text_heard;
if (!force_speak) {
text_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prob0, t_ms));
text_heard = ::trim(::transcribe(ctx_wsp, state_wsp, params, pcmf32_cur, prob0, t_ms));
}
// remove text between brackets using regex
@ -688,7 +693,8 @@ int main(int argc, char ** argv) {
audio.pause();
whisper_print_timings(ctx_wsp);
whisper_print_timings(ctx_wsp, state_wsp);
whisper_free_state(state_wsp);
whisper_free(ctx_wsp);
return 0;

@ -2481,7 +2481,7 @@ struct whisper_state* whisper_init_state(whisper_context * ctx) {
if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->wtype, ctx->model.hparams.n_text_ctx)) {
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
return false;
return NULL;
}
{
@ -2491,7 +2491,7 @@ struct whisper_state* whisper_init_state(whisper_context * ctx) {
if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_CROSS.at(ctx->model.type), state->kv_cross, ctx->wtype, ctx->model.hparams.n_audio_ctx)) {
fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__);
return false;
return NULL;
}
{
@ -3514,7 +3514,7 @@ static void whisper_sequence_score(
}
}
int whisper_full_state(
int whisper_full_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
struct whisper_full_params params,
@ -3523,6 +3523,11 @@ int whisper_full_state(
auto& result_all = state->result_all;
// clear previous results in case some state is re-used
if (!result_all.empty()) {
result_all.clear();
}
// compute log mel spectrogram
if (params.speed_up) {
if (whisper_pcm_to_mel_phase_vocoder(ctx, state, samples, n_samples, params.n_threads) != 0) {
@ -4238,7 +4243,7 @@ int whisper_full(
// Initialize a new state for the current transformation
whisper_state* state = whisper_init_state(ctx);
int result = whisper_full_state(ctx, state, params, samples, n_samples);
int result = whisper_full_with_state(ctx, state, params, samples, n_samples);
whisper_free_state(state);
@ -4287,7 +4292,7 @@ int whisper_full_parallel(
params_cur.new_segment_callback = nullptr;
params_cur.new_segment_callback_user_data = nullptr;
workers[i] = std::thread(whisper_full_state, ctx, states[i + 1], std::move(params_cur), samples + start_samples, n_samples_cur);
workers[i] = std::thread(whisper_full_with_state, ctx, states[i + 1], std::move(params_cur), samples + start_samples, n_samples_cur);
}
{
@ -4297,7 +4302,7 @@ int whisper_full_parallel(
params_cur.print_realtime = false;
// Run the first transformation using the first state and for the first chunk.
ret = whisper_full_state(ctx, states[0], std::move(params_cur), samples, offset_samples + n_samples_per_processor);
ret = whisper_full_with_state(ctx, states[0], std::move(params_cur), samples, offset_samples + n_samples_per_processor);
}
for (int i = 0; i < n_processors - 1; ++i) {

@ -339,6 +339,14 @@ extern "C" {
const float * samples,
int n_samples);
// Same as whisper_full() but uses the specified state, so that the caller can read the result on the state after the method call
WHISPER_API int whisper_full_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
struct whisper_full_params params,
const float * samples,
int n_samples);
// Split the input audio in chunks and process each chunk separately using whisper_full()
// It seems this approach can offer some speedup in some cases.
// However, the transcription accuracy can be worse at the beginning and end of each chunk.

Loading…
Cancel
Save