diff --git a/whisper.cpp b/whisper.cpp index 331d408..04cbc36 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -2936,7 +2936,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.language =*/ "en", /*.suppress_blank =*/ true, - /*.suppress_non_speech_tokens =*/true, + /*.suppress_non_speech_tokens =*/ false, /*.temperature =*/ 0.0f, /*.max_initial_ts =*/ 1.0f, @@ -3078,8 +3078,7 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool return res; } -static const std::vector non_speech_tokens -{ +static const std::vector non_speech_tokens = { "\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^", "_", "`", "{", "|", "}", "~", "「", "」", "『", "』", "<<", ">>", "<<<", ">>>", "--", "---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪", @@ -3149,26 +3148,21 @@ static void whisper_process_logits( // suppress non-speech tokens // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 - if (params.suppress_non_speech_tokens) - { - for (const std::string &token : non_speech_tokens) - { - std::string suppress_tokens[] = {token, " " + token}; - for (const std::string &suppress_token : suppress_tokens) - { - if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end()) - { + if (params.suppress_non_speech_tokens) { + for (const std::string & token : non_speech_tokens) { + const std::string suppress_tokens[] = {token, " " + token}; + for (const std::string & suppress_token : suppress_tokens) { + if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end()) { logits[vocab.token_to_id.at(suppress_token)] = -INFINITY; } } } + // allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word - if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) - { + if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) { logits[vocab.token_to_id.at(" -")] = -INFINITY; } - if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) - { + if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) { logits[vocab.token_to_id.at(" '")] = -INFINITY; } }