add suppress non-speech_tokens param

pull/473/head
shibukazu 2 years ago
parent 2d3332e26f
commit a8f0bd4e89

@ -2936,6 +2936,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.language =*/ "en", /*.language =*/ "en",
/*.suppress_blank =*/ true, /*.suppress_blank =*/ true,
/*.suppress_non_speech_tokens =*/true,
/*.temperature =*/ 0.0f, /*.temperature =*/ 0.0f,
/*.max_initial_ts =*/ 1.0f, /*.max_initial_ts =*/ 1.0f,
@ -3073,6 +3074,14 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool
return res; return res;
} }
static const std::vector<std::string> non_speech_tokens
{
"\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^",
"_", "`", "{", "|", "}", "~", "", "", "", "", "<<", ">>", "<<<", ">>>", "--",
"---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪",
"♪♪♪","", "", "", "", "", "", ""
};
// process the logits for the selected decoder // process the logits for the selected decoder
// - applies logit filters // - applies logit filters
// - computes logprobs and probs // - computes logprobs and probs
@ -3136,13 +3145,8 @@ static void whisper_process_logits(
// suppress non-speech tokens // suppress non-speech tokens
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
std::vector<std::string> non_speech_tokens{ if (params.suppress_non_speech_tokens)
"\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^", {
"_", "`", "{", "|", "}", "~", "", "", "", "", "<<", ">>", "<<<", ">>>", "--",
"---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪",
"♪♪♪","", "", "", "", "", "", ""
};
for (const std::string &token : non_speech_tokens) for (const std::string &token : non_speech_tokens)
{ {
std::string suppress_tokens[] = {token, " " + token}; std::string suppress_tokens[] = {token, " " + token};
@ -3155,12 +3159,15 @@ static void whisper_process_logits(
} }
} }
// allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word // allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) { if (vocab.token_to_id.find(" -") != vocab.token_to_id.end())
{
logits[vocab.token_to_id.at(" -")] = -INFINITY; logits[vocab.token_to_id.at(" -")] = -INFINITY;
} }
if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) { if (vocab.token_to_id.find(" '") != vocab.token_to_id.end())
{
logits[vocab.token_to_id.at(" '")] = -INFINITY; logits[vocab.token_to_id.at(" '")] = -INFINITY;
} }
}
// timestamps have to appear in pairs, except directly before EOT; mask logits accordingly // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
// https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424 // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424

@ -275,6 +275,7 @@ extern "C" {
// common decoding parameters: // common decoding parameters:
bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89 bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478 float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478
float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97 float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97

Loading…
Cancel
Save