whisper : minor improvemnt in decoding strategy (#244)

Do not allow for text segments to go beyond end of audio.
This partially mitigates some issues when the last audio window is 1-2
seconds just before the end of the audio file and the decoding spirals
into a repetition of the last transcribed phrase.
pull/245/head
Georgi Gerganov 2 years ago
parent 50a061b313
commit 603f97ba11
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -2687,6 +2687,7 @@ int whisper_full(
tokens_cur.clear(); tokens_cur.clear();
bool failed = false; bool failed = false;
bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment?
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) { if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
@ -2712,13 +2713,13 @@ int whisper_full(
const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx)); const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx));
// do not allow to go back in time // do not allow to go back in time
if (seek_delta != 100*WHISPER_CHUNK_SIZE && if (has_ts && seek_delta > seek_delta_new && result_len < i) {
seek_delta > seek_delta_new && result_len < i) {
break; break;
} }
seek_delta = seek_delta_new; seek_delta = seek_delta_new;
result_len = i + 1; result_len = i + 1;
has_ts = true;
} }
// add it to the context // add it to the context
@ -2730,8 +2731,11 @@ int whisper_full(
// printf("%s: %10s %6d %6.3f '%s'\n", __func__, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str()); // printf("%s: %10s %6d %6.3f '%s'\n", __func__, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str());
//} //}
// end of text token // end of segment
if (token.id == whisper_token_eot(ctx) || (params.max_tokens > 0 && i > params.max_tokens)) { if (token.id == whisper_token_eot(ctx) || // end of text token
(params.max_tokens > 0 && i > params.max_tokens) || // max tokens per segment reached
(has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached
) {
if (result_len == 0) { if (result_len == 0) {
if (seek + seek_delta + 100 >= seek_end) { if (seek + seek_delta + 100 >= seek_end) {
result_len = i + 1; result_len = i + 1;

Loading…
Cancel
Save