whisper : do not provide past prompt when n_max_text_ctx == 0

pull/451/head
Georgi Gerganov 2 years ago
parent 60337f5306
commit b992f3709e
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -3524,7 +3524,7 @@ int whisper_full(
prompt.clear(); prompt.clear();
// if we have already generated some text, use it as a prompt to condition the next generation // if we have already generated some text, use it as a prompt to condition the next generation
if (!prompt_past.empty() && t_cur < 0.5f) { if (!prompt_past.empty() && t_cur < 0.5f && params.n_max_text_ctx > 0) {
int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size())); int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
prompt = { whisper_token_prev(ctx) }; prompt = { whisper_token_prev(ctx) };
@ -3535,11 +3535,11 @@ int whisper_full(
prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
// print the prompt // print the prompt
//WHISPER_PRINT_DEBUG("\n\n"); WHISPER_PRINT_DEBUG("\n\n");
//for (int i = 0; i < (int) prompt.size(); i++) { for (int i = 0; i < (int) prompt.size(); i++) {
// WHISPER_PRINT_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str()); WHISPER_PRINT_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str());
//} }
//WHISPER_PRINT_DEBUG("\n\n"); WHISPER_PRINT_DEBUG("\n\n");
if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) { if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
fprintf(stderr, "%s: failed to decode\n", __func__); fprintf(stderr, "%s: failed to decode\n", __func__);

Loading…
Cancel
Save