|
|
|
@ -416,11 +416,16 @@ struct whisper_model {
|
|
|
|
|
std::map<std::string, struct ggml_tensor *> tensors;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct whisper_sequence {
|
|
|
|
|
std::vector<whisper_token_data> tokens;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct whisper_decoder {
|
|
|
|
|
whisper_kv_cache kv_self;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct whisper_sequence {
|
|
|
|
|
whisper_sequence sequence;
|
|
|
|
|
|
|
|
|
|
std::vector<whisper_token> prompt;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct whisper_context {
|
|
|
|
@ -759,6 +764,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
|
|
wctx.work_logprobs.reserve(vocab.n_vocab);
|
|
|
|
|
|
|
|
|
|
wctx.probs_id.reserve(n_vocab);
|
|
|
|
|
|
|
|
|
|
wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx);
|
|
|
|
|
wctx.decoders[0].prompt.reserve(model.hparams.n_text_ctx);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t ctx_size = 0;
|
|
|
|
@ -2797,15 +2805,23 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
|
|
|
|
|
|
|
/*.suppress_blank =*/ true,
|
|
|
|
|
|
|
|
|
|
/*.max_initial_timestamp =*/ 1.0,
|
|
|
|
|
/*.temperature =*/ 0.0f,
|
|
|
|
|
/*.max_initial_timestamp =*/ 1.0f,
|
|
|
|
|
|
|
|
|
|
/*.temperature_increment =*/ 0.2f,
|
|
|
|
|
/*.compression_ratio_threshold =*/ 2.4f,
|
|
|
|
|
/*.logprob_threshold =*/ -1.0f,
|
|
|
|
|
/*.no_speech_threshold =*/ 0.6f,
|
|
|
|
|
|
|
|
|
|
/*.greedy =*/ {
|
|
|
|
|
/*.dummy =*/ 0,
|
|
|
|
|
/*.best_of =*/ 5,
|
|
|
|
|
},
|
|
|
|
|
|
|
|
|
|
/*.beam_search =*/ {
|
|
|
|
|
/*.beam_width =*/ -1,
|
|
|
|
|
/*.n_best =*/ -1,
|
|
|
|
|
/*.beam_size =*/ -1,
|
|
|
|
|
|
|
|
|
|
/*.patience =*/ -1.0f,
|
|
|
|
|
/*.length_penalty =*/ -1.0f,
|
|
|
|
|
},
|
|
|
|
|
|
|
|
|
|
/*.new_segment_callback =*/ nullptr,
|
|
|
|
@ -2849,15 +2865,23 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
|
|
|
|
|
|
|
/*.suppress_blank =*/ true,
|
|
|
|
|
|
|
|
|
|
/*.max_initial_timestamp =*/ 1.0,
|
|
|
|
|
/*.temperature =*/ 0.0f,
|
|
|
|
|
/*.max_initial_timestamp =*/ 1.0f,
|
|
|
|
|
|
|
|
|
|
/*.temperature_increment =*/ 0.2f,
|
|
|
|
|
/*.compression_ratio_threshold =*/ 2.4f,
|
|
|
|
|
/*.logprob_threshold =*/ -1.0f,
|
|
|
|
|
/*.no_speech_threshold =*/ 0.6f,
|
|
|
|
|
|
|
|
|
|
/*.greedy =*/ {
|
|
|
|
|
/*.dummy =*/ 0,
|
|
|
|
|
/*.best_of =*/ 5,
|
|
|
|
|
},
|
|
|
|
|
|
|
|
|
|
/*.beam_search =*/ {
|
|
|
|
|
/*.beam_width =*/ 10,
|
|
|
|
|
/*.n_best =*/ 5,
|
|
|
|
|
/*.beam_size =*/ 5,
|
|
|
|
|
|
|
|
|
|
/*.patience =*/ -1.0f,
|
|
|
|
|
/*.length_penalty =*/ -1.0f,
|
|
|
|
|
},
|
|
|
|
|
|
|
|
|
|
/*.new_segment_callback =*/ nullptr,
|
|
|
|
@ -3142,6 +3166,45 @@ int whisper_full(
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// a set of temperatures to use
|
|
|
|
|
// [ t0, t0 + delta, t0 + 2*delta, ..., < 1.0f + 1e-6f ]
|
|
|
|
|
std::vector<float> temperatures;
|
|
|
|
|
if (params.temperature_increment > 0.0f) {
|
|
|
|
|
for (float t = params.temperature; t < 1.0f + 1e-6f; t += params.temperature_increment) {
|
|
|
|
|
temperatures.push_back(t);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
temperatures.push_back(params.temperature);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// initialize the decoders
|
|
|
|
|
int n_decoders = 1;
|
|
|
|
|
switch (params.strategy) {
|
|
|
|
|
case WHISPER_SAMPLING_GREEDY:
|
|
|
|
|
{
|
|
|
|
|
n_decoders = params.greedy.best_of;
|
|
|
|
|
} break;
|
|
|
|
|
case WHISPER_SAMPLING_BEAM_SEARCH:
|
|
|
|
|
{
|
|
|
|
|
n_decoders = std::max(params.greedy.best_of, params.beam_search.beam_size);
|
|
|
|
|
} break;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
for (int i = 1; i < n_decoders; i++) {
|
|
|
|
|
if (ctx->decoders[i].kv_self.ctx == nullptr) {
|
|
|
|
|
ctx->decoders[i].kv_self = ctx->decoders[0].kv_self;
|
|
|
|
|
if (!kv_cache_reinit(ctx->decoders[i].kv_self)) {
|
|
|
|
|
fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, i);
|
|
|
|
|
return -4;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fprintf(stderr, "%s: initialized self-attention kv cache, decoder %d\n", __func__, i);
|
|
|
|
|
|
|
|
|
|
ctx->decoders[i].sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity());
|
|
|
|
|
ctx->decoders[i].prompt.reserve(ctx->decoders[0].prompt.capacity());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// the accumulated text context so far
|
|
|
|
|
auto & prompt_past = ctx->prompt_past;
|
|
|
|
|
if (params.no_context) {
|
|
|
|
@ -3160,7 +3223,7 @@ int whisper_full(
|
|
|
|
|
// overwrite audio_ctx, max allowed is hparams.n_audio_ctx
|
|
|
|
|
if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
|
|
|
|
|
fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
|
|
|
|
|
return -4;
|
|
|
|
|
return -5;
|
|
|
|
|
}
|
|
|
|
|
ctx->exp_n_audio_ctx = params.audio_ctx;
|
|
|
|
|
|
|
|
|
@ -3201,12 +3264,6 @@ int whisper_full(
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// if there is a very short audio segment left to process, we remove any past prompt since it tends
|
|
|
|
|
// to confuse the decoder and often make it repeat or hallucinate stuff
|
|
|
|
|
if (seek > seek_start && seek + 500 >= seek_end) {
|
|
|
|
|
prompt_past.clear();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (params.encoder_begin_callback) {
|
|
|
|
|
if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) {
|
|
|
|
|
fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__);
|
|
|
|
@ -3217,7 +3274,13 @@ int whisper_full(
|
|
|
|
|
// encode audio features starting at offset seek
|
|
|
|
|
if (whisper_encode(ctx, seek, params.n_threads) != 0) {
|
|
|
|
|
fprintf(stderr, "%s: failed to encode\n", __func__);
|
|
|
|
|
return -4;
|
|
|
|
|
return -6;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// if there is a very short audio segment left to process, we remove any past prompt since it tends
|
|
|
|
|
// to confuse the decoder and often make it repeat or hallucinate stuff
|
|
|
|
|
if (seek > seek_start && seek + 500 >= seek_end) {
|
|
|
|
|
prompt_past.clear();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
prompt.clear();
|
|
|
|
@ -3257,7 +3320,7 @@ int whisper_full(
|
|
|
|
|
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
|
|
|
|
|
if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
|
|
|
|
|
fprintf(stderr, "%s: failed to decode\n", __func__);
|
|
|
|
|
return -5;
|
|
|
|
|
return -7;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
n_past += prompt.size();
|
|
|
|
@ -3469,16 +3532,27 @@ int whisper_full_parallel(
|
|
|
|
|
|
|
|
|
|
ctx_p = *ctx;
|
|
|
|
|
|
|
|
|
|
ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx);
|
|
|
|
|
ctx_p.probs.reserve (ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx);
|
|
|
|
|
|
|
|
|
|
ctx_p.work_logits.reserve (ctx_p.vocab.n_vocab);
|
|
|
|
|
ctx_p.work_logprobs.reserve(ctx_p.vocab.n_vocab);
|
|
|
|
|
|
|
|
|
|
ctx_p.probs_id.reserve(ctx_p.vocab.n_vocab);
|
|
|
|
|
|
|
|
|
|
if (!kv_cache_reinit(ctx_p.kv_cross)) {
|
|
|
|
|
fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention cache\n", __func__);
|
|
|
|
|
fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention, processor %d\n", __func__, i);
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) {
|
|
|
|
|
if (ctx_p.decoders[j].kv_self.ctx && !kv_cache_reinit(ctx_p.decoders[j].kv_self)) {
|
|
|
|
|
fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention cache\n", __func__);
|
|
|
|
|
fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d, processor %d\n", __func__, j, i);
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ctx_p.decoders[j].sequence.tokens.reserve(ctx_p.model.hparams.n_text_ctx);
|
|
|
|
|
ctx_p.decoders[j].prompt.reserve(ctx_p.model.hparams.n_text_ctx);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|