whisper : support for compression_ration_threshold

We actually use entropy, but it is similar
pull/291/head
Georgi Gerganov 3 years ago
parent 34c5110f59
commit c67716faba
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -235,7 +235,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
}

@ -425,6 +425,7 @@ struct whisper_sequence {
double sum_logprobs; // the sum of the log probabilities of the tokens
double avg_logprobs; // the average log probability of the tokens
double entropy; // the entropy of the tokens
double score; // likelihood rank score
};
@ -2700,10 +2701,10 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.max_initial_timestamp =*/ 1.0f,
/*.length_penalty =*/ -1.0f,
/*.temperature_increment =*/ 0.2f,
/*.compression_ratio_threshold =*/ 2.4f,
/*.logprob_threshold =*/ -1.0f,
/*.no_speech_threshold =*/ 0.6f,
/*.temperature_increment =*/ 0.2f,
/*.entropy_threshold =*/ 2.4f,
/*.logprob_threshold =*/ -1.0f,
/*.no_speech_threshold =*/ 0.6f,
/*.greedy =*/ {
/*.best_of =*/ 5,
@ -2760,10 +2761,10 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.max_initial_timestamp =*/ 1.0f,
/*.length_penalty =*/ -1.0f,
/*.temperature_increment =*/ 0.2f,
/*.compression_ratio_threshold =*/ 2.4f,
/*.logprob_threshold =*/ -1.0f,
/*.no_speech_threshold =*/ 0.6f,
/*.temperature_increment =*/ 0.2f,
/*.entropy_threshold =*/ 2.4f,
/*.logprob_threshold =*/ -1.0f,
/*.no_speech_threshold =*/ 0.6f,
/*.greedy =*/ {
/*.best_of =*/ 5,
@ -3081,6 +3082,11 @@ static whisper_token_data whisper_sample_token(
result.plog = logprobs[result.id];
}
if (result.id >= vocab.token_beg) {
result.tid = result.id;
result.pt = result.p;
}
return result;
}
@ -3106,6 +3112,28 @@ static void whisper_sequence_score(
}
sequence.score = result/penalty;
// compute the entropy of the sequence of the last 32 tokens
{
const int n = 32;
int cnt = 0;
double entropy = 0.0f;
std::map<whisper_token, int> token_counts;
for (int i = std::max(0, sequence.result_len - n); i < sequence.result_len; ++i) {
token_counts[sequence.tokens[i].id]++;
cnt++;
}
for (const auto & kv : token_counts) {
const auto p = kv.second/(double)cnt;
entropy -= p*log(p);
//printf("entropy: %d %f %f, count %d\n", kv.first, p, log(p), kv.second);
}
sequence.entropy = entropy;
}
}
int whisper_full(
@ -3322,9 +3350,10 @@ int whisper_full(
decoder.sequence.tokens.clear();
decoder.sequence.result_len = 0;
decoder.sequence.sum_logprobs = 0.0;
decoder.sequence.avg_logprobs = 0.0;
decoder.sequence.score = 0.0;
decoder.sequence.sum_logprobs = -INFINITY;
decoder.sequence.avg_logprobs = -INFINITY;
decoder.sequence.entropy = 0.0f;
decoder.sequence.score = -INFINITY;
decoder.n_past = 0;
decoder.seek_delta = 100*WHISPER_CHUNK_SIZE;
@ -3543,12 +3572,22 @@ int whisper_full(
decoder.sequence.tokens.resize(decoder.sequence.result_len);
whisper_sequence_score(params, decoder.sequence);
fprintf(stderr, "%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n",
__func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy);
if (decoder.sequence.result_len > 8 && decoder.sequence.entropy < params.entropy_threshold) {
fprintf(stderr, "%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n",
__func__, j, decoder.sequence.entropy, params.entropy_threshold);
decoder.failed = true;
continue;
}
if (best_score < decoder.sequence.score) {
best_score = decoder.sequence.score;
best_decoder_id = j;
}
//fprintf(stderr, "%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f\n", __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs);
}
}
@ -3559,13 +3598,13 @@ int whisper_full(
{
auto & decoder = ctx->decoders[best_decoder_id];
if (decoder.sequence.avg_logprobs < params.logprob_threshold) {
if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_threshold) {
success = false;
}
}
if (success) {
//fprintf(stderr, "%s: best decoder = %d\n", __func__, best_decoder_id);
fprintf(stderr, "%s: best decoder = %d\n", __func__, best_decoder_id);
//for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) {
// fprintf(stderr, "%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str());
//}
@ -3623,6 +3662,8 @@ int whisper_full(
}
}
//printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid);
result_all.push_back({ tt0, tt1, text, {} });
for (int j = i0; j <= i; j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
@ -3690,6 +3731,8 @@ int whisper_full(
// update audio window
seek += seek_delta;
fprintf(stderr, "seek = %d, seek_delta = %d\n", seek, seek_delta);
}
}

@ -279,7 +279,7 @@ extern "C" {
// fallback parameters
float temperature_increment;
float compression_ratio_threshold;
float entropy_threshold; // analog tho OpenAI's compression_ratio_threshold
float logprob_threshold;
float no_speech_threshold;

Loading…
Cancel
Save