|
|
@ -2962,6 +2962,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
|
|
|
|
|
|
|
|
|
/*.encoder_begin_callback =*/ nullptr,
|
|
|
|
/*.encoder_begin_callback =*/ nullptr,
|
|
|
|
/*.encoder_begin_callback_user_data =*/ nullptr,
|
|
|
|
/*.encoder_begin_callback_user_data =*/ nullptr,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/*.logits_filter_callback =*/ nullptr,
|
|
|
|
|
|
|
|
/*.logits_filter_callback_user_data =*/ nullptr,
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
switch (strategy) {
|
|
|
|
switch (strategy) {
|
|
|
@ -3089,7 +3092,7 @@ static const std::vector<std::string> non_speech_tokens = {
|
|
|
|
// - applies logit filters
|
|
|
|
// - applies logit filters
|
|
|
|
// - computes logprobs and probs
|
|
|
|
// - computes logprobs and probs
|
|
|
|
static void whisper_process_logits(
|
|
|
|
static void whisper_process_logits(
|
|
|
|
const struct whisper_context & ctx,
|
|
|
|
struct whisper_context & ctx,
|
|
|
|
const struct whisper_full_params params,
|
|
|
|
const struct whisper_full_params params,
|
|
|
|
struct whisper_decoder & decoder,
|
|
|
|
struct whisper_decoder & decoder,
|
|
|
|
float temperature) {
|
|
|
|
float temperature) {
|
|
|
@ -3145,6 +3148,9 @@ static void whisper_process_logits(
|
|
|
|
logits[vocab.token_translate] = -INFINITY;
|
|
|
|
logits[vocab.token_translate] = -INFINITY;
|
|
|
|
logits[vocab.token_transcribe] = -INFINITY;
|
|
|
|
logits[vocab.token_transcribe] = -INFINITY;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (params.logits_filter_callback) {
|
|
|
|
|
|
|
|
params.logits_filter_callback(&ctx, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// suppress non-speech tokens
|
|
|
|
// suppress non-speech tokens
|
|
|
|
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
|
|
|
|
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
|
|
|
@ -3848,7 +3854,7 @@ int whisper_full(
|
|
|
|
return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
|
|
|
|
return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
|
|
|
|
});
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
|
|
unsigned int cur_c = 0;
|
|
|
|
uint32_t cur_c = 0;
|
|
|
|
|
|
|
|
|
|
|
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
|
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
|
|
auto & decoder = ctx->decoders[j];
|
|
|
|
auto & decoder = ctx->decoders[j];
|
|
|
|