|
|
|
@ -109,6 +109,73 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
|
|
|
|
fprintf(stderr, "\n");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
struct whisper_logits_filter_user_data {
|
|
|
|
|
std::vector<std::string> * allowed_commands;
|
|
|
|
|
std::vector<std::vector<whisper_token>> * allowed_tokens;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
void whisper_logits_filter(
|
|
|
|
|
struct whisper_context * ctx,
|
|
|
|
|
const whisper_token_data * tokens,
|
|
|
|
|
int n_tokens,
|
|
|
|
|
float * logits,
|
|
|
|
|
void * user_data){
|
|
|
|
|
const auto & allowed_tokens = *((whisper_logits_filter_user_data *) user_data)->allowed_tokens;
|
|
|
|
|
|
|
|
|
|
printf("n_tokens = %d\n", n_tokens);
|
|
|
|
|
for (int i = 0; i < n_tokens; i++) {
|
|
|
|
|
printf(" - '%s' (%.2f)\n", whisper_token_to_str(ctx, tokens[i].id), logits[i]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (n_tokens == 0) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::pair<whisper_token, float>> pool;
|
|
|
|
|
for (int i = 0; i < (int) allowed_tokens.size(); i++) {
|
|
|
|
|
const int n = (int) allowed_tokens[i].size();
|
|
|
|
|
if (n_tokens > n) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const whisper_token id = allowed_tokens[i][n_tokens - 1];
|
|
|
|
|
pool.push_back({ id, logits[id] });
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (pool.empty()) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
printf("applying logits filter, pool size = %d\n", (int) pool.size());
|
|
|
|
|
|
|
|
|
|
const int ibeg = whisper_token_beg(ctx);
|
|
|
|
|
|
|
|
|
|
double sum_all = 0.0;
|
|
|
|
|
for (int i = 0; i < ibeg; ++i) {
|
|
|
|
|
if (logits[i] == -INFINITY) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
sum_all += logits[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
double sum_pool = 0.0;
|
|
|
|
|
for (int i = 0; i < (int) pool.size(); ++i) {
|
|
|
|
|
sum_pool += pool[i].second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
printf("sum_all = %.2f, sum_pool = %.2f\n", sum_all, sum_pool);
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < ibeg; ++i) {
|
|
|
|
|
logits[i] = -INFINITY;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < (int) pool.size(); ++i) {
|
|
|
|
|
//logits[pool[i].first] = pool[i].second / sum_pool * sum_all;
|
|
|
|
|
logits[pool[i].first] = pool[i].second;
|
|
|
|
|
printf(" - '%s' (%.2f)\n", whisper_token_to_str(ctx, pool[i].first), logits[pool[i].first]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
|
|
|
|
|
const auto t_start = std::chrono::high_resolution_clock::now();
|
|
|
|
|
|
|
|
|
@ -131,6 +198,8 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
|
|
|
|
|
wparams.audio_ctx = params.audio_ctx;
|
|
|
|
|
wparams.speed_up = params.speed_up;
|
|
|
|
|
|
|
|
|
|
wparams.temperature_inc = -1.0f;
|
|
|
|
|
|
|
|
|
|
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
|
|
|
|
return "";
|
|
|
|
|
}
|
|
|
|
@ -334,22 +403,31 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
|
|
|
|
|
wparams.translate = params.translate;
|
|
|
|
|
wparams.no_context = true;
|
|
|
|
|
wparams.single_segment = true;
|
|
|
|
|
wparams.max_tokens = 1;
|
|
|
|
|
//wparams.max_tokens = 1;
|
|
|
|
|
wparams.language = params.language.c_str();
|
|
|
|
|
wparams.n_threads = params.n_threads;
|
|
|
|
|
|
|
|
|
|
wparams.audio_ctx = params.audio_ctx;
|
|
|
|
|
wparams.speed_up = params.speed_up;
|
|
|
|
|
|
|
|
|
|
wparams.temperature_inc = -1.0f;
|
|
|
|
|
|
|
|
|
|
wparams.prompt_tokens = k_tokens.data();
|
|
|
|
|
wparams.prompt_n_tokens = k_tokens.size();
|
|
|
|
|
|
|
|
|
|
whisper_logits_filter_user_data user_data = { &allowed_commands, &allowed_tokens };
|
|
|
|
|
|
|
|
|
|
wparams.logits_filter_callback = whisper_logits_filter;
|
|
|
|
|
wparams.logits_filter_callback_user_data = &user_data;
|
|
|
|
|
|
|
|
|
|
// run the transformer and a single decoding pass
|
|
|
|
|
if (whisper_full(ctx, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) {
|
|
|
|
|
fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fprintf(stdout, "%s: text - '%s'\n", __func__, whisper_full_get_segment_text(ctx, 0));
|
|
|
|
|
|
|
|
|
|
// estimate command probability
|
|
|
|
|
// NOTE: not optimal
|
|
|
|
|
{
|
|
|
|
@ -436,7 +514,7 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
|
|
|
|
|
|
|
|
|
|
// always-prompt mode
|
|
|
|
|
// transcribe the voice into text after valid prompt
|
|
|
|
|
int always_prompt_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) {
|
|
|
|
|
int process_always_prompt(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) {
|
|
|
|
|
bool is_running = true;
|
|
|
|
|
bool ask_prompt = true;
|
|
|
|
|
|
|
|
|
@ -496,7 +574,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
|
|
|
|
|
const float sim = similarity(prompt, k_prompt);
|
|
|
|
|
|
|
|
|
|
//debug
|
|
|
|
|
//fprintf(stdout, "command size: %i\n", command_length);
|
|
|
|
|
//fprintf(stdout, "command size: %d, sim: %f\n", (int) command.size(), sim);
|
|
|
|
|
|
|
|
|
|
if ((sim > 0.7f) && (command.size() > 0)) {
|
|
|
|
|
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
|
|
|
@ -676,7 +754,7 @@ int main(int argc, char ** argv) {
|
|
|
|
|
if (!params.commands.empty()) {
|
|
|
|
|
ret_val = process_command_list(ctx, audio, params);
|
|
|
|
|
} else if (!params.prompt.empty()) {
|
|
|
|
|
ret_val = always_prompt_transcription(ctx, audio, params);
|
|
|
|
|
ret_val = process_always_prompt(ctx, audio, params);
|
|
|
|
|
} else {
|
|
|
|
|
ret_val = process_general_transcription(ctx, audio, params);
|
|
|
|
|
}
|
|
|
|
|