whisper : wip decoding parameters + strategies

pull/291/head
Georgi Gerganov 3 years ago
parent 628843c60a
commit 9551d7fabd
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -416,11 +416,16 @@ struct whisper_model {
std::map<std::string, struct ggml_tensor *> tensors;
};
struct whisper_sequence {
std::vector<whisper_token_data> tokens;
};
struct whisper_decoder {
whisper_kv_cache kv_self;
};
struct whisper_sequence {
whisper_sequence sequence;
std::vector<whisper_token> prompt;
};
struct whisper_context {
@ -759,6 +764,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
wctx.work_logprobs.reserve(vocab.n_vocab);
wctx.probs_id.reserve(n_vocab);
wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx);
wctx.decoders[0].prompt.reserve(model.hparams.n_text_ctx);
}
size_t ctx_size = 0;
@ -2797,15 +2805,23 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.suppress_blank =*/ true,
/*.max_initial_timestamp =*/ 1.0,
/*.temperature =*/ 0.0f,
/*.max_initial_timestamp =*/ 1.0f,
/*.temperature_increment =*/ 0.2f,
/*.compression_ratio_threshold =*/ 2.4f,
/*.logprob_threshold =*/ -1.0f,
/*.no_speech_threshold =*/ 0.6f,
/*.greedy =*/ {
/*.dummy =*/ 0,
/*.best_of =*/ 5,
},
/*.beam_search =*/ {
/*.beam_width =*/ -1,
/*.n_best =*/ -1,
/*.beam_size =*/ -1,
/*.patience =*/ -1.0f,
/*.length_penalty =*/ -1.0f,
},
/*.new_segment_callback =*/ nullptr,
@ -2849,15 +2865,23 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.suppress_blank =*/ true,
/*.max_initial_timestamp =*/ 1.0,
/*.temperature =*/ 0.0f,
/*.max_initial_timestamp =*/ 1.0f,
/*.temperature_increment =*/ 0.2f,
/*.compression_ratio_threshold =*/ 2.4f,
/*.logprob_threshold =*/ -1.0f,
/*.no_speech_threshold =*/ 0.6f,
/*.greedy =*/ {
/*.dummy =*/ 0,
/*.best_of =*/ 5,
},
/*.beam_search =*/ {
/*.beam_width =*/ 10,
/*.n_best =*/ 5,
/*.beam_size =*/ 5,
/*.patience =*/ -1.0f,
/*.length_penalty =*/ -1.0f,
},
/*.new_segment_callback =*/ nullptr,
@ -3142,6 +3166,45 @@ int whisper_full(
return 0;
}
// a set of temperatures to use
// [ t0, t0 + delta, t0 + 2*delta, ..., < 1.0f + 1e-6f ]
std::vector<float> temperatures;
if (params.temperature_increment > 0.0f) {
for (float t = params.temperature; t < 1.0f + 1e-6f; t += params.temperature_increment) {
temperatures.push_back(t);
}
} else {
temperatures.push_back(params.temperature);
}
// initialize the decoders
int n_decoders = 1;
switch (params.strategy) {
case WHISPER_SAMPLING_GREEDY:
{
n_decoders = params.greedy.best_of;
} break;
case WHISPER_SAMPLING_BEAM_SEARCH:
{
n_decoders = std::max(params.greedy.best_of, params.beam_search.beam_size);
} break;
};
for (int i = 1; i < n_decoders; i++) {
if (ctx->decoders[i].kv_self.ctx == nullptr) {
ctx->decoders[i].kv_self = ctx->decoders[0].kv_self;
if (!kv_cache_reinit(ctx->decoders[i].kv_self)) {
fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, i);
return -4;
}
fprintf(stderr, "%s: initialized self-attention kv cache, decoder %d\n", __func__, i);
ctx->decoders[i].sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity());
ctx->decoders[i].prompt.reserve(ctx->decoders[0].prompt.capacity());
}
}
// the accumulated text context so far
auto & prompt_past = ctx->prompt_past;
if (params.no_context) {
@ -3160,7 +3223,7 @@ int whisper_full(
// overwrite audio_ctx, max allowed is hparams.n_audio_ctx
if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
return -4;
return -5;
}
ctx->exp_n_audio_ctx = params.audio_ctx;
@ -3201,12 +3264,6 @@ int whisper_full(
break;
}
// if there is a very short audio segment left to process, we remove any past prompt since it tends
// to confuse the decoder and often make it repeat or hallucinate stuff
if (seek > seek_start && seek + 500 >= seek_end) {
prompt_past.clear();
}
if (params.encoder_begin_callback) {
if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) {
fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__);
@ -3217,7 +3274,13 @@ int whisper_full(
// encode audio features starting at offset seek
if (whisper_encode(ctx, seek, params.n_threads) != 0) {
fprintf(stderr, "%s: failed to encode\n", __func__);
return -4;
return -6;
}
// if there is a very short audio segment left to process, we remove any past prompt since it tends
// to confuse the decoder and often make it repeat or hallucinate stuff
if (seek > seek_start && seek + 500 >= seek_end) {
prompt_past.clear();
}
prompt.clear();
@ -3257,7 +3320,7 @@ int whisper_full(
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
fprintf(stderr, "%s: failed to decode\n", __func__);
return -5;
return -7;
}
n_past += prompt.size();
@ -3469,16 +3532,27 @@ int whisper_full_parallel(
ctx_p = *ctx;
ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx);
ctx_p.probs.reserve (ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx);
ctx_p.work_logits.reserve (ctx_p.vocab.n_vocab);
ctx_p.work_logprobs.reserve(ctx_p.vocab.n_vocab);
ctx_p.probs_id.reserve(ctx_p.vocab.n_vocab);
if (!kv_cache_reinit(ctx_p.kv_cross)) {
fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention cache\n", __func__);
fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention, processor %d\n", __func__, i);
return false;
}
for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) {
if (ctx_p.decoders[j].kv_self.ctx && !kv_cache_reinit(ctx_p.decoders[j].kv_self)) {
fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention cache\n", __func__);
fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d, processor %d\n", __func__, j, i);
return false;
}
ctx_p.decoders[j].sequence.tokens.reserve(ctx_p.model.hparams.n_text_ctx);
ctx_p.decoders[j].prompt.reserve(ctx_p.model.hparams.n_text_ctx);
}
}

@ -278,15 +278,24 @@ extern "C" {
// common decoding parameters:
bool suppress_blank;
float temperature;
float max_initial_timestamp;
// fallback parameters
float temperature_increment;
float compression_ratio_threshold;
float logprob_threshold;
float no_speech_threshold;
struct {
int dummy;
int best_of;
} greedy;
struct {
int beam_width;
int n_best;
int beam_size;
float patience;
float length_penalty;
} beam_search;
whisper_new_segment_callback new_segment_callback;

Loading…
Cancel
Save