From 8f8a5aca99135befb68ce1fa1c329d5110c9c146 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 26 Feb 2023 21:10:50 +0200 Subject: [PATCH] sync : latest whisper.cpp --- examples/whisper/main.cpp | 4 +-- examples/whisper/whisper.cpp | 48 ++++++++++++++++++------------------ examples/whisper/whisper.h | 14 +++++++++++ 3 files changed, 40 insertions(+), 26 deletions(-) diff --git a/examples/whisper/main.cpp b/examples/whisper/main.cpp index 5bd7e42..b8366b7 100644 --- a/examples/whisper/main.cpp +++ b/examples/whisper/main.cpp @@ -91,12 +91,12 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params); bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { for (int i = 1; i < argc; i++) { std::string arg = argv[i]; - + if (arg == "-"){ params.fname_inp.push_back(arg); continue; } - + if (arg[0] != '-') { params.fname_inp.push_back(arg); continue; diff --git a/examples/whisper/whisper.cpp b/examples/whisper/whisper.cpp index 331d408..3a21581 100644 --- a/examples/whisper/whisper.cpp +++ b/examples/whisper/whisper.cpp @@ -592,16 +592,16 @@ struct whisper_context { mutable std::mt19937 rng; // used for sampling at t > 0.0 - int lang_id; + int lang_id = 0; // english by default // [EXPERIMENTAL] token-level timestamps data - int64_t t_beg; - int64_t t_last; + int64_t t_beg = 0; + int64_t t_last = 0; whisper_token tid_last; std::vector energy; // PCM signal energy // [EXPERIMENTAL] speed-up techniques - int32_t exp_n_audio_ctx; // 0 - use default + int32_t exp_n_audio_ctx = 0; // 0 - use default void use_buf(struct ggml_context * ctx, int i) { #if defined(WHISPER_USE_SCRATCH) @@ -805,7 +805,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con MEM_REQ_SCRATCH3.at (model.type) + scale*MEM_REQ_MODEL.at (model.type) + scale*MEM_REQ_KV_CROSS.at(model.type) + - scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)); + scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)); // this is the memory required by one decoder const size_t mem_required_decoder = @@ -2936,7 +2936,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.language =*/ "en", /*.suppress_blank =*/ true, - /*.suppress_non_speech_tokens =*/true, + /*.suppress_non_speech_tokens =*/ false, /*.temperature =*/ 0.0f, /*.max_initial_ts =*/ 1.0f, @@ -2962,6 +2962,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.encoder_begin_callback =*/ nullptr, /*.encoder_begin_callback_user_data =*/ nullptr, + + /*.logits_filter_callback =*/ nullptr, + /*.logits_filter_callback_user_data =*/ nullptr, }; switch (strategy) { @@ -3078,8 +3081,7 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool return res; } -static const std::vector non_speech_tokens -{ +static const std::vector non_speech_tokens = { "\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^", "_", "`", "{", "|", "}", "~", "「", "」", "『", "』", "<<", ">>", "<<<", ">>>", "--", "---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪", @@ -3090,7 +3092,7 @@ static const std::vector non_speech_tokens // - applies logit filters // - computes logprobs and probs static void whisper_process_logits( - const struct whisper_context & ctx, + struct whisper_context & ctx, const struct whisper_full_params params, struct whisper_decoder & decoder, float temperature) { @@ -3146,29 +3148,27 @@ static void whisper_process_logits( logits[vocab.token_translate] = -INFINITY; logits[vocab.token_transcribe] = -INFINITY; + if (params.logits_filter_callback) { + params.logits_filter_callback(&ctx, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data); + } // suppress non-speech tokens // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 - if (params.suppress_non_speech_tokens) - { - for (const std::string &token : non_speech_tokens) - { - std::string suppress_tokens[] = {token, " " + token}; - for (const std::string &suppress_token : suppress_tokens) - { - if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end()) - { + if (params.suppress_non_speech_tokens) { + for (const std::string & token : non_speech_tokens) { + const std::string suppress_tokens[] = {token, " " + token}; + for (const std::string & suppress_token : suppress_tokens) { + if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end()) { logits[vocab.token_to_id.at(suppress_token)] = -INFINITY; } } } + // allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word - if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) - { + if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) { logits[vocab.token_to_id.at(" -")] = -INFINITY; } - if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) - { + if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) { logits[vocab.token_to_id.at(" '")] = -INFINITY; } } @@ -3854,7 +3854,7 @@ int whisper_full( return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all; }); - unsigned int cur_c = 0; + uint32_t cur_c = 0; for (int j = 0; j < n_decoders_cur; ++j) { auto & decoder = ctx->decoders[j]; @@ -4339,7 +4339,7 @@ int whisper_full_n_segments(struct whisper_context * ctx) { } int whisper_full_lang_id(struct whisper_context * ctx) { - return ctx->lang_id; + return ctx->lang_id; } int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) { diff --git a/examples/whisper/whisper.h b/examples/whisper/whisper.h index 7eece79..3eb8d08 100644 --- a/examples/whisper/whisper.h +++ b/examples/whisper/whisper.h @@ -243,6 +243,16 @@ extern "C" { // If it returns false, the computation is aborted typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data); + // Logits filter callback + // Can be used to modify the logits before sampling + // If not NULL, called after applying temperature to logits + typedef void (*whisper_logits_filter_callback)( + struct whisper_context * ctx, + const whisper_token_data * tokens, + int n_tokens, + float * logits, + void * user_data); + // Parameters for the whisper_full() function // If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp: // whisper_full_default_params() @@ -315,6 +325,10 @@ extern "C" { // called each time before the encoder starts whisper_encoder_begin_callback encoder_begin_callback; void * encoder_begin_callback_user_data; + + // called by each decoder to filter obtained logits + whisper_logits_filter_callback logits_filter_callback; + void * logits_filter_callback_user_data; }; WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);