|
|
|
@ -25,7 +25,9 @@
|
|
|
|
|
} \
|
|
|
|
|
} while (0)
|
|
|
|
|
|
|
|
|
|
// define this to enable verbose trace logging - useful for debugging purposes
|
|
|
|
|
//#define WHISPER_DEBUG
|
|
|
|
|
|
|
|
|
|
#if defined(WHISPER_DEBUG)
|
|
|
|
|
#define WHISPER_PRINT_DEBUG(...) \
|
|
|
|
|
do { \
|
|
|
|
@ -380,6 +382,8 @@ struct whisper_kv_cache {
|
|
|
|
|
struct ggml_context * ctx;
|
|
|
|
|
|
|
|
|
|
std::vector<uint8_t> buf;
|
|
|
|
|
|
|
|
|
|
int n; // number of tokens currently in the cache
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct whisper_model {
|
|
|
|
@ -442,12 +446,13 @@ struct whisper_sequence {
|
|
|
|
|
|
|
|
|
|
// TAGS: WHISPER_DECODER_INIT
|
|
|
|
|
struct whisper_decoder {
|
|
|
|
|
// each decoders keeps its own KV-cache
|
|
|
|
|
whisper_kv_cache kv_self;
|
|
|
|
|
|
|
|
|
|
// the currently generated sequence of tokens
|
|
|
|
|
whisper_sequence sequence;
|
|
|
|
|
|
|
|
|
|
int n_past;
|
|
|
|
|
int seek_delta;
|
|
|
|
|
int seek_delta; // the window shift found so far based on the decoded timestamp tokens
|
|
|
|
|
|
|
|
|
|
bool failed; // has the current segment failed to decode?
|
|
|
|
|
bool completed; // has the decoder completed the current segment?
|
|
|
|
@ -476,6 +481,8 @@ struct whisper_context {
|
|
|
|
|
whisper_model model;
|
|
|
|
|
whisper_vocab vocab;
|
|
|
|
|
|
|
|
|
|
// cross-attention KV cache for the decoders
|
|
|
|
|
// shared between all decoders
|
|
|
|
|
whisper_kv_cache kv_cross;
|
|
|
|
|
|
|
|
|
|
whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
|
|
|
|
@ -490,6 +497,7 @@ struct whisper_context {
|
|
|
|
|
std::vector<whisper_segment> result_all;
|
|
|
|
|
std::vector<whisper_token> prompt_past;
|
|
|
|
|
|
|
|
|
|
// work container used to avoid memory allocations
|
|
|
|
|
std::vector<std::pair<double, whisper_vocab::id>> logits_id;
|
|
|
|
|
|
|
|
|
|
mutable std::mt19937 rng; // used for sampling at t > 0.0
|
|
|
|
@ -680,6 +688,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
|
|
mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// initialize all memory buffers
|
|
|
|
|
// always have at least one decoder
|
|
|
|
|
|
|
|
|
|
wctx.model.buf = new std::vector<uint8_t>();
|
|
|
|
|
wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type));
|
|
|
|
|
|
|
|
|
@ -2671,127 +2682,77 @@ const char * whisper_print_system_info(void) {
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
|
|
|
|
|
struct whisper_full_params result;
|
|
|
|
|
struct whisper_full_params result = {
|
|
|
|
|
/*.strategy =*/ WHISPER_SAMPLING_GREEDY,
|
|
|
|
|
|
|
|
|
|
switch (strategy) {
|
|
|
|
|
case WHISPER_SAMPLING_GREEDY:
|
|
|
|
|
{
|
|
|
|
|
result = {
|
|
|
|
|
/*.strategy =*/ WHISPER_SAMPLING_GREEDY,
|
|
|
|
|
/*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
|
|
|
|
|
/*.n_max_text_ctx =*/ 16384,
|
|
|
|
|
/*.offset_ms =*/ 0,
|
|
|
|
|
/*.duration_ms =*/ 0,
|
|
|
|
|
|
|
|
|
|
/*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
|
|
|
|
|
/*.n_max_text_ctx =*/ 16384,
|
|
|
|
|
/*.offset_ms =*/ 0,
|
|
|
|
|
/*.duration_ms =*/ 0,
|
|
|
|
|
/*.translate =*/ false,
|
|
|
|
|
/*.no_context =*/ false,
|
|
|
|
|
/*.single_segment =*/ false,
|
|
|
|
|
/*.print_special =*/ false,
|
|
|
|
|
/*.print_progress =*/ true,
|
|
|
|
|
/*.print_realtime =*/ false,
|
|
|
|
|
/*.print_timestamps =*/ true,
|
|
|
|
|
|
|
|
|
|
/*.translate =*/ false,
|
|
|
|
|
/*.no_context =*/ false,
|
|
|
|
|
/*.single_segment =*/ false,
|
|
|
|
|
/*.print_special =*/ false,
|
|
|
|
|
/*.print_progress =*/ true,
|
|
|
|
|
/*.print_realtime =*/ false,
|
|
|
|
|
/*.print_timestamps =*/ true,
|
|
|
|
|
/*.token_timestamps =*/ false,
|
|
|
|
|
/*.thold_pt =*/ 0.01f,
|
|
|
|
|
/*.thold_ptsum =*/ 0.01f,
|
|
|
|
|
/*.max_len =*/ 0,
|
|
|
|
|
/*.max_tokens =*/ 0,
|
|
|
|
|
|
|
|
|
|
/*.token_timestamps =*/ false,
|
|
|
|
|
/*.thold_pt =*/ 0.01f,
|
|
|
|
|
/*.thold_ptsum =*/ 0.01f,
|
|
|
|
|
/*.max_len =*/ 0,
|
|
|
|
|
/*.max_tokens =*/ 0,
|
|
|
|
|
/*.speed_up =*/ false,
|
|
|
|
|
/*.audio_ctx =*/ 0,
|
|
|
|
|
|
|
|
|
|
/*.speed_up =*/ false,
|
|
|
|
|
/*.audio_ctx =*/ 0,
|
|
|
|
|
/*.prompt_tokens =*/ nullptr,
|
|
|
|
|
/*.prompt_n_tokens =*/ 0,
|
|
|
|
|
|
|
|
|
|
/*.prompt_tokens =*/ nullptr,
|
|
|
|
|
/*.prompt_n_tokens =*/ 0,
|
|
|
|
|
/*.language =*/ "en",
|
|
|
|
|
|
|
|
|
|
/*.language =*/ "en",
|
|
|
|
|
/*.suppress_blank =*/ true,
|
|
|
|
|
|
|
|
|
|
/*.suppress_blank =*/ true,
|
|
|
|
|
/*.temperature =*/ 0.0f,
|
|
|
|
|
/*.max_initial_ts =*/ 1.0f,
|
|
|
|
|
/*.length_penalty =*/ -1.0f,
|
|
|
|
|
|
|
|
|
|
/*.temperature =*/ 0.0f,
|
|
|
|
|
/*.max_initial_timestamp =*/ 1.0f,
|
|
|
|
|
/*.length_penalty =*/ -1.0f,
|
|
|
|
|
/*.temperature_inc =*/ 0.2f,
|
|
|
|
|
/*.entropy_thold =*/ 2.4f,
|
|
|
|
|
/*.logprob_thold =*/ -1.0f,
|
|
|
|
|
/*.no_speech_thold =*/ 0.6f,
|
|
|
|
|
|
|
|
|
|
/*.temperature_increment =*/ 0.2f,
|
|
|
|
|
/*.entropy_threshold =*/ 2.4f,
|
|
|
|
|
/*.logprob_threshold =*/ -1.0f,
|
|
|
|
|
/*.no_speech_threshold =*/ 0.6f,
|
|
|
|
|
/*.greedy =*/ {
|
|
|
|
|
/*.best_of =*/ -1,
|
|
|
|
|
},
|
|
|
|
|
|
|
|
|
|
/*.greedy =*/ {
|
|
|
|
|
/*.best_of =*/ 5,
|
|
|
|
|
},
|
|
|
|
|
/*.beam_search =*/ {
|
|
|
|
|
/*.beam_size =*/ -1,
|
|
|
|
|
|
|
|
|
|
/*.beam_search =*/ {
|
|
|
|
|
/*.beam_size =*/ -1,
|
|
|
|
|
/*.patience =*/ -1.0f,
|
|
|
|
|
},
|
|
|
|
|
|
|
|
|
|
/*.patience =*/ -1.0f,
|
|
|
|
|
},
|
|
|
|
|
/*.new_segment_callback =*/ nullptr,
|
|
|
|
|
/*.new_segment_callback_user_data =*/ nullptr,
|
|
|
|
|
|
|
|
|
|
/*.new_segment_callback =*/ nullptr,
|
|
|
|
|
/*.new_segment_callback_user_data =*/ nullptr,
|
|
|
|
|
/*.encoder_begin_callback =*/ nullptr,
|
|
|
|
|
/*.encoder_begin_callback_user_data =*/ nullptr,
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
/*.encoder_begin_callback =*/ nullptr,
|
|
|
|
|
/*.encoder_begin_callback_user_data =*/ nullptr,
|
|
|
|
|
switch (strategy) {
|
|
|
|
|
case WHISPER_SAMPLING_GREEDY:
|
|
|
|
|
{
|
|
|
|
|
result.greedy = {
|
|
|
|
|
/*.best_of =*/ 1,
|
|
|
|
|
};
|
|
|
|
|
} break;
|
|
|
|
|
case WHISPER_SAMPLING_BEAM_SEARCH:
|
|
|
|
|
{
|
|
|
|
|
result = {
|
|
|
|
|
/*.strategy =*/ WHISPER_SAMPLING_BEAM_SEARCH,
|
|
|
|
|
|
|
|
|
|
/*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
|
|
|
|
|
/*.n_max_text_ctx =*/ 16384,
|
|
|
|
|
/*.offset_ms =*/ 0,
|
|
|
|
|
/*.duration_ms =*/ 0,
|
|
|
|
|
|
|
|
|
|
/*.translate =*/ false,
|
|
|
|
|
/*.no_context =*/ false,
|
|
|
|
|
/*.single_segment =*/ false,
|
|
|
|
|
/*.print_special =*/ false,
|
|
|
|
|
/*.print_progress =*/ true,
|
|
|
|
|
/*.print_realtime =*/ false,
|
|
|
|
|
/*.print_timestamps =*/ true,
|
|
|
|
|
|
|
|
|
|
/*.token_timestamps =*/ false,
|
|
|
|
|
/*.thold_pt =*/ 0.01f,
|
|
|
|
|
/*.thold_ptsum =*/ 0.01f,
|
|
|
|
|
/*.max_len =*/ 0,
|
|
|
|
|
/*.max_tokens =*/ 0,
|
|
|
|
|
|
|
|
|
|
/*.speed_up =*/ false,
|
|
|
|
|
/*.audio_ctx =*/ 0,
|
|
|
|
|
|
|
|
|
|
/*.prompt_tokens =*/ nullptr,
|
|
|
|
|
/*.prompt_n_tokens =*/ 0,
|
|
|
|
|
|
|
|
|
|
/*.language =*/ "en",
|
|
|
|
|
|
|
|
|
|
/*.suppress_blank =*/ true,
|
|
|
|
|
|
|
|
|
|
/*.temperature =*/ 0.0f,
|
|
|
|
|
/*.max_initial_timestamp =*/ 1.0f,
|
|
|
|
|
/*.length_penalty =*/ -1.0f,
|
|
|
|
|
|
|
|
|
|
/*.temperature_increment =*/ 0.2f,
|
|
|
|
|
/*.entropy_threshold =*/ 2.4f,
|
|
|
|
|
/*.logprob_threshold =*/ -1.0f,
|
|
|
|
|
/*.no_speech_threshold =*/ 0.6f,
|
|
|
|
|
result.beam_search = {
|
|
|
|
|
/*.beam_size =*/ 5,
|
|
|
|
|
|
|
|
|
|
/*.greedy =*/ {
|
|
|
|
|
/*.best_of =*/ 5,
|
|
|
|
|
},
|
|
|
|
|
|
|
|
|
|
/*.beam_search =*/ {
|
|
|
|
|
/*.beam_size =*/ 5,
|
|
|
|
|
|
|
|
|
|
/*.patience =*/ -1.0f,
|
|
|
|
|
},
|
|
|
|
|
|
|
|
|
|
/*.new_segment_callback =*/ nullptr,
|
|
|
|
|
/*.new_segment_callback_user_data =*/ nullptr,
|
|
|
|
|
|
|
|
|
|
/*.encoder_begin_callback =*/ nullptr,
|
|
|
|
|
/*.encoder_begin_callback_user_data =*/ nullptr,
|
|
|
|
|
/*.patience =*/ -1.0f,
|
|
|
|
|
};
|
|
|
|
|
} break;
|
|
|
|
|
}
|
|
|
|
@ -2862,8 +2823,8 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// process the logits for the selected decoder
|
|
|
|
|
// - applyies logit filters
|
|
|
|
|
// - computes logprobs
|
|
|
|
|
// - applies logit filters
|
|
|
|
|
// - computes logprobs and probs
|
|
|
|
|
static void whisper_process_logits(
|
|
|
|
|
const struct whisper_context & ctx,
|
|
|
|
|
const struct whisper_full_params params,
|
|
|
|
@ -2938,11 +2899,11 @@ static void whisper_process_logits(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// the initial timestamp cannot be larger than max_initial_timestamp
|
|
|
|
|
// the initial timestamp cannot be larger than max_initial_ts
|
|
|
|
|
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429
|
|
|
|
|
if (is_initial && params.max_initial_timestamp > 0.0f) {
|
|
|
|
|
if (is_initial && params.max_initial_ts > 0.0f) {
|
|
|
|
|
const float precision = float(WHISPER_CHUNK_SIZE)/ctx.model.hparams.n_audio_ctx;
|
|
|
|
|
const int tid0 = std::round(params.max_initial_timestamp/precision);
|
|
|
|
|
const int tid0 = std::round(params.max_initial_ts/precision);
|
|
|
|
|
|
|
|
|
|
for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++i) {
|
|
|
|
|
logits[i] = -INFINITY;
|
|
|
|
@ -3073,7 +3034,7 @@ static whisper_token_data whisper_sample_token(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
result.pt = max_ts/(sum_ts + 1e-10);
|
|
|
|
|
result.pt = max_ts/(sum_ts + 1e-10);
|
|
|
|
|
result.ptsum = sum_ts;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -3127,7 +3088,13 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
|
|
|
|
|
return a.first > b.first;
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
std::vector<whisper_token_data> result(k);
|
|
|
|
|
std::vector<whisper_token_data> result;
|
|
|
|
|
result.reserve(k);
|
|
|
|
|
|
|
|
|
|
whisper_token tid;
|
|
|
|
|
|
|
|
|
|
float pt;
|
|
|
|
|
float ptsum;
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
double sum_ts = 0.0;
|
|
|
|
@ -3141,24 +3108,18 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
|
|
|
|
|
sum_ts += probs[i];
|
|
|
|
|
if (max_ts < probs[i]) {
|
|
|
|
|
max_ts = probs[i];
|
|
|
|
|
result[0].tid = i;
|
|
|
|
|
tid = i;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
result[0].pt = max_ts/(sum_ts + 1e-10);
|
|
|
|
|
result[0].ptsum = sum_ts;
|
|
|
|
|
pt = max_ts/(sum_ts + 1e-10);
|
|
|
|
|
ptsum = sum_ts;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < k; ++i) {
|
|
|
|
|
result[i].id = logits_id[i].second;
|
|
|
|
|
result[i].p = probs[result[i].id];
|
|
|
|
|
result[i].plog = logprobs[result[i].id];
|
|
|
|
|
result[i].tid = result[0].tid;
|
|
|
|
|
result[i].pt = result[0].pt;
|
|
|
|
|
result[i].ptsum = result[0].ptsum;
|
|
|
|
|
result[i].t0 = -1;
|
|
|
|
|
result[i].t1 = -1;
|
|
|
|
|
result[i].vlen = 0.0f;
|
|
|
|
|
const auto id = logits_id[i].second;
|
|
|
|
|
|
|
|
|
|
result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, });
|
|
|
|
|
|
|
|
|
|
if (result[i].id >= vocab.token_beg) {
|
|
|
|
|
result[i].tid = result[i].id;
|
|
|
|
@ -3276,8 +3237,8 @@ int whisper_full(
|
|
|
|
|
// a set of temperatures to use
|
|
|
|
|
// [ t0, t0 + delta, t0 + 2*delta, ..., < 1.0f + 1e-6f ]
|
|
|
|
|
std::vector<float> temperatures;
|
|
|
|
|
if (params.temperature_increment > 0.0f) {
|
|
|
|
|
for (float t = params.temperature; t < 1.0f + 1e-6f; t += params.temperature_increment) {
|
|
|
|
|
if (params.temperature_inc > 0.0f) {
|
|
|
|
|
for (float t = params.temperature; t < 1.0f + 1e-6f; t += params.temperature_inc) {
|
|
|
|
|
temperatures.push_back(t);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
@ -3448,6 +3409,8 @@ int whisper_full(
|
|
|
|
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
|
|
|
auto & decoder = ctx->decoders[j];
|
|
|
|
|
|
|
|
|
|
decoder.kv_self.n = 0;
|
|
|
|
|
|
|
|
|
|
decoder.sequence.tokens.clear();
|
|
|
|
|
decoder.sequence.result_len = 0;
|
|
|
|
|
decoder.sequence.sum_logprobs_all = 0.0;
|
|
|
|
@ -3456,7 +3419,6 @@ int whisper_full(
|
|
|
|
|
decoder.sequence.entropy = 0.0;
|
|
|
|
|
decoder.sequence.score = -INFINITY;
|
|
|
|
|
|
|
|
|
|
decoder.n_past = 0;
|
|
|
|
|
decoder.seek_delta = 100*WHISPER_CHUNK_SIZE;
|
|
|
|
|
|
|
|
|
|
decoder.failed = false;
|
|
|
|
@ -3497,7 +3459,7 @@ int whisper_full(
|
|
|
|
|
|
|
|
|
|
whisper_process_logits(*ctx, params, ctx->decoders[0], t_cur);
|
|
|
|
|
|
|
|
|
|
ctx->decoders[0].n_past += prompt.size();
|
|
|
|
|
ctx->decoders[0].kv_self.n += prompt.size();
|
|
|
|
|
|
|
|
|
|
for (int j = 1; j < n_decoders_cur; ++j) {
|
|
|
|
|
auto & decoder = ctx->decoders[j];
|
|
|
|
@ -3505,7 +3467,7 @@ int whisper_full(
|
|
|
|
|
memcpy(decoder.kv_self.k->data, ctx->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
|
|
|
|
|
memcpy(decoder.kv_self.v->data, ctx->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
|
|
|
|
|
|
|
|
|
|
decoder.n_past += prompt.size();
|
|
|
|
|
decoder.kv_self.n += prompt.size();
|
|
|
|
|
|
|
|
|
|
memcpy(decoder.probs.data(), ctx->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
|
|
|
|
|
memcpy(decoder.logits.data(), ctx->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
|
|
|
|
@ -3519,6 +3481,7 @@ int whisper_full(
|
|
|
|
|
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
|
|
|
|
|
const int64_t t_start_sample_us = ggml_time_us();
|
|
|
|
|
|
|
|
|
|
// store the KV caches of all decoders when doing beam-search
|
|
|
|
|
if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
|
|
|
|
|
kv_bufs.resize(n_decoders_cur);
|
|
|
|
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
|
|
@ -3538,6 +3501,7 @@ int whisper_full(
|
|
|
|
|
beam_candidates.clear();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// generate new sequence candidates for each decoder
|
|
|
|
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
|
|
|
auto & decoder = ctx->decoders[j];
|
|
|
|
|
|
|
|
|
@ -3571,6 +3535,7 @@ int whisper_full(
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// for beam-search, choose the top candidates and update the KV caches
|
|
|
|
|
if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
|
|
|
|
|
std::sort(
|
|
|
|
|
beam_candidates.begin(),
|
|
|
|
@ -3606,6 +3571,10 @@ int whisper_full(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// update the decoder state
|
|
|
|
|
// - check if the sequence is completed
|
|
|
|
|
// - check if the sequence is failed
|
|
|
|
|
// - update sliding window based on timestamp tokens
|
|
|
|
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
|
|
|
auto & decoder = ctx->decoders[j];
|
|
|
|
|
|
|
|
|
@ -3705,6 +3674,7 @@ int whisper_full(
|
|
|
|
|
|
|
|
|
|
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
|
|
|
|
|
|
|
|
// obtain logits for the next token
|
|
|
|
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
|
|
|
auto & decoder = ctx->decoders[j];
|
|
|
|
|
|
|
|
|
@ -3715,11 +3685,11 @@ int whisper_full(
|
|
|
|
|
decoder.tokens_tmp.resize(1);
|
|
|
|
|
decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id;
|
|
|
|
|
|
|
|
|
|
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, n_past %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.n_past, decoder.seek_delta);
|
|
|
|
|
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
|
|
|
|
|
|
|
|
|
|
if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.n_past, params.n_threads)) {
|
|
|
|
|
if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) {
|
|
|
|
|
fprintf(stderr, "%s: failed to decode\n", __func__);
|
|
|
|
|
return -7;
|
|
|
|
|
return -8;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
@ -3727,7 +3697,7 @@ int whisper_full(
|
|
|
|
|
|
|
|
|
|
whisper_process_logits(*ctx, params, decoder, t_cur);
|
|
|
|
|
|
|
|
|
|
++decoder.n_past;
|
|
|
|
|
++decoder.kv_self.n;
|
|
|
|
|
|
|
|
|
|
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
|
|
|
}
|
|
|
|
@ -3736,7 +3706,7 @@ int whisper_full(
|
|
|
|
|
|
|
|
|
|
// rank the resulting sequences and select the best one
|
|
|
|
|
{
|
|
|
|
|
double best_score = -1e9;
|
|
|
|
|
double best_score = -INFINITY;
|
|
|
|
|
|
|
|
|
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
|
|
|
auto & decoder = ctx->decoders[j];
|
|
|
|
@ -3751,9 +3721,9 @@ int whisper_full(
|
|
|
|
|
WHISPER_PRINT_DEBUG("%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n",
|
|
|
|
|
__func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy);
|
|
|
|
|
|
|
|
|
|
if (decoder.sequence.result_len > 8 && decoder.sequence.entropy < params.entropy_threshold) {
|
|
|
|
|
if (decoder.sequence.result_len > 8 && decoder.sequence.entropy < params.entropy_thold) {
|
|
|
|
|
WHISPER_PRINT_DEBUG("%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n",
|
|
|
|
|
__func__, j, decoder.sequence.entropy, params.entropy_threshold);
|
|
|
|
|
__func__, j, decoder.sequence.entropy, params.entropy_thold);
|
|
|
|
|
|
|
|
|
|
decoder.failed = true;
|
|
|
|
|
|
|
|
|
@ -3765,31 +3735,33 @@ int whisper_full(
|
|
|
|
|
best_decoder_id = j;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool success = true;
|
|
|
|
|
WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// implement logprob threshold
|
|
|
|
|
// was the decoding successful for the current temperature?
|
|
|
|
|
{
|
|
|
|
|
auto & decoder = ctx->decoders[best_decoder_id];
|
|
|
|
|
bool success = true;
|
|
|
|
|
|
|
|
|
|
const auto & decoder = ctx->decoders[best_decoder_id];
|
|
|
|
|
|
|
|
|
|
if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_threshold) {
|
|
|
|
|
if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
|
|
|
|
|
success = false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (success) {
|
|
|
|
|
WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id);
|
|
|
|
|
//for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) {
|
|
|
|
|
// WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str());
|
|
|
|
|
//}
|
|
|
|
|
if (success) {
|
|
|
|
|
//for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) {
|
|
|
|
|
// WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str());
|
|
|
|
|
//}
|
|
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
WHISPER_PRINT_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// output results through a user-provided callback
|
|
|
|
|
{
|
|
|
|
|
const auto & best_decoder = ctx->decoders[best_decoder_id];
|
|
|
|
|
|
|
|
|
|