diff --git a/whisper.cpp b/whisper.cpp index 04cbc36..9a9e562 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -805,7 +805,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con MEM_REQ_SCRATCH3.at (model.type) + scale*MEM_REQ_MODEL.at (model.type) + scale*MEM_REQ_KV_CROSS.at(model.type) + - scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)); + scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)); // this is the memory required by one decoder const size_t mem_required_decoder = @@ -2962,6 +2962,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.encoder_begin_callback =*/ nullptr, /*.encoder_begin_callback_user_data =*/ nullptr, + + /*.logits_filter_callback =*/ nullptr, + /*.logits_filter_callback_user_data =*/ nullptr, }; switch (strategy) { @@ -3089,7 +3092,7 @@ static const std::vector non_speech_tokens = { // - applies logit filters // - computes logprobs and probs static void whisper_process_logits( - const struct whisper_context & ctx, + struct whisper_context & ctx, const struct whisper_full_params params, struct whisper_decoder & decoder, float temperature) { @@ -3145,6 +3148,9 @@ static void whisper_process_logits( logits[vocab.token_translate] = -INFINITY; logits[vocab.token_transcribe] = -INFINITY; + if (params.logits_filter_callback) { + params.logits_filter_callback(&ctx, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data); + } // suppress non-speech tokens // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 @@ -3848,7 +3854,7 @@ int whisper_full( return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all; }); - unsigned int cur_c = 0; + uint32_t cur_c = 0; for (int j = 0; j < n_decoders_cur; ++j) { auto & decoder = ctx->decoders[j]; diff --git a/whisper.h b/whisper.h index 7eece79..3eb8d08 100644 --- a/whisper.h +++ b/whisper.h @@ -243,6 +243,16 @@ extern "C" { // If it returns false, the computation is aborted typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data); + // Logits filter callback + // Can be used to modify the logits before sampling + // If not NULL, called after applying temperature to logits + typedef void (*whisper_logits_filter_callback)( + struct whisper_context * ctx, + const whisper_token_data * tokens, + int n_tokens, + float * logits, + void * user_data); + // Parameters for the whisper_full() function // If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp: // whisper_full_default_params() @@ -315,6 +325,10 @@ extern "C" { // called each time before the encoder starts whisper_encoder_begin_callback encoder_begin_callback; void * encoder_begin_callback_user_data; + + // called by each decoder to filter obtained logits + whisper_logits_filter_callback logits_filter_callback; + void * logits_filter_callback_user_data; }; WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);