|
|
|
@ -671,56 +671,81 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const auto * probs = whisper_get_probs(ctx);
|
|
|
|
|
std::vector<std::pair<float, int>> probs_id;
|
|
|
|
|
|
|
|
|
|
double psum = 0.0;
|
|
|
|
|
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
|
|
|
|
probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
|
|
|
|
|
for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
|
|
|
|
|
probs_id.back().first += probs[allowed_tokens[i][j]];
|
|
|
|
|
}
|
|
|
|
|
probs_id.back().first /= allowed_tokens[i].size();
|
|
|
|
|
psum += probs_id.back().first;
|
|
|
|
|
}
|
|
|
|
|
// estimate command probability
|
|
|
|
|
// NOTE: not optimal
|
|
|
|
|
{
|
|
|
|
|
const auto * logits = whisper_get_logits(ctx);
|
|
|
|
|
|
|
|
|
|
// normalize
|
|
|
|
|
for (auto & p : probs_id) {
|
|
|
|
|
p.first /= psum;
|
|
|
|
|
}
|
|
|
|
|
std::vector<float> probs(whisper_n_vocab(ctx), 0.0f);
|
|
|
|
|
|
|
|
|
|
// sort descending
|
|
|
|
|
{
|
|
|
|
|
using pair_type = decltype(probs_id)::value_type;
|
|
|
|
|
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
|
|
|
|
|
return a.first > b.first;
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
// compute probs from logits via softmax
|
|
|
|
|
{
|
|
|
|
|
float max = -1e9;
|
|
|
|
|
for (int i = 0; i < (int) probs.size(); ++i) {
|
|
|
|
|
max = std::max(max, logits[i]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// print the commands and the respective probabilities
|
|
|
|
|
{
|
|
|
|
|
fprintf(stdout, "\n");
|
|
|
|
|
for (const auto & cmd : probs_id) {
|
|
|
|
|
fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
|
|
|
|
|
for (int token : allowed_tokens[cmd.second]) {
|
|
|
|
|
fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]);
|
|
|
|
|
float sum = 0.0f;
|
|
|
|
|
for (int i = 0; i < (int) probs.size(); ++i) {
|
|
|
|
|
probs[i] = expf(logits[i] - max);
|
|
|
|
|
sum += probs[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < (int) probs.size(); ++i) {
|
|
|
|
|
probs[i] /= sum;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::pair<float, int>> probs_id;
|
|
|
|
|
|
|
|
|
|
double psum = 0.0;
|
|
|
|
|
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
|
|
|
|
probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
|
|
|
|
|
for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
|
|
|
|
|
probs_id.back().first += probs[allowed_tokens[i][j]];
|
|
|
|
|
}
|
|
|
|
|
probs_id.back().first /= allowed_tokens[i].size();
|
|
|
|
|
psum += probs_id.back().first;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// normalize
|
|
|
|
|
for (auto & p : probs_id) {
|
|
|
|
|
p.first /= psum;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// sort descending
|
|
|
|
|
{
|
|
|
|
|
using pair_type = decltype(probs_id)::value_type;
|
|
|
|
|
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
|
|
|
|
|
return a.first > b.first;
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// print the commands and the respective probabilities
|
|
|
|
|
{
|
|
|
|
|
fprintf(stdout, "\n");
|
|
|
|
|
for (const auto & cmd : probs_id) {
|
|
|
|
|
fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
|
|
|
|
|
for (int token : allowed_tokens[cmd.second]) {
|
|
|
|
|
fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]);
|
|
|
|
|
}
|
|
|
|
|
fprintf(stdout, "\n");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// best command
|
|
|
|
|
{
|
|
|
|
|
const auto t_end = std::chrono::high_resolution_clock::now();
|
|
|
|
|
// best command
|
|
|
|
|
{
|
|
|
|
|
const auto t_end = std::chrono::high_resolution_clock::now();
|
|
|
|
|
|
|
|
|
|
const float prob = probs_id[0].first;
|
|
|
|
|
const int index = probs_id[0].second;
|
|
|
|
|
const float prob = probs_id[0].first;
|
|
|
|
|
const int index = probs_id[0].second;
|
|
|
|
|
|
|
|
|
|
fprintf(stdout, "\n");
|
|
|
|
|
fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
|
|
|
|
|
"\033[1m", allowed_commands[index].c_str(), "\033[0m", prob,
|
|
|
|
|
(int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count());
|
|
|
|
|
fprintf(stdout, "\n");
|
|
|
|
|
fprintf(stdout, "\n");
|
|
|
|
|
fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
|
|
|
|
|
"\033[1m", allowed_commands[index].c_str(), "\033[0m", prob,
|
|
|
|
|
(int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count());
|
|
|
|
|
fprintf(stdout, "\n");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
audio.clear();
|
|
|
|
|