|
|
@ -25,6 +25,16 @@
|
|
|
|
} \
|
|
|
|
} \
|
|
|
|
} while (0)
|
|
|
|
} while (0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#define WHISPER_DEBUG
|
|
|
|
|
|
|
|
#if defined(WHISPER_DEBUG)
|
|
|
|
|
|
|
|
#define WHISPER_PRINT_DEBUG(...) \
|
|
|
|
|
|
|
|
do { \
|
|
|
|
|
|
|
|
fprintf(stderr, __VA_ARGS__); \
|
|
|
|
|
|
|
|
} while (0)
|
|
|
|
|
|
|
|
#else
|
|
|
|
|
|
|
|
#define WHISPER_PRINT_DEBUG(...)
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
|
|
#define WHISPER_USE_FLASH_ATTN
|
|
|
|
#define WHISPER_USE_FLASH_ATTN
|
|
|
|
//#define WHISPER_USE_FLASH_FF
|
|
|
|
//#define WHISPER_USE_FLASH_FF
|
|
|
|
#define WHISPER_MAX_DECODERS 16
|
|
|
|
#define WHISPER_MAX_DECODERS 16
|
|
|
@ -1640,7 +1650,7 @@ static bool whisper_decode(
|
|
|
|
const int N = n_tokens;
|
|
|
|
const int N = n_tokens;
|
|
|
|
const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
|
|
|
|
const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
|
|
|
|
|
|
|
|
|
|
|
|
//fprintf(stderr, "n_past = %d, N = %d, M = %d, n_ctx = %d\n", n_past, N, M, n_ctx);
|
|
|
|
WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
|
|
|
|
|
|
|
|
|
|
|
|
struct ggml_init_params params;
|
|
|
|
struct ggml_init_params params;
|
|
|
|
params.mem_size = wctx.buf_compute.size();
|
|
|
|
params.mem_size = wctx.buf_compute.size();
|
|
|
@ -3228,7 +3238,7 @@ int whisper_full(
|
|
|
|
return -4;
|
|
|
|
return -4;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
fprintf(stderr, "%s: initialized self-attention kv cache, decoder %d\n", __func__, j);
|
|
|
|
WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j);
|
|
|
|
|
|
|
|
|
|
|
|
decoder.sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity());
|
|
|
|
decoder.sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity());
|
|
|
|
|
|
|
|
|
|
|
@ -3338,12 +3348,12 @@ int whisper_full(
|
|
|
|
} break;
|
|
|
|
} break;
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
fprintf(stderr, "\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur);
|
|
|
|
WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur);
|
|
|
|
|
|
|
|
|
|
|
|
if (t_cur > 0.5) {
|
|
|
|
if (t_cur > 0.5) {
|
|
|
|
prompt_past.clear();
|
|
|
|
prompt_past.clear();
|
|
|
|
|
|
|
|
|
|
|
|
fprintf(stderr, "%s: clearing prompt_past\n", __func__);
|
|
|
|
WHISPER_PRINT_DEBUG("%s: clearing prompt_past\n", __func__);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// TAGS: WHISPER_DECODER_INIT
|
|
|
|
// TAGS: WHISPER_DECODER_INIT
|
|
|
@ -3465,11 +3475,13 @@ int whisper_full(
|
|
|
|
has_ts = true;
|
|
|
|
has_ts = true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
//{
|
|
|
|
#ifdef WHISPER_DEBUG
|
|
|
|
// const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]";
|
|
|
|
{
|
|
|
|
// printf("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n",
|
|
|
|
const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]";
|
|
|
|
// __func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token.at(token.id).c_str());
|
|
|
|
WHISPER_PRINT_DEBUG("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n",
|
|
|
|
//}
|
|
|
|
__func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token.at(token.id).c_str());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
|
|
// end of segment
|
|
|
|
// end of segment
|
|
|
|
if (token.id == whisper_token_eot(ctx) || // end of text token
|
|
|
|
if (token.id == whisper_token_eot(ctx) || // end of text token
|
|
|
@ -3541,7 +3553,7 @@ int whisper_full(
|
|
|
|
decoder.tokens_tmp.resize(1);
|
|
|
|
decoder.tokens_tmp.resize(1);
|
|
|
|
decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id;
|
|
|
|
decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id;
|
|
|
|
|
|
|
|
|
|
|
|
//fprintf(stderr, "%s: decoder %d: token %d, n_past %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.n_past, decoder.seek_delta);
|
|
|
|
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, n_past %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.n_past, decoder.seek_delta);
|
|
|
|
|
|
|
|
|
|
|
|
if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.n_past, params.n_threads)) {
|
|
|
|
if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.n_past, params.n_threads)) {
|
|
|
|
fprintf(stderr, "%s: failed to decode\n", __func__);
|
|
|
|
fprintf(stderr, "%s: failed to decode\n", __func__);
|
|
|
@ -3574,11 +3586,11 @@ int whisper_full(
|
|
|
|
decoder.sequence.tokens.resize(decoder.sequence.result_len);
|
|
|
|
decoder.sequence.tokens.resize(decoder.sequence.result_len);
|
|
|
|
whisper_sequence_score(params, decoder.sequence);
|
|
|
|
whisper_sequence_score(params, decoder.sequence);
|
|
|
|
|
|
|
|
|
|
|
|
fprintf(stderr, "%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n",
|
|
|
|
WHISPER_PRINT_DEBUG("%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n",
|
|
|
|
__func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy);
|
|
|
|
__func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy);
|
|
|
|
|
|
|
|
|
|
|
|
if (decoder.sequence.result_len > 8 && decoder.sequence.entropy < params.entropy_threshold) {
|
|
|
|
if (decoder.sequence.result_len > 8 && decoder.sequence.entropy < params.entropy_threshold) {
|
|
|
|
fprintf(stderr, "%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n",
|
|
|
|
WHISPER_PRINT_DEBUG("%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n",
|
|
|
|
__func__, j, decoder.sequence.entropy, params.entropy_threshold);
|
|
|
|
__func__, j, decoder.sequence.entropy, params.entropy_threshold);
|
|
|
|
|
|
|
|
|
|
|
|
decoder.failed = true;
|
|
|
|
decoder.failed = true;
|
|
|
@ -3606,15 +3618,15 @@ int whisper_full(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if (success) {
|
|
|
|
if (success) {
|
|
|
|
fprintf(stderr, "%s: best decoder = %d\n", __func__, best_decoder_id);
|
|
|
|
WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id);
|
|
|
|
//for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) {
|
|
|
|
//for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) {
|
|
|
|
// fprintf(stderr, "%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str());
|
|
|
|
// WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str());
|
|
|
|
//}
|
|
|
|
//}
|
|
|
|
|
|
|
|
|
|
|
|
break;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
fprintf(stderr, "\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur);
|
|
|
|
WHISPER_PRINT_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
{
|
|
|
@ -3734,7 +3746,7 @@ int whisper_full(
|
|
|
|
// update audio window
|
|
|
|
// update audio window
|
|
|
|
seek += seek_delta;
|
|
|
|
seek += seek_delta;
|
|
|
|
|
|
|
|
|
|
|
|
fprintf(stderr, "seek = %d, seek_delta = %d\n", seek, seek_delta);
|
|
|
|
WHISPER_PRINT_DEBUG("seek = %d, seek_delta = %d\n", seek, seek_delta);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|