|
|
|
@ -15,6 +15,7 @@
|
|
|
|
|
#include <thread>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include <regex>
|
|
|
|
|
#include <random>
|
|
|
|
|
|
|
|
|
|
#define WHISPER_ASSERT(x) \
|
|
|
|
|
do { \
|
|
|
|
@ -479,6 +480,8 @@ struct whisper_context {
|
|
|
|
|
|
|
|
|
|
std::vector<std::pair<double, whisper_vocab::id>> logits_id;
|
|
|
|
|
|
|
|
|
|
mutable std::mt19937 rng; // used for sampling at t > 0.0
|
|
|
|
|
|
|
|
|
|
// [EXPERIMENTAL] token-level timestamps data
|
|
|
|
|
int64_t t_beg;
|
|
|
|
|
int64_t t_last;
|
|
|
|
@ -1180,6 +1183,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
wctx.rng = std::mt19937(0);
|
|
|
|
|
|
|
|
|
|
wctx.t_load_us = ggml_time_us() - t_start_us;
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
@ -2848,9 +2853,10 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) {
|
|
|
|
|
// - applyies logit filters
|
|
|
|
|
// - computes logprobs
|
|
|
|
|
static void whisper_process_logits(
|
|
|
|
|
struct whisper_context & ctx,
|
|
|
|
|
struct whisper_decoder & decoder,
|
|
|
|
|
struct whisper_full_params params) {
|
|
|
|
|
const struct whisper_context & ctx,
|
|
|
|
|
const struct whisper_full_params params,
|
|
|
|
|
struct whisper_decoder & decoder,
|
|
|
|
|
float temperature) {
|
|
|
|
|
const auto & vocab = ctx.vocab;
|
|
|
|
|
const auto & tokens_cur = decoder.sequence.tokens;
|
|
|
|
|
|
|
|
|
@ -2868,6 +2874,12 @@ static void whisper_process_logits(
|
|
|
|
|
logits.resize(n_logits);
|
|
|
|
|
memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float));
|
|
|
|
|
|
|
|
|
|
if (temperature > 0.0f) {
|
|
|
|
|
for (int i = 0; i < n_logits; i++) {
|
|
|
|
|
logits[i] /= temperature;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// will be populated a bit later
|
|
|
|
|
probs.resize(n_logits);
|
|
|
|
|
logprobs.resize(n_logits);
|
|
|
|
@ -3018,10 +3030,10 @@ static void whisper_process_logits(
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// select the most probable token
|
|
|
|
|
static whisper_token_data whisper_sample_best(
|
|
|
|
|
whisper_context & ctx,
|
|
|
|
|
whisper_decoder & decoder) {
|
|
|
|
|
static whisper_token_data whisper_sample_token(
|
|
|
|
|
const whisper_context & ctx,
|
|
|
|
|
const whisper_decoder & decoder,
|
|
|
|
|
bool best) {
|
|
|
|
|
whisper_token_data result = {
|
|
|
|
|
0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
|
|
|
|
|
};
|
|
|
|
@ -3053,12 +3065,20 @@ static whisper_token_data whisper_sample_best(
|
|
|
|
|
result.ptsum = sum_ts;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < n_logits; ++i) {
|
|
|
|
|
if (result.p < probs[i]) {
|
|
|
|
|
result.id = i;
|
|
|
|
|
result.p = probs[i];
|
|
|
|
|
result.plog = logprobs[i];
|
|
|
|
|
if (best) {
|
|
|
|
|
for (int i = 0; i < n_logits; ++i) {
|
|
|
|
|
if (result.p < probs[i]) {
|
|
|
|
|
result.id = i;
|
|
|
|
|
result.p = probs[i];
|
|
|
|
|
result.plog = logprobs[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
|
|
|
|
|
|
|
|
|
result.id = dist(ctx.rng);
|
|
|
|
|
result.p = probs[result.id];
|
|
|
|
|
result.plog = logprobs[result.id];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return result;
|
|
|
|
@ -3167,22 +3187,24 @@ int whisper_full(
|
|
|
|
|
} break;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
for (int i = 1; i < n_decoders; i++) {
|
|
|
|
|
// TAGS: WHISPER_DECODER_INIT
|
|
|
|
|
if (ctx->decoders[i].kv_self.ctx == nullptr) {
|
|
|
|
|
ctx->decoders[i].kv_self = ctx->decoders[0].kv_self;
|
|
|
|
|
if (!kv_cache_reinit(ctx->decoders[i].kv_self)) {
|
|
|
|
|
fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, i);
|
|
|
|
|
// TAGS: WHISPER_DECODER_INIT
|
|
|
|
|
for (int j = 1; j < n_decoders; j++) {
|
|
|
|
|
auto & decoder = ctx->decoders[j];
|
|
|
|
|
|
|
|
|
|
if (decoder.kv_self.ctx == nullptr) {
|
|
|
|
|
decoder.kv_self = ctx->decoders[0].kv_self;
|
|
|
|
|
if (!kv_cache_reinit(decoder.kv_self)) {
|
|
|
|
|
fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
|
|
|
|
|
return -4;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fprintf(stderr, "%s: initialized self-attention kv cache, decoder %d\n", __func__, i);
|
|
|
|
|
fprintf(stderr, "%s: initialized self-attention kv cache, decoder %d\n", __func__, j);
|
|
|
|
|
|
|
|
|
|
ctx->decoders[i].sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity());
|
|
|
|
|
decoder.sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity());
|
|
|
|
|
|
|
|
|
|
ctx->decoders[i].probs.resize (ctx->vocab.n_vocab);
|
|
|
|
|
ctx->decoders[i].logits.resize (ctx->vocab.n_vocab);
|
|
|
|
|
ctx->decoders[i].logprobs.resize(ctx->vocab.n_vocab);
|
|
|
|
|
decoder.probs.resize (ctx->vocab.n_vocab);
|
|
|
|
|
decoder.logits.resize (ctx->vocab.n_vocab);
|
|
|
|
|
decoder.logprobs.resize(ctx->vocab.n_vocab);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -3268,6 +3290,7 @@ int whisper_full(
|
|
|
|
|
const float t_cur = temperatures[it];
|
|
|
|
|
|
|
|
|
|
int n_decoders_cur = 1;
|
|
|
|
|
|
|
|
|
|
switch (params.strategy) {
|
|
|
|
|
case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
|
|
|
|
|
{
|
|
|
|
@ -3330,7 +3353,7 @@ int whisper_full(
|
|
|
|
|
// print the prompt
|
|
|
|
|
//printf("\n\n");
|
|
|
|
|
//for (int i = 0; i < (int) prompt.size(); i++) {
|
|
|
|
|
// printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token[prompt[i]].c_str());
|
|
|
|
|
// printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str());
|
|
|
|
|
//}
|
|
|
|
|
//printf("\n\n");
|
|
|
|
|
|
|
|
|
@ -3342,7 +3365,7 @@ int whisper_full(
|
|
|
|
|
{
|
|
|
|
|
const int64_t t_start_sample_us = ggml_time_us();
|
|
|
|
|
|
|
|
|
|
whisper_process_logits(*ctx, ctx->decoders[0], params);
|
|
|
|
|
whisper_process_logits(*ctx, params, ctx->decoders[0], t_cur);
|
|
|
|
|
|
|
|
|
|
ctx->decoders[0].n_past += prompt.size();
|
|
|
|
|
|
|
|
|
@ -3376,8 +3399,9 @@ int whisper_full(
|
|
|
|
|
case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
|
|
|
|
|
{
|
|
|
|
|
if (t_cur < 1e-6f) {
|
|
|
|
|
decoder.sequence.tokens.push_back(whisper_sample_best(*ctx, decoder));
|
|
|
|
|
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true));
|
|
|
|
|
} else {
|
|
|
|
|
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false));
|
|
|
|
|
}
|
|
|
|
|
} break;
|
|
|
|
|
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
|
|
|
|
@ -3402,7 +3426,7 @@ int whisper_full(
|
|
|
|
|
// do not allow to go back in time
|
|
|
|
|
if (has_ts && seek_delta > seek_delta_new && result_len < i) {
|
|
|
|
|
failed = true; // TODO: maybe this is not a failure ?
|
|
|
|
|
break;
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
seek_delta = seek_delta_new;
|
|
|
|
@ -3410,11 +3434,11 @@ int whisper_full(
|
|
|
|
|
has_ts = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]";
|
|
|
|
|
printf("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n",
|
|
|
|
|
__func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token[token.id].c_str());
|
|
|
|
|
}
|
|
|
|
|
//{
|
|
|
|
|
// const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]";
|
|
|
|
|
// printf("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n",
|
|
|
|
|
// __func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token.at(token.id).c_str());
|
|
|
|
|
//}
|
|
|
|
|
|
|
|
|
|
// end of segment
|
|
|
|
|
if (token.id == whisper_token_eot(ctx) || // end of text token
|
|
|
|
@ -3426,7 +3450,7 @@ int whisper_full(
|
|
|
|
|
result_len = i + 1;
|
|
|
|
|
} else {
|
|
|
|
|
failed = true;
|
|
|
|
|
break;
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -3436,14 +3460,14 @@ int whisper_full(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
completed = true;
|
|
|
|
|
break;
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TESTS: if no tensors are loaded, it means we are running tests
|
|
|
|
|
if (ctx->model.n_loaded == 0) {
|
|
|
|
|
seek_delta = 100*WHISPER_CHUNK_SIZE;
|
|
|
|
|
completed = true;
|
|
|
|
|
break;
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -3451,7 +3475,7 @@ int whisper_full(
|
|
|
|
|
// this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
|
|
|
|
|
if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
|
|
|
|
|
failed = true;
|
|
|
|
|
break;
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -3496,7 +3520,7 @@ int whisper_full(
|
|
|
|
|
{
|
|
|
|
|
const int64_t t_start_sample_us = ggml_time_us();
|
|
|
|
|
|
|
|
|
|
whisper_process_logits(*ctx, decoder, params);
|
|
|
|
|
whisper_process_logits(*ctx, params, decoder, t_cur);
|
|
|
|
|
|
|
|
|
|
++decoder.n_past;
|
|
|
|
|
|
|
|
|
@ -3524,7 +3548,7 @@ int whisper_full(
|
|
|
|
|
best_decoder_id = j;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fprintf(stderr, "%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f\n", __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs);
|
|
|
|
|
//fprintf(stderr, "%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f\n", __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -3541,6 +3565,11 @@ int whisper_full(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (success) {
|
|
|
|
|
//fprintf(stderr, "%s: best decoder = %d\n", __func__, best_decoder_id);
|
|
|
|
|
//for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) {
|
|
|
|
|
// fprintf(stderr, "%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str());
|
|
|
|
|
//}
|
|
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|