|
|
@ -422,7 +422,9 @@ struct whisper_sequence {
|
|
|
|
// the accumulated transcription in the current interation (used to truncate the tokens array)
|
|
|
|
// the accumulated transcription in the current interation (used to truncate the tokens array)
|
|
|
|
int result_len;
|
|
|
|
int result_len;
|
|
|
|
|
|
|
|
|
|
|
|
double sum_logprobs;
|
|
|
|
double sum_logprobs; // the sum of the log probabilities of the tokens
|
|
|
|
|
|
|
|
double avg_logprobs; // the average log probability of the tokens
|
|
|
|
|
|
|
|
double score; // likelihood rank score
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
// TAGS: WHISPER_DECODER_INIT
|
|
|
|
// TAGS: WHISPER_DECODER_INIT
|
|
|
@ -438,11 +440,12 @@ struct whisper_decoder {
|
|
|
|
bool completed; // has the decoder completed the current segment?
|
|
|
|
bool completed; // has the decoder completed the current segment?
|
|
|
|
bool has_ts; // have we already sampled a non-beg timestamp token for the current segment?
|
|
|
|
bool has_ts; // have we already sampled a non-beg timestamp token for the current segment?
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<whisper_token> tokens;
|
|
|
|
// new token probs, logits and logprobs after the last whisper_decode (1-dimensional array: [n_vocab])
|
|
|
|
|
|
|
|
std::vector<float> probs;
|
|
|
|
// new token logits and logprobs after the last whisper_decode (1-dimensional array: [n_vocab])
|
|
|
|
|
|
|
|
std::vector<float> logits;
|
|
|
|
std::vector<float> logits;
|
|
|
|
std::vector<float> logprobs;
|
|
|
|
std::vector<float> logprobs;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<whisper_token> tokens_tmp; // used for whisper_decode calls
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
struct whisper_context {
|
|
|
|
struct whisper_context {
|
|
|
@ -774,6 +777,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
|
// TAGS: WHISPER_DECODER_INIT
|
|
|
|
// TAGS: WHISPER_DECODER_INIT
|
|
|
|
wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx);
|
|
|
|
wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
wctx.decoders[0].probs.reserve (vocab.n_vocab);
|
|
|
|
wctx.decoders[0].logits.reserve (vocab.n_vocab);
|
|
|
|
wctx.decoders[0].logits.reserve (vocab.n_vocab);
|
|
|
|
wctx.decoders[0].logprobs.reserve(vocab.n_vocab);
|
|
|
|
wctx.decoders[0].logprobs.reserve(vocab.n_vocab);
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -1959,99 +1963,6 @@ static bool whisper_decode(
|
|
|
|
return true;
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// the most basic sampling scheme - select the top token
|
|
|
|
|
|
|
|
static whisper_token_data whisper_sample_best(
|
|
|
|
|
|
|
|
whisper_context & ctx,
|
|
|
|
|
|
|
|
const float * probs,
|
|
|
|
|
|
|
|
bool force_timestamp,
|
|
|
|
|
|
|
|
bool is_initial) {
|
|
|
|
|
|
|
|
whisper_token_data result = {
|
|
|
|
|
|
|
|
0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const auto & vocab = ctx.vocab;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const int n_logits = vocab.n_vocab;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto & probs_id = ctx.probs_id;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
probs_id.clear();
|
|
|
|
|
|
|
|
for (int i = 0; i < n_logits; i++) {
|
|
|
|
|
|
|
|
probs_id.emplace_back(probs[i], i);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
double sum_ts = 0.0;
|
|
|
|
|
|
|
|
double max_ts = -1.0;
|
|
|
|
|
|
|
|
double max_tx = -1.0;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < vocab.token_beg; i++) {
|
|
|
|
|
|
|
|
max_tx = std::max(max_tx, probs_id[i].first);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const auto i0 = is_initial ? vocab.token_beg + 101 : vocab.token_beg;
|
|
|
|
|
|
|
|
const auto i1 = is_initial ? vocab.token_beg + 101 : n_logits;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// the initial timestamp cannot be larger than 100
|
|
|
|
|
|
|
|
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429
|
|
|
|
|
|
|
|
if (is_initial) {
|
|
|
|
|
|
|
|
for (int i = i0; i < n_logits; ++ i) {
|
|
|
|
|
|
|
|
probs_id[i].first = -INFINITY;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = vocab.token_beg; i < i1; i++) {
|
|
|
|
|
|
|
|
sum_ts += probs_id[i].first;
|
|
|
|
|
|
|
|
if (probs_id[i].first > max_ts) {
|
|
|
|
|
|
|
|
max_ts = probs_id[i].first;
|
|
|
|
|
|
|
|
result.tid = probs_id[i].second;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// if the probability sum of all timestamp tokens is higher than the max probability of the text tokens - sample a
|
|
|
|
|
|
|
|
// timestamp token
|
|
|
|
|
|
|
|
if (sum_ts > max_tx || force_timestamp) {
|
|
|
|
|
|
|
|
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
|
|
|
|
|
|
|
|
for (int i = 0; i < vocab.token_beg; i++) {
|
|
|
|
|
|
|
|
probs_id[i].first = -INFINITY;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result.pt = max_ts/(sum_ts + 1e-10);
|
|
|
|
|
|
|
|
result.ptsum = sum_ts;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// find the top K tokens
|
|
|
|
|
|
|
|
const int top_k = 4;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::partial_sort(
|
|
|
|
|
|
|
|
probs_id.begin(),
|
|
|
|
|
|
|
|
probs_id.begin() + top_k, probs_id.end(),
|
|
|
|
|
|
|
|
[](const std::pair<double, whisper_vocab::id> & a, const std::pair<double, whisper_vocab::id> & b) {
|
|
|
|
|
|
|
|
return a.first > b.first;
|
|
|
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
probs_id.resize(top_k);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//printf("\n");
|
|
|
|
|
|
|
|
//for (int i = 0; i < (int) probs_id.size(); i++) {
|
|
|
|
|
|
|
|
// printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
|
|
|
|
|
|
|
|
//}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int res = 0;
|
|
|
|
|
|
|
|
while ((probs_id[res].second == vocab.token_sot ||
|
|
|
|
|
|
|
|
probs_id[res].second == vocab.token_solm ||
|
|
|
|
|
|
|
|
probs_id[res].second == vocab.token_not) &&
|
|
|
|
|
|
|
|
res < (int) probs_id.size() - 1) {
|
|
|
|
|
|
|
|
res++;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result.id = probs_id[res].second;
|
|
|
|
|
|
|
|
result.p = probs_id[res].first;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return result;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 500 -> 00:05.000
|
|
|
|
// 500 -> 00:05.000
|
|
|
|
// 6000 -> 01:00.000
|
|
|
|
// 6000 -> 01:00.000
|
|
|
|
static std::string to_timestamp(int64_t t, bool comma = false) {
|
|
|
|
static std::string to_timestamp(int64_t t, bool comma = false) {
|
|
|
@ -2793,8 +2704,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
|
|
|
|
|
|
|
|
|
/*.suppress_blank =*/ true,
|
|
|
|
/*.suppress_blank =*/ true,
|
|
|
|
|
|
|
|
|
|
|
|
/*.temperature =*/ 0.0f,
|
|
|
|
/*.temperature =*/ 0.0f,
|
|
|
|
/*.max_initial_timestamp =*/ 1.0f,
|
|
|
|
/*.max_initial_timestamp =*/ 1.0f,
|
|
|
|
|
|
|
|
/*.length_penalty =*/ -1.0f,
|
|
|
|
|
|
|
|
|
|
|
|
/*.temperature_increment =*/ 0.2f,
|
|
|
|
/*.temperature_increment =*/ 0.2f,
|
|
|
|
/*.compression_ratio_threshold =*/ 2.4f,
|
|
|
|
/*.compression_ratio_threshold =*/ 2.4f,
|
|
|
@ -2809,7 +2721,6 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
|
/*.beam_size =*/ -1,
|
|
|
|
/*.beam_size =*/ -1,
|
|
|
|
|
|
|
|
|
|
|
|
/*.patience =*/ -1.0f,
|
|
|
|
/*.patience =*/ -1.0f,
|
|
|
|
/*.length_penalty =*/ -1.0f,
|
|
|
|
|
|
|
|
},
|
|
|
|
},
|
|
|
|
|
|
|
|
|
|
|
|
/*.new_segment_callback =*/ nullptr,
|
|
|
|
/*.new_segment_callback =*/ nullptr,
|
|
|
@ -2853,8 +2764,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
|
|
|
|
|
|
|
|
|
/*.suppress_blank =*/ true,
|
|
|
|
/*.suppress_blank =*/ true,
|
|
|
|
|
|
|
|
|
|
|
|
/*.temperature =*/ 0.0f,
|
|
|
|
/*.temperature =*/ 0.0f,
|
|
|
|
/*.max_initial_timestamp =*/ 1.0f,
|
|
|
|
/*.max_initial_timestamp =*/ 1.0f,
|
|
|
|
|
|
|
|
/*.length_penalty =*/ -1.0f,
|
|
|
|
|
|
|
|
|
|
|
|
/*.temperature_increment =*/ 0.2f,
|
|
|
|
/*.temperature_increment =*/ 0.2f,
|
|
|
|
/*.compression_ratio_threshold =*/ 2.4f,
|
|
|
|
/*.compression_ratio_threshold =*/ 2.4f,
|
|
|
@ -2869,7 +2781,6 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
|
/*.beam_size =*/ 5,
|
|
|
|
/*.beam_size =*/ 5,
|
|
|
|
|
|
|
|
|
|
|
|
/*.patience =*/ -1.0f,
|
|
|
|
/*.patience =*/ -1.0f,
|
|
|
|
/*.length_penalty =*/ -1.0f,
|
|
|
|
|
|
|
|
},
|
|
|
|
},
|
|
|
|
|
|
|
|
|
|
|
|
/*.new_segment_callback =*/ nullptr,
|
|
|
|
/*.new_segment_callback =*/ nullptr,
|
|
|
@ -2963,9 +2874,13 @@ static void whisper_process_logits(
|
|
|
|
|
|
|
|
|
|
|
|
// extract the logits for the last token
|
|
|
|
// extract the logits for the last token
|
|
|
|
// we will be mutating and therefore we don't want to use the ctx.logits buffer directly
|
|
|
|
// we will be mutating and therefore we don't want to use the ctx.logits buffer directly
|
|
|
|
|
|
|
|
auto & probs = decoder.probs;
|
|
|
|
auto & logits = decoder.logits;
|
|
|
|
auto & logits = decoder.logits;
|
|
|
|
auto & logprobs = decoder.logprobs;
|
|
|
|
auto & logprobs = decoder.logprobs;
|
|
|
|
{
|
|
|
|
{
|
|
|
|
|
|
|
|
probs.resize(n_logits);
|
|
|
|
|
|
|
|
memcpy(probs.data(), ctx.probs.data() + (ctx.probs.size() - n_logits), n_logits*sizeof(float));
|
|
|
|
|
|
|
|
|
|
|
|
logits.resize(n_logits);
|
|
|
|
logits.resize(n_logits);
|
|
|
|
memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float));
|
|
|
|
memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float));
|
|
|
|
|
|
|
|
|
|
|
@ -2995,6 +2910,8 @@ static void whisper_process_logits(
|
|
|
|
const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg;
|
|
|
|
const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg;
|
|
|
|
const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg;
|
|
|
|
const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//fprintf(stderr, "last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp);
|
|
|
|
|
|
|
|
|
|
|
|
if (last_was_timestamp) {
|
|
|
|
if (last_was_timestamp) {
|
|
|
|
if (penultimate_was_timestamp) {
|
|
|
|
if (penultimate_was_timestamp) {
|
|
|
|
for (int i = vocab.token_beg; i < n_logits; ++ i) {
|
|
|
|
for (int i = vocab.token_beg; i < n_logits; ++ i) {
|
|
|
@ -3099,6 +3016,83 @@ static void whisper_process_logits(
|
|
|
|
//return result;
|
|
|
|
//return result;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// select the most probable token
|
|
|
|
|
|
|
|
static whisper_token_data whisper_sample_best(
|
|
|
|
|
|
|
|
whisper_context & ctx,
|
|
|
|
|
|
|
|
whisper_decoder & decoder) {
|
|
|
|
|
|
|
|
whisper_token_data result = {
|
|
|
|
|
|
|
|
0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const auto & vocab = ctx.vocab;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const auto & probs = decoder.probs;
|
|
|
|
|
|
|
|
const auto & logprobs = decoder.logprobs;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const int n_logits = vocab.n_vocab;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
double sum_ts = 0.0;
|
|
|
|
|
|
|
|
double max_ts = 0.0;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = vocab.token_beg; i < n_logits; i++) {
|
|
|
|
|
|
|
|
if (probs[i] == -INFINITY) {
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sum_ts += probs[i];
|
|
|
|
|
|
|
|
if (max_ts < probs[i]) {
|
|
|
|
|
|
|
|
max_ts = probs[i];
|
|
|
|
|
|
|
|
result.tid = i;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result.pt = max_ts/(sum_ts + 1e-10);
|
|
|
|
|
|
|
|
result.ptsum = sum_ts;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < n_logits; ++i) {
|
|
|
|
|
|
|
|
// never sample these:
|
|
|
|
|
|
|
|
if (i == vocab.token_sot ||
|
|
|
|
|
|
|
|
i == vocab.token_solm ||
|
|
|
|
|
|
|
|
i == vocab.token_not) {
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (result.p < probs[i]) {
|
|
|
|
|
|
|
|
result.id = i;
|
|
|
|
|
|
|
|
result.p = probs[i];
|
|
|
|
|
|
|
|
result.plog = logprobs[i];
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return result;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L178-L192
|
|
|
|
|
|
|
|
static void whisper_sequence_score(
|
|
|
|
|
|
|
|
const struct whisper_full_params & params,
|
|
|
|
|
|
|
|
whisper_sequence & sequence) {
|
|
|
|
|
|
|
|
WHISPER_ASSERT(sequence.result_len > 0);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
double result = 0.0f;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < sequence.result_len; ++i) {
|
|
|
|
|
|
|
|
result += sequence.tokens[i].plog;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sequence.sum_logprobs = result;
|
|
|
|
|
|
|
|
sequence.avg_logprobs = result/sequence.result_len;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
double penalty = sequence.result_len;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (params.length_penalty > 0.0f) {
|
|
|
|
|
|
|
|
penalty = pow((5.0 + penalty) / 6.0, params.length_penalty);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sequence.score = result/penalty;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
int whisper_full(
|
|
|
|
int whisper_full(
|
|
|
|
struct whisper_context * ctx,
|
|
|
|
struct whisper_context * ctx,
|
|
|
|
struct whisper_full_params params,
|
|
|
|
struct whisper_full_params params,
|
|
|
@ -3191,6 +3185,7 @@ int whisper_full(
|
|
|
|
|
|
|
|
|
|
|
|
ctx->decoders[i].sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity());
|
|
|
|
ctx->decoders[i].sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity());
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ctx->decoders[i].probs.reserve (ctx->vocab.n_vocab);
|
|
|
|
ctx->decoders[i].logits.reserve (ctx->vocab.n_vocab);
|
|
|
|
ctx->decoders[i].logits.reserve (ctx->vocab.n_vocab);
|
|
|
|
ctx->decoders[i].logprobs.reserve(ctx->vocab.n_vocab);
|
|
|
|
ctx->decoders[i].logprobs.reserve(ctx->vocab.n_vocab);
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -3261,7 +3256,7 @@ int whisper_full(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// encode audio features starting at offset seek
|
|
|
|
// encode audio features starting at offset seek
|
|
|
|
if (whisper_encode(ctx, seek, params.n_threads) != 0) {
|
|
|
|
if (!whisper_encode(*ctx, seek, params.n_threads)) {
|
|
|
|
fprintf(stderr, "%s: failed to encode\n", __func__);
|
|
|
|
fprintf(stderr, "%s: failed to encode\n", __func__);
|
|
|
|
return -6;
|
|
|
|
return -6;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -3272,132 +3267,7 @@ int whisper_full(
|
|
|
|
prompt_past.clear();
|
|
|
|
prompt_past.clear();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
//prompt.clear();
|
|
|
|
int best_decoder_id = 0;
|
|
|
|
|
|
|
|
|
|
|
|
//// if we have already generated some text, use it as a prompt to condition the next generation
|
|
|
|
|
|
|
|
//if (!prompt_past.empty()) {
|
|
|
|
|
|
|
|
// int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// prompt = { whisper_token_prev(ctx) };
|
|
|
|
|
|
|
|
// prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// prompt_past.clear();
|
|
|
|
|
|
|
|
// prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end());
|
|
|
|
|
|
|
|
//}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//// print the prompt
|
|
|
|
|
|
|
|
////printf("\n\n");
|
|
|
|
|
|
|
|
////for (int i = 0; i < prompt.size(); i++) {
|
|
|
|
|
|
|
|
//// printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token[prompt[i]].c_str());
|
|
|
|
|
|
|
|
////}
|
|
|
|
|
|
|
|
////printf("\n\n");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//int n_past = 0;
|
|
|
|
|
|
|
|
//int seek_delta = 100*WHISPER_CHUNK_SIZE;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//// the accumulated transcription in the current interation
|
|
|
|
|
|
|
|
//int result_len = 0;
|
|
|
|
|
|
|
|
//tokens_cur.clear();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//bool failed = false; // has the current segment failed to decode?
|
|
|
|
|
|
|
|
//bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment?
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//double sum_logprobs = 0.0;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
|
|
|
|
|
|
|
|
// if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
|
|
|
|
|
|
|
|
// fprintf(stderr, "%s: failed to decode\n", __func__);
|
|
|
|
|
|
|
|
// return -7;
|
|
|
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// n_past += prompt.size();
|
|
|
|
|
|
|
|
// prompt.clear();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// // sample the next token based on the selected decoding strategy + parameters
|
|
|
|
|
|
|
|
// // also, update the sliding window position based on the sampled timestamp tokens
|
|
|
|
|
|
|
|
// {
|
|
|
|
|
|
|
|
// const auto token = whisper_sample_next_token(ctx, params, sum_logprobs, prompt, tokens_cur);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// // timestamp token - update sliding window
|
|
|
|
|
|
|
|
// if (token.id > whisper_token_beg(ctx)) {
|
|
|
|
|
|
|
|
// const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// // do not allow to go back in time
|
|
|
|
|
|
|
|
// if (has_ts && seek_delta > seek_delta_new && result_len < i) {
|
|
|
|
|
|
|
|
// break;
|
|
|
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// seek_delta = seek_delta_new;
|
|
|
|
|
|
|
|
// result_len = i + 1;
|
|
|
|
|
|
|
|
// has_ts = true;
|
|
|
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// // add it to the context
|
|
|
|
|
|
|
|
// prompt.push_back(token.id);
|
|
|
|
|
|
|
|
// tokens_cur.push_back(token);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// //{
|
|
|
|
|
|
|
|
// // const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]";
|
|
|
|
|
|
|
|
// // printf("%s: %3d %10s %6d %6.3f '%s'\n", __func__, i, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str());
|
|
|
|
|
|
|
|
// //}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// // end of segment
|
|
|
|
|
|
|
|
// if (token.id == whisper_token_eot(ctx) || // end of text token
|
|
|
|
|
|
|
|
// (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached
|
|
|
|
|
|
|
|
// (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached
|
|
|
|
|
|
|
|
// ) {
|
|
|
|
|
|
|
|
// if (result_len == 0) {
|
|
|
|
|
|
|
|
// if (seek + seek_delta + 100 >= seek_end) {
|
|
|
|
|
|
|
|
// result_len = i + 1;
|
|
|
|
|
|
|
|
// } else {
|
|
|
|
|
|
|
|
// failed = true;
|
|
|
|
|
|
|
|
// break;
|
|
|
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// if (params.single_segment) {
|
|
|
|
|
|
|
|
// result_len = i + 1;
|
|
|
|
|
|
|
|
// seek_delta = 100*WHISPER_CHUNK_SIZE;
|
|
|
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// break;
|
|
|
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// // TESTS: if no tensors are loaded, it means we are running tests
|
|
|
|
|
|
|
|
// if (ctx->model.n_loaded == 0) {
|
|
|
|
|
|
|
|
// seek_delta = 100*WHISPER_CHUNK_SIZE;
|
|
|
|
|
|
|
|
// break;
|
|
|
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// // sometimes, the decoding can get stuck in a repetition loop
|
|
|
|
|
|
|
|
// // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
|
|
|
|
|
|
|
|
// if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
|
|
|
|
|
|
|
|
// failed = true;
|
|
|
|
|
|
|
|
// break;
|
|
|
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
//}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//if (failed) {
|
|
|
|
|
|
|
|
// // when we fail to sample timestamp token, retry by clearing the past prompt
|
|
|
|
|
|
|
|
// // if it fails again, then we advance the window by 1 second
|
|
|
|
|
|
|
|
// if (!prompt_past.empty()) {
|
|
|
|
|
|
|
|
// prompt_past.clear();
|
|
|
|
|
|
|
|
// } else {
|
|
|
|
|
|
|
|
// fprintf(stderr, "\n%s: failed to generate timestamp token - skipping one second\n\n", __func__);
|
|
|
|
|
|
|
|
// seek += 100;
|
|
|
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
// continue;
|
|
|
|
|
|
|
|
//}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//// shrink down to result_len
|
|
|
|
|
|
|
|
//tokens_cur.resize(result_len);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//for (const auto & r : tokens_cur) {
|
|
|
|
|
|
|
|
// prompt_past.push_back(r.id);
|
|
|
|
|
|
|
|
//}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (int it = 0; it < (int) temperatures.size(); ++it) {
|
|
|
|
for (int it = 0; it < (int) temperatures.size(); ++it) {
|
|
|
|
const float t_cur = temperatures[it];
|
|
|
|
const float t_cur = temperatures[it];
|
|
|
@ -3429,12 +3299,14 @@ int whisper_full(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// TAGS: WHISPER_DECODER_INIT
|
|
|
|
// TAGS: WHISPER_DECODER_INIT
|
|
|
|
for (int i = 0; i < n_decoders_cur; ++i) {
|
|
|
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
|
|
auto & decoder = ctx->decoders[i];
|
|
|
|
auto & decoder = ctx->decoders[j];
|
|
|
|
|
|
|
|
|
|
|
|
decoder.sequence.tokens.clear();
|
|
|
|
decoder.sequence.tokens.clear();
|
|
|
|
decoder.sequence.result_len = 0;
|
|
|
|
decoder.sequence.result_len = 0;
|
|
|
|
decoder.sequence.sum_logprobs = 0.0;
|
|
|
|
decoder.sequence.sum_logprobs = 0.0;
|
|
|
|
|
|
|
|
decoder.sequence.avg_logprobs = 0.0;
|
|
|
|
|
|
|
|
decoder.sequence.score = 0.0;
|
|
|
|
|
|
|
|
|
|
|
|
decoder.n_past = 0;
|
|
|
|
decoder.n_past = 0;
|
|
|
|
decoder.seek_delta = 100*WHISPER_CHUNK_SIZE;
|
|
|
|
decoder.seek_delta = 100*WHISPER_CHUNK_SIZE;
|
|
|
@ -3467,65 +3339,148 @@ int whisper_full(
|
|
|
|
//}
|
|
|
|
//}
|
|
|
|
//printf("\n\n");
|
|
|
|
//printf("\n\n");
|
|
|
|
|
|
|
|
|
|
|
|
if (whisper_decode(*ctx, ctx->decoders[0], params.n_threads, prompt.data(), prompt.size(), 0) != 0) {
|
|
|
|
if (!whisper_decode(*ctx, ctx->decoders[0], params.n_threads, prompt.data(), prompt.size(), 0)) {
|
|
|
|
fprintf(stderr, "%s: failed to decode\n", __func__);
|
|
|
|
fprintf(stderr, "%s: failed to decode\n", __func__);
|
|
|
|
return -7;
|
|
|
|
return -7;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
whisper_process_logits(*ctx, ctx->decoders[0], params);
|
|
|
|
whisper_process_logits(*ctx, ctx->decoders[0], params);
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = 1; i < n_decoders_cur; ++i) {
|
|
|
|
for (int j = 1; j < n_decoders_cur; ++j) {
|
|
|
|
auto & decoder = ctx->decoders[i];
|
|
|
|
auto & decoder = ctx->decoders[j];
|
|
|
|
|
|
|
|
|
|
|
|
memcpy(decoder.kv_self.buf.data(), ctx->decoders[0].kv_self.buf.data(), decoder.kv_self.buf.size());
|
|
|
|
memcpy(decoder.kv_self.buf.data(), ctx->decoders[0].kv_self.buf.data(), decoder.kv_self.buf.size());
|
|
|
|
|
|
|
|
|
|
|
|
decoder.n_past += prompt.size();
|
|
|
|
decoder.n_past += prompt.size();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
memcpy(decoder.probs.data(), ctx->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
|
|
|
|
memcpy(decoder.logits.data(), ctx->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
|
|
|
|
memcpy(decoder.logits.data(), ctx->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
|
|
|
|
memcpy(decoder.logprobs.data(), ctx->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
|
|
|
|
memcpy(decoder.logprobs.data(), ctx->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
|
|
|
|
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
|
|
|
|
bool completed = true;
|
|
|
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
|
|
|
|
|
|
auto & decoder = ctx->decoders[j];
|
|
|
|
switch (params.strategy) {
|
|
|
|
|
|
|
|
case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
for (int i = 0; i < n_decoders_cur; ++i) {
|
|
|
|
|
|
|
|
auto & decoder = ctx->decoders[i];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (decoder.completed || decoder.failed) {
|
|
|
|
if (decoder.completed || decoder.failed) {
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
switch (params.strategy) {
|
|
|
|
|
|
|
|
case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
|
|
|
|
|
|
|
|
{
|
|
|
|
if (t_cur < 1e-6f) {
|
|
|
|
if (t_cur < 1e-6f) {
|
|
|
|
// select top token
|
|
|
|
decoder.sequence.tokens.push_back(whisper_sample_best(*ctx, decoder));
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
} break;
|
|
|
|
|
|
|
|
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
// TODO: ..
|
|
|
|
|
|
|
|
} break;
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto & has_ts = decoder.has_ts;
|
|
|
|
|
|
|
|
auto & failed = decoder.failed;
|
|
|
|
|
|
|
|
auto & completed = decoder.completed;
|
|
|
|
|
|
|
|
auto & seek_delta = decoder.seek_delta;
|
|
|
|
|
|
|
|
auto & result_len = decoder.sequence.result_len;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
const auto & token = decoder.sequence.tokens.back();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// timestamp token - update sliding window
|
|
|
|
|
|
|
|
if (token.id > whisper_token_beg(ctx)) {
|
|
|
|
|
|
|
|
const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// do not allow to go back in time
|
|
|
|
|
|
|
|
if (has_ts && seek_delta > seek_delta_new && result_len < i) {
|
|
|
|
|
|
|
|
failed = true; // TODO: maybe this is not a failure ?
|
|
|
|
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} break;
|
|
|
|
|
|
|
|
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
|
|
|
|
seek_delta = seek_delta_new;
|
|
|
|
|
|
|
|
result_len = i + 1;
|
|
|
|
|
|
|
|
has_ts = true;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
{
|
|
|
|
// TODO: ..
|
|
|
|
const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]";
|
|
|
|
} break;
|
|
|
|
printf("%s: %3d, decoder = %d, %10s %6d %6.3f '%s'\n", __func__, i, j, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str());
|
|
|
|
};
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if (completed) {
|
|
|
|
// end of segment
|
|
|
|
break;
|
|
|
|
if (token.id == whisper_token_eot(ctx) || // end of text token
|
|
|
|
|
|
|
|
(params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached
|
|
|
|
|
|
|
|
(has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached
|
|
|
|
|
|
|
|
) {
|
|
|
|
|
|
|
|
if (result_len == 0) {
|
|
|
|
|
|
|
|
if (seek + seek_delta + 100 >= seek_end) {
|
|
|
|
|
|
|
|
result_len = i + 1;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
failed = true;
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (params.single_segment) {
|
|
|
|
|
|
|
|
result_len = i + 1;
|
|
|
|
|
|
|
|
seek_delta = 100*WHISPER_CHUNK_SIZE;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
completed = true;
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// TESTS: if no tensors are loaded, it means we are running tests
|
|
|
|
|
|
|
|
if (ctx->model.n_loaded == 0) {
|
|
|
|
|
|
|
|
seek_delta = 100*WHISPER_CHUNK_SIZE;
|
|
|
|
|
|
|
|
completed = true;
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// sometimes, the decoding can get stuck in a repetition loop
|
|
|
|
|
|
|
|
// this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
|
|
|
|
|
|
|
|
if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
|
|
|
|
|
|
|
|
failed = true;
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// check if all decoders have finished (i.e. completed or failed)
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
bool completed_all = true;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
|
|
|
|
|
|
auto & decoder = ctx->decoders[j];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (decoder.completed || decoder.failed) {
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
completed_all = false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (completed_all) {
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < n_decoders_cur; ++i) {
|
|
|
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
|
|
auto & decoder = ctx->decoders[i];
|
|
|
|
auto & decoder = ctx->decoders[j];
|
|
|
|
|
|
|
|
|
|
|
|
if (decoder.failed || decoder.completed) {
|
|
|
|
if (decoder.failed || decoder.completed) {
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
decoder.tokens.resize(1);
|
|
|
|
decoder.tokens_tmp.resize(1);
|
|
|
|
decoder.tokens[0] = decoder.sequence.tokens.back().id;
|
|
|
|
decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//fprintf(stderr, "%s: decoder %d: token %d, n_past %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.n_past, decoder.seek_delta);
|
|
|
|
|
|
|
|
|
|
|
|
if (whisper_decode(*ctx, decoder, params.n_threads, decoder.tokens.data(), decoder.tokens.size(), decoder.n_past) != 0) {
|
|
|
|
if (!whisper_decode(*ctx, decoder, params.n_threads, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.n_past)) {
|
|
|
|
fprintf(stderr, "%s: failed to decode\n", __func__);
|
|
|
|
fprintf(stderr, "%s: failed to decode\n", __func__);
|
|
|
|
return -7;
|
|
|
|
return -7;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -3536,14 +3491,36 @@ int whisper_full(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// TODO: rank the resulting sequences and select the best one
|
|
|
|
// rank the resulting sequences and select the best one
|
|
|
|
{
|
|
|
|
{
|
|
|
|
|
|
|
|
double best_score = -1e9;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
|
|
|
|
|
|
auto & decoder = ctx->decoders[j];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (decoder.failed) {
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
whisper_sequence_score(params, ctx->decoders[j].sequence);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (best_score < decoder.sequence.score) {
|
|
|
|
|
|
|
|
best_score = decoder.sequence.score;
|
|
|
|
|
|
|
|
best_decoder_id = j;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
bool success = true;
|
|
|
|
bool success = true;
|
|
|
|
|
|
|
|
|
|
|
|
// TODO: implement logprob threshold + compression threshold
|
|
|
|
// implement logprob threshold
|
|
|
|
|
|
|
|
// TODO: implement compression threshold
|
|
|
|
{
|
|
|
|
{
|
|
|
|
|
|
|
|
auto & decoder = ctx->decoders[best_decoder_id];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (decoder.sequence.avg_logprobs < params.logprob_threshold) {
|
|
|
|
|
|
|
|
success = false;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if (success) {
|
|
|
|
if (success) {
|
|
|
@ -3553,113 +3530,121 @@ int whisper_full(
|
|
|
|
fprintf(stderr, "\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur);
|
|
|
|
fprintf(stderr, "\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// TODO
|
|
|
|
{
|
|
|
|
const int seek_delta = 0;
|
|
|
|
const auto & best_decoder = ctx->decoders[best_decoder_id];
|
|
|
|
whisper_sequence seq_best;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const auto & tokens_cur = seq_best.tokens;
|
|
|
|
const auto seek_delta = best_decoder.seek_delta;
|
|
|
|
|
|
|
|
const auto result_len = best_decoder.sequence.result_len;
|
|
|
|
|
|
|
|
|
|
|
|
// TODO: update prompt_past to keep only the last whisper_n_text_ctx(ctx)/2 tokens
|
|
|
|
const auto & tokens_cur = best_decoder.sequence.tokens;
|
|
|
|
//prompt_past.clear();
|
|
|
|
|
|
|
|
//prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end());
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// store the text from this iteration
|
|
|
|
// update prompt_past
|
|
|
|
if (!tokens_cur.empty()) {
|
|
|
|
prompt_past.clear();
|
|
|
|
int i0 = 0;
|
|
|
|
prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end());
|
|
|
|
auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::string text;
|
|
|
|
for (int i = 0; i < result_len; ++i) {
|
|
|
|
|
|
|
|
prompt_past.push_back(tokens_cur[i].id);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < (int) tokens_cur.size(); i++) {
|
|
|
|
// store the text from this iteration
|
|
|
|
//printf("%s: %18s %6.3f %18s %6.3f\n", __func__,
|
|
|
|
if (!tokens_cur.empty()) {
|
|
|
|
// ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p,
|
|
|
|
int i0 = 0;
|
|
|
|
// ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt);
|
|
|
|
auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
|
|
|
|
|
|
|
|
|
|
|
|
if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) {
|
|
|
|
std::string text;
|
|
|
|
} else {
|
|
|
|
|
|
|
|
text += whisper_token_to_str(ctx, tokens_cur[i].id);
|
|
|
|
for (int i = 0; i < (int) tokens_cur.size(); i++) {
|
|
|
|
}
|
|
|
|
//printf("%s: %18s %6.3f %18s %6.3f\n", __func__,
|
|
|
|
if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) {
|
|
|
|
// ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p,
|
|
|
|
const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
|
|
|
|
// ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt);
|
|
|
|
if (!text.empty()) {
|
|
|
|
|
|
|
|
const auto tt0 = params.speed_up ? 2*t0 : t0;
|
|
|
|
if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) {
|
|
|
|
const auto tt1 = params.speed_up ? 2*t1 : t1;
|
|
|
|
} else {
|
|
|
|
|
|
|
|
text += whisper_token_to_str(ctx, tokens_cur[i].id);
|
|
|
|
if (params.print_realtime) {
|
|
|
|
}
|
|
|
|
if (params.print_timestamps) {
|
|
|
|
if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) {
|
|
|
|
printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
|
|
|
|
const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
|
|
|
|
} else {
|
|
|
|
if (!text.empty()) {
|
|
|
|
printf("%s", text.c_str());
|
|
|
|
const auto tt0 = params.speed_up ? 2*t0 : t0;
|
|
|
|
fflush(stdout);
|
|
|
|
const auto tt1 = params.speed_up ? 2*t1 : t1;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (params.print_realtime) {
|
|
|
|
|
|
|
|
if (params.print_timestamps) {
|
|
|
|
|
|
|
|
printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
printf("%s", text.c_str());
|
|
|
|
|
|
|
|
fflush(stdout);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result_all.push_back({ tt0, tt1, text, {} });
|
|
|
|
result_all.push_back({ tt0, tt1, text, {} });
|
|
|
|
for (int j = i0; j <= i; j++) {
|
|
|
|
for (int j = i0; j <= i; j++) {
|
|
|
|
result_all.back().tokens.push_back(tokens_cur[j]);
|
|
|
|
result_all.back().tokens.push_back(tokens_cur[j]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
int n_new = 1;
|
|
|
|
int n_new = 1;
|
|
|
|
|
|
|
|
|
|
|
|
if (params.token_timestamps) {
|
|
|
|
if (params.token_timestamps) {
|
|
|
|
whisper_exp_compute_token_level_timestamps(
|
|
|
|
whisper_exp_compute_token_level_timestamps(
|
|
|
|
*ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
|
|
|
|
*ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
|
|
|
|
|
|
|
|
|
|
|
|
if (params.max_len > 0) {
|
|
|
|
if (params.max_len > 0) {
|
|
|
|
n_new = whisper_wrap_segment(*ctx, params.max_len);
|
|
|
|
n_new = whisper_wrap_segment(*ctx, params.max_len);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (params.new_segment_callback) {
|
|
|
|
|
|
|
|
params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (params.new_segment_callback) {
|
|
|
|
text = "";
|
|
|
|
params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
|
|
|
|
while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) {
|
|
|
|
|
|
|
|
i++;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
i--;
|
|
|
|
|
|
|
|
t0 = t1;
|
|
|
|
|
|
|
|
i0 = i + 1;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
text = "";
|
|
|
|
|
|
|
|
while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) {
|
|
|
|
|
|
|
|
i++;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
i--;
|
|
|
|
|
|
|
|
t0 = t1;
|
|
|
|
|
|
|
|
i0 = i + 1;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (!text.empty()) {
|
|
|
|
if (!text.empty()) {
|
|
|
|
const auto t1 = seek + seek_delta;
|
|
|
|
const auto t1 = seek + seek_delta;
|
|
|
|
|
|
|
|
|
|
|
|
const auto tt0 = params.speed_up ? 2*t0 : t0;
|
|
|
|
const auto tt0 = params.speed_up ? 2*t0 : t0;
|
|
|
|
const auto tt1 = params.speed_up ? 2*t1 : t1;
|
|
|
|
const auto tt1 = params.speed_up ? 2*t1 : t1;
|
|
|
|
|
|
|
|
|
|
|
|
if (params.print_realtime) {
|
|
|
|
if (params.print_realtime) {
|
|
|
|
if (params.print_timestamps) {
|
|
|
|
if (params.print_timestamps) {
|
|
|
|
printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
|
|
|
|
printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
printf("%s", text.c_str());
|
|
|
|
printf("%s", text.c_str());
|
|
|
|
fflush(stdout);
|
|
|
|
fflush(stdout);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result_all.push_back({ tt0, tt1, text, {} });
|
|
|
|
result_all.push_back({ tt0, tt1, text, {} });
|
|
|
|
for (int j = i0; j < (int) tokens_cur.size(); j++) {
|
|
|
|
for (int j = i0; j < (int) tokens_cur.size(); j++) {
|
|
|
|
result_all.back().tokens.push_back(tokens_cur[j]);
|
|
|
|
result_all.back().tokens.push_back(tokens_cur[j]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
int n_new = 1;
|
|
|
|
int n_new = 1;
|
|
|
|
|
|
|
|
|
|
|
|
if (params.token_timestamps) {
|
|
|
|
if (params.token_timestamps) {
|
|
|
|
whisper_exp_compute_token_level_timestamps(
|
|
|
|
whisper_exp_compute_token_level_timestamps(
|
|
|
|
*ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
|
|
|
|
*ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
|
|
|
|
|
|
|
|
|
|
|
|
if (params.max_len > 0) {
|
|
|
|
if (params.max_len > 0) {
|
|
|
|
n_new = whisper_wrap_segment(*ctx, params.max_len);
|
|
|
|
n_new = whisper_wrap_segment(*ctx, params.max_len);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (params.new_segment_callback) {
|
|
|
|
|
|
|
|
params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if (params.new_segment_callback) {
|
|
|
|
|
|
|
|
params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
seek += seek_delta;
|
|
|
|
// update audio window
|
|
|
|
|
|
|
|
seek += seek_delta;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return 0;
|
|
|
|
return 0;
|
|
|
@ -3704,6 +3689,7 @@ int whisper_full_parallel(
|
|
|
|
|
|
|
|
|
|
|
|
ctx_p.decoders[j].sequence.tokens.reserve(ctx_p.model.hparams.n_text_ctx);
|
|
|
|
ctx_p.decoders[j].sequence.tokens.reserve(ctx_p.model.hparams.n_text_ctx);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ctx_p.decoders[j].probs.reserve (ctx_p.vocab.n_vocab);
|
|
|
|
ctx_p.decoders[j].logits.reserve (ctx_p.vocab.n_vocab);
|
|
|
|
ctx_p.decoders[j].logits.reserve (ctx_p.vocab.n_vocab);
|
|
|
|
ctx_p.decoders[j].logprobs.reserve(ctx_p.vocab.n_vocab);
|
|
|
|
ctx_p.decoders[j].logprobs.reserve(ctx_p.vocab.n_vocab);
|
|
|
|
}
|
|
|
|
}
|
|
|
|