whisper : minor fixes

pull/291/head
Georgi Gerganov 3 years ago
parent c301a7942b
commit 5e97f80fc5
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -459,7 +459,7 @@ int main(int argc, char ** argv) {
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str()); struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
std::vector<float> pcmf32 (n_samples_30s, 0.0f); std::vector<float> pcmf32 (n_samples_30s, 0.0f);
std::vector<float> pcmf32_old(n_samples_30s, 0.0f); std::vector<float> pcmf32_old;
std::vector<float> pcmf32_new(n_samples_30s, 0.0f); std::vector<float> pcmf32_new(n_samples_30s, 0.0f);
std::vector<whisper_token> prompt_tokens; std::vector<whisper_token> prompt_tokens;
@ -615,6 +615,10 @@ int main(int argc, char ** argv) {
wparams.audio_ctx = params.audio_ctx; wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up; wparams.speed_up = params.speed_up;
// disable best_of fallback
wparams.temperature_increment = -1.0f;
wparams.greedy.best_of = -1;
wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data(); wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size(); wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size();

@ -25,7 +25,7 @@
} \ } \
} while (0) } while (0)
#define WHISPER_DEBUG //#define WHISPER_DEBUG
#if defined(WHISPER_DEBUG) #if defined(WHISPER_DEBUG)
#define WHISPER_PRINT_DEBUG(...) \ #define WHISPER_PRINT_DEBUG(...) \
do { \ do { \
@ -3216,6 +3216,7 @@ int whisper_full(
// initialize the decoders // initialize the decoders
int n_decoders = 1; int n_decoders = 1;
switch (params.strategy) { switch (params.strategy) {
case WHISPER_SAMPLING_GREEDY: case WHISPER_SAMPLING_GREEDY:
{ {
@ -3227,6 +3228,8 @@ int whisper_full(
} break; } break;
}; };
n_decoders = std::max(1, n_decoders);
// TAGS: WHISPER_DECODER_INIT // TAGS: WHISPER_DECODER_INIT
for (int j = 1; j < n_decoders; j++) { for (int j = 1; j < n_decoders; j++) {
auto & decoder = ctx->decoders[j]; auto & decoder = ctx->decoders[j];
@ -3348,13 +3351,9 @@ int whisper_full(
} break; } break;
}; };
WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur); n_decoders_cur = std::max(1, n_decoders_cur);
if (t_cur > 0.5) { WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur);
prompt_past.clear();
WHISPER_PRINT_DEBUG("%s: clearing prompt_past\n", __func__);
}
// TAGS: WHISPER_DECODER_INIT // TAGS: WHISPER_DECODER_INIT
for (int j = 0; j < n_decoders_cur; ++j) { for (int j = 0; j < n_decoders_cur; ++j) {
@ -3381,7 +3380,7 @@ int whisper_full(
prompt.clear(); prompt.clear();
// if we have already generated some text, use it as a prompt to condition the next generation // if we have already generated some text, use it as a prompt to condition the next generation
if (!prompt_past.empty()) { if (!prompt_past.empty() && t_cur > 0.5f) {
int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size())); int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
prompt = { whisper_token_prev(ctx) }; prompt = { whisper_token_prev(ctx) };
@ -3392,11 +3391,11 @@ int whisper_full(
prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
// print the prompt // print the prompt
//printf("\n\n"); //WHISPER_PRINT_DEBUG("\n\n");
//for (int i = 0; i < (int) prompt.size(); i++) { //for (int i = 0; i < (int) prompt.size(); i++) {
// printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str()); // WHISPER_PRINT_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str());
//} //}
//printf("\n\n"); //WHISPER_PRINT_DEBUG("\n\n");
if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) { if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
fprintf(stderr, "%s: failed to decode\n", __func__); fprintf(stderr, "%s: failed to decode\n", __func__);
@ -3608,7 +3607,6 @@ int whisper_full(
bool success = true; bool success = true;
// implement logprob threshold // implement logprob threshold
// TODO: implement compression threshold
{ {
auto & decoder = ctx->decoders[best_decoder_id]; auto & decoder = ctx->decoders[best_decoder_id];
@ -3646,7 +3644,7 @@ int whisper_full(
} }
// store the text from this iteration // store the text from this iteration
if (!tokens_cur.empty()) { if (!tokens_cur.empty() && ctx->model.n_loaded > 0) {
int i0 = 0; int i0 = 0;
auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx)); auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));

Loading…
Cancel
Save