diff --git a/examples/command/command.cpp b/examples/command/command.cpp index 3dae3a5..2bdaf87 100644 --- a/examples/command/command.cpp +++ b/examples/command/command.cpp @@ -671,56 +671,81 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const break; } - const auto * probs = whisper_get_probs(ctx); - std::vector> probs_id; - - double psum = 0.0; - for (int i = 0; i < (int) allowed_commands.size(); ++i) { - probs_id.emplace_back(probs[allowed_tokens[i][0]], i); - for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) { - probs_id.back().first += probs[allowed_tokens[i][j]]; - } - probs_id.back().first /= allowed_tokens[i].size(); - psum += probs_id.back().first; - } + // estimate command probability + // NOTE: not optimal + { + const auto * logits = whisper_get_logits(ctx); - // normalize - for (auto & p : probs_id) { - p.first /= psum; - } + std::vector probs(whisper_n_vocab(ctx), 0.0f); - // sort descending - { - using pair_type = decltype(probs_id)::value_type; - std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) { - return a.first > b.first; - }); - } + // compute probs from logits via softmax + { + float max = -1e9; + for (int i = 0; i < (int) probs.size(); ++i) { + max = std::max(max, logits[i]); + } - // print the commands and the respective probabilities - { - fprintf(stdout, "\n"); - for (const auto & cmd : probs_id) { - fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first); - for (int token : allowed_tokens[cmd.second]) { - fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]); + float sum = 0.0f; + for (int i = 0; i < (int) probs.size(); ++i) { + probs[i] = expf(logits[i] - max); + sum += probs[i]; + } + + for (int i = 0; i < (int) probs.size(); ++i) { + probs[i] /= sum; } + } + + std::vector> probs_id; + + double psum = 0.0; + for (int i = 0; i < (int) allowed_commands.size(); ++i) { + probs_id.emplace_back(probs[allowed_tokens[i][0]], i); + for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) { + probs_id.back().first += probs[allowed_tokens[i][j]]; + } + probs_id.back().first /= allowed_tokens[i].size(); + psum += probs_id.back().first; + } + + // normalize + for (auto & p : probs_id) { + p.first /= psum; + } + + // sort descending + { + using pair_type = decltype(probs_id)::value_type; + std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) { + return a.first > b.first; + }); + } + + // print the commands and the respective probabilities + { fprintf(stdout, "\n"); + for (const auto & cmd : probs_id) { + fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first); + for (int token : allowed_tokens[cmd.second]) { + fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]); + } + fprintf(stdout, "\n"); + } } - } - // best command - { - const auto t_end = std::chrono::high_resolution_clock::now(); + // best command + { + const auto t_end = std::chrono::high_resolution_clock::now(); - const float prob = probs_id[0].first; - const int index = probs_id[0].second; + const float prob = probs_id[0].first; + const int index = probs_id[0].second; - fprintf(stdout, "\n"); - fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__, - "\033[1m", allowed_commands[index].c_str(), "\033[0m", prob, - (int) std::chrono::duration_cast(t_end - t_start).count()); - fprintf(stdout, "\n"); + fprintf(stdout, "\n"); + fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__, + "\033[1m", allowed_commands[index].c_str(), "\033[0m", prob, + (int) std::chrono::duration_cast(t_end - t_start).count()); + fprintf(stdout, "\n"); + } } audio.clear();