From cfc06bf8dfea876dded28cff647bd98267dfc718 Mon Sep 17 00:00:00 2001 From: shibukazu <61775791+shibukazu@users.noreply.github.com> Date: Wed, 8 Feb 2023 16:05:34 +0900 Subject: [PATCH] whisper : suppress non-speech-related token outputs (#473) * add non-speech-token suppression * add suppress non-speech_tokens param --- whisper.cpp | 36 ++++++++++++++++++++++++++++++++++++ whisper.h | 1 + 2 files changed, 37 insertions(+) diff --git a/whisper.cpp b/whisper.cpp index aebb481..24e16bd 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -2936,6 +2936,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.language =*/ "en", /*.suppress_blank =*/ true, + /*.suppress_non_speech_tokens =*/true, /*.temperature =*/ 0.0f, /*.max_initial_ts =*/ 1.0f, @@ -3077,6 +3078,14 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool return res; } +static const std::vector non_speech_tokens +{ + "\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^", + "_", "`", "{", "|", "}", "~", "「", "」", "『", "』", "<<", ">>", "<<<", ">>>", "--", + "---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪", + "♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯" +}; + // process the logits for the selected decoder // - applies logit filters // - computes logprobs and probs @@ -3137,6 +3146,33 @@ static void whisper_process_logits( logits[vocab.token_translate] = -INFINITY; logits[vocab.token_transcribe] = -INFINITY; + + // suppress non-speech tokens + // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 + if (params.suppress_non_speech_tokens) + { + for (const std::string &token : non_speech_tokens) + { + std::string suppress_tokens[] = {token, " " + token}; + for (const std::string &suppress_token : suppress_tokens) + { + if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end()) + { + logits[vocab.token_to_id.at(suppress_token)] = -INFINITY; + } + } + } + // allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word + if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) + { + logits[vocab.token_to_id.at(" -")] = -INFINITY; + } + if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) + { + logits[vocab.token_to_id.at(" '")] = -INFINITY; + } + } + // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424 { diff --git a/whisper.h b/whisper.h index 786d67d..7eece79 100644 --- a/whisper.h +++ b/whisper.h @@ -285,6 +285,7 @@ extern "C" { // common decoding parameters: bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89 + bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478 float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97