whisper : minor fixes

pull/291/head
Georgi Gerganov 3 years ago
parent c301a7942b
commit 5e97f80fc5
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -459,7 +459,7 @@ int main(int argc, char ** argv) {
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
std::vector<float> pcmf32 (n_samples_30s, 0.0f);
std::vector<float> pcmf32_old(n_samples_30s, 0.0f);
std::vector<float> pcmf32_old;
std::vector<float> pcmf32_new(n_samples_30s, 0.0f);
std::vector<whisper_token> prompt_tokens;
@ -615,6 +615,10 @@ int main(int argc, char ** argv) {
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
// disable best_of fallback
wparams.temperature_increment = -1.0f;
wparams.greedy.best_of = -1;
wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size();

@ -25,7 +25,7 @@
} \
} while (0)
#define WHISPER_DEBUG
//#define WHISPER_DEBUG
#if defined(WHISPER_DEBUG)
#define WHISPER_PRINT_DEBUG(...) \
do { \
@ -3216,6 +3216,7 @@ int whisper_full(
// initialize the decoders
int n_decoders = 1;
switch (params.strategy) {
case WHISPER_SAMPLING_GREEDY:
{
@ -3227,6 +3228,8 @@ int whisper_full(
} break;
};
n_decoders = std::max(1, n_decoders);
// TAGS: WHISPER_DECODER_INIT
for (int j = 1; j < n_decoders; j++) {
auto & decoder = ctx->decoders[j];
@ -3348,13 +3351,9 @@ int whisper_full(
} break;
};
WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur);
if (t_cur > 0.5) {
prompt_past.clear();
n_decoders_cur = std::max(1, n_decoders_cur);
WHISPER_PRINT_DEBUG("%s: clearing prompt_past\n", __func__);
}
WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur);
// TAGS: WHISPER_DECODER_INIT
for (int j = 0; j < n_decoders_cur; ++j) {
@ -3381,7 +3380,7 @@ int whisper_full(
prompt.clear();
// if we have already generated some text, use it as a prompt to condition the next generation
if (!prompt_past.empty()) {
if (!prompt_past.empty() && t_cur > 0.5f) {
int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
prompt = { whisper_token_prev(ctx) };
@ -3392,11 +3391,11 @@ int whisper_full(
prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
// print the prompt
//printf("\n\n");
//WHISPER_PRINT_DEBUG("\n\n");
//for (int i = 0; i < (int) prompt.size(); i++) {
// printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str());
// WHISPER_PRINT_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str());
//}
//printf("\n\n");
//WHISPER_PRINT_DEBUG("\n\n");
if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
fprintf(stderr, "%s: failed to decode\n", __func__);
@ -3608,7 +3607,6 @@ int whisper_full(
bool success = true;
// implement logprob threshold
// TODO: implement compression threshold
{
auto & decoder = ctx->decoders[best_decoder_id];
@ -3646,7 +3644,7 @@ int whisper_full(
}
// store the text from this iteration
if (!tokens_cur.empty()) {
if (!tokens_cur.empty() && ctx->model.n_loaded > 0) {
int i0 = 0;
auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));

Loading…
Cancel
Save