diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 48e0292..d149e49 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -235,7 +235,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi const char * text = whisper_full_get_token_text(ctx, i, j); const float p = whisper_full_get_token_p (ctx, i, j); - const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size())))); + const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size())))); printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m"); } diff --git a/whisper.cpp b/whisper.cpp index dad1139..45181e3 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -425,6 +425,7 @@ struct whisper_sequence { double sum_logprobs; // the sum of the log probabilities of the tokens double avg_logprobs; // the average log probability of the tokens + double entropy; // the entropy of the tokens double score; // likelihood rank score }; @@ -2700,10 +2701,10 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.max_initial_timestamp =*/ 1.0f, /*.length_penalty =*/ -1.0f, - /*.temperature_increment =*/ 0.2f, - /*.compression_ratio_threshold =*/ 2.4f, - /*.logprob_threshold =*/ -1.0f, - /*.no_speech_threshold =*/ 0.6f, + /*.temperature_increment =*/ 0.2f, + /*.entropy_threshold =*/ 2.4f, + /*.logprob_threshold =*/ -1.0f, + /*.no_speech_threshold =*/ 0.6f, /*.greedy =*/ { /*.best_of =*/ 5, @@ -2760,10 +2761,10 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.max_initial_timestamp =*/ 1.0f, /*.length_penalty =*/ -1.0f, - /*.temperature_increment =*/ 0.2f, - /*.compression_ratio_threshold =*/ 2.4f, - /*.logprob_threshold =*/ -1.0f, - /*.no_speech_threshold =*/ 0.6f, + /*.temperature_increment =*/ 0.2f, + /*.entropy_threshold =*/ 2.4f, + /*.logprob_threshold =*/ -1.0f, + /*.no_speech_threshold =*/ 0.6f, /*.greedy =*/ { /*.best_of =*/ 5, @@ -3081,6 +3082,11 @@ static whisper_token_data whisper_sample_token( result.plog = logprobs[result.id]; } + if (result.id >= vocab.token_beg) { + result.tid = result.id; + result.pt = result.p; + } + return result; } @@ -3106,6 +3112,28 @@ static void whisper_sequence_score( } sequence.score = result/penalty; + + // compute the entropy of the sequence of the last 32 tokens + { + const int n = 32; + + int cnt = 0; + double entropy = 0.0f; + + std::map token_counts; + for (int i = std::max(0, sequence.result_len - n); i < sequence.result_len; ++i) { + token_counts[sequence.tokens[i].id]++; + cnt++; + } + + for (const auto & kv : token_counts) { + const auto p = kv.second/(double)cnt; + entropy -= p*log(p); + //printf("entropy: %d %f %f, count %d\n", kv.first, p, log(p), kv.second); + } + + sequence.entropy = entropy; + } } int whisper_full( @@ -3322,9 +3350,10 @@ int whisper_full( decoder.sequence.tokens.clear(); decoder.sequence.result_len = 0; - decoder.sequence.sum_logprobs = 0.0; - decoder.sequence.avg_logprobs = 0.0; - decoder.sequence.score = 0.0; + decoder.sequence.sum_logprobs = -INFINITY; + decoder.sequence.avg_logprobs = -INFINITY; + decoder.sequence.entropy = 0.0f; + decoder.sequence.score = -INFINITY; decoder.n_past = 0; decoder.seek_delta = 100*WHISPER_CHUNK_SIZE; @@ -3543,12 +3572,22 @@ int whisper_full( decoder.sequence.tokens.resize(decoder.sequence.result_len); whisper_sequence_score(params, decoder.sequence); + fprintf(stderr, "%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n", + __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy); + + if (decoder.sequence.result_len > 8 && decoder.sequence.entropy < params.entropy_threshold) { + fprintf(stderr, "%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n", + __func__, j, decoder.sequence.entropy, params.entropy_threshold); + + decoder.failed = true; + + continue; + } + if (best_score < decoder.sequence.score) { best_score = decoder.sequence.score; best_decoder_id = j; } - - //fprintf(stderr, "%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f\n", __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs); } } @@ -3559,13 +3598,13 @@ int whisper_full( { auto & decoder = ctx->decoders[best_decoder_id]; - if (decoder.sequence.avg_logprobs < params.logprob_threshold) { + if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_threshold) { success = false; } } if (success) { - //fprintf(stderr, "%s: best decoder = %d\n", __func__, best_decoder_id); + fprintf(stderr, "%s: best decoder = %d\n", __func__, best_decoder_id); //for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) { // fprintf(stderr, "%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str()); //} @@ -3623,6 +3662,8 @@ int whisper_full( } } + //printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid); + result_all.push_back({ tt0, tt1, text, {} }); for (int j = i0; j <= i; j++) { result_all.back().tokens.push_back(tokens_cur[j]); @@ -3690,6 +3731,8 @@ int whisper_full( // update audio window seek += seek_delta; + + fprintf(stderr, "seek = %d, seek_delta = %d\n", seek, seek_delta); } } diff --git a/whisper.h b/whisper.h index 8c94a7a..69a2246 100644 --- a/whisper.h +++ b/whisper.h @@ -279,7 +279,7 @@ extern "C" { // fallback parameters float temperature_increment; - float compression_ratio_threshold; + float entropy_threshold; // analog tho OpenAI's compression_ratio_threshold float logprob_threshold; float no_speech_threshold;