|
|
@ -472,13 +472,12 @@ struct whisper_context {
|
|
|
|
std::vector<uint8_t> buf_compute_layer;
|
|
|
|
std::vector<uint8_t> buf_compute_layer;
|
|
|
|
|
|
|
|
|
|
|
|
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
|
|
|
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
|
|
|
std::vector<float> probs;
|
|
|
|
|
|
|
|
std::vector<float> logits;
|
|
|
|
std::vector<float> logits;
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<whisper_segment> result_all;
|
|
|
|
std::vector<whisper_segment> result_all;
|
|
|
|
std::vector<whisper_token> prompt_past;
|
|
|
|
std::vector<whisper_token> prompt_past;
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
|
|
|
|
std::vector<std::pair<double, whisper_vocab::id>> logits_id;
|
|
|
|
|
|
|
|
|
|
|
|
// [EXPERIMENTAL] token-level timestamps data
|
|
|
|
// [EXPERIMENTAL] token-level timestamps data
|
|
|
|
int64_t t_beg;
|
|
|
|
int64_t t_beg;
|
|
|
@ -575,6 +574,10 @@ static void kv_cache_free(struct whisper_kv_cache & cache) {
|
|
|
|
static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) {
|
|
|
|
static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) {
|
|
|
|
fprintf(stderr, "%s: loading model\n", __func__);
|
|
|
|
fprintf(stderr, "%s: loading model\n", __func__);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const int64_t t_start_us = ggml_time_us();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
wctx.t_start_us = t_start_us;
|
|
|
|
|
|
|
|
|
|
|
|
auto & model = wctx.model;
|
|
|
|
auto & model = wctx.model;
|
|
|
|
auto & vocab = wctx.vocab;
|
|
|
|
auto & vocab = wctx.vocab;
|
|
|
|
|
|
|
|
|
|
|
@ -770,9 +773,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
|
|
|
|
wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
|
|
|
|
wctx.probs.reserve (vocab.n_vocab*model.hparams.n_text_ctx);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
wctx.probs_id.reserve(n_vocab);
|
|
|
|
wctx.logits_id.reserve(n_vocab);
|
|
|
|
|
|
|
|
|
|
|
|
// TAGS: WHISPER_DECODER_INIT
|
|
|
|
// TAGS: WHISPER_DECODER_INIT
|
|
|
|
wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx);
|
|
|
|
wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx);
|
|
|
@ -1178,6 +1180,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
wctx.t_load_us = ggml_time_us() - t_start_us;
|
|
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -1192,8 +1196,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
|
//
|
|
|
|
//
|
|
|
|
static bool whisper_encode(
|
|
|
|
static bool whisper_encode(
|
|
|
|
whisper_context & wctx,
|
|
|
|
whisper_context & wctx,
|
|
|
|
const int n_threads,
|
|
|
|
const int mel_offset,
|
|
|
|
const int mel_offset) {
|
|
|
|
const int n_threads) {
|
|
|
|
|
|
|
|
const int64_t t_start_us = ggml_time_us();
|
|
|
|
|
|
|
|
|
|
|
|
const auto & model = wctx.model;
|
|
|
|
const auto & model = wctx.model;
|
|
|
|
const auto & mel_inp = wctx.mel;
|
|
|
|
const auto & mel_inp = wctx.mel;
|
|
|
|
const auto & hparams = model.hparams;
|
|
|
|
const auto & hparams = model.hparams;
|
|
|
@ -1585,6 +1591,8 @@ static bool whisper_encode(
|
|
|
|
|
|
|
|
|
|
|
|
ggml_free(ctx0);
|
|
|
|
ggml_free(ctx0);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
wctx.t_encode_us += ggml_time_us() - t_start_us;
|
|
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -1601,10 +1609,12 @@ static bool whisper_encode(
|
|
|
|
static bool whisper_decode(
|
|
|
|
static bool whisper_decode(
|
|
|
|
whisper_context & wctx,
|
|
|
|
whisper_context & wctx,
|
|
|
|
whisper_decoder & decoder,
|
|
|
|
whisper_decoder & decoder,
|
|
|
|
const int n_threads,
|
|
|
|
|
|
|
|
const whisper_token * tokens,
|
|
|
|
const whisper_token * tokens,
|
|
|
|
const int n_tokens,
|
|
|
|
const int n_tokens,
|
|
|
|
const int n_past) {
|
|
|
|
const int n_past,
|
|
|
|
|
|
|
|
const int n_threads) {
|
|
|
|
|
|
|
|
const int64_t t_start_us = ggml_time_us();
|
|
|
|
|
|
|
|
|
|
|
|
const auto & model = wctx.model;
|
|
|
|
const auto & model = wctx.model;
|
|
|
|
const auto & hparams = model.hparams;
|
|
|
|
const auto & hparams = model.hparams;
|
|
|
|
|
|
|
|
|
|
|
@ -1613,7 +1623,6 @@ static bool whisper_decode(
|
|
|
|
WHISPER_ASSERT(!!kv_self.ctx);
|
|
|
|
WHISPER_ASSERT(!!kv_self.ctx);
|
|
|
|
|
|
|
|
|
|
|
|
auto & logits_out = wctx.logits;
|
|
|
|
auto & logits_out = wctx.logits;
|
|
|
|
auto & probs_out = wctx.probs;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const int n_vocab = hparams.n_vocab;
|
|
|
|
const int n_vocab = hparams.n_vocab;
|
|
|
|
|
|
|
|
|
|
|
@ -1625,6 +1634,8 @@ static bool whisper_decode(
|
|
|
|
const int N = n_tokens;
|
|
|
|
const int N = n_tokens;
|
|
|
|
const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
|
|
|
|
const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//fprintf(stderr, "n_past = %d, N = %d, M = %d, n_ctx = %d\n", n_past, N, M, n_ctx);
|
|
|
|
|
|
|
|
|
|
|
|
struct ggml_init_params params;
|
|
|
|
struct ggml_init_params params;
|
|
|
|
params.mem_size = wctx.buf_compute.size();
|
|
|
|
params.mem_size = wctx.buf_compute.size();
|
|
|
|
params.mem_buffer = wctx.buf_compute.data();
|
|
|
|
params.mem_buffer = wctx.buf_compute.data();
|
|
|
@ -1933,25 +1944,18 @@ static bool whisper_decode(
|
|
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
|
|
|
|
struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
|
|
|
|
|
|
|
|
|
|
|
|
// logits -> probs
|
|
|
|
|
|
|
|
cur = ggml_dup(ctx0, logits);
|
|
|
|
|
|
|
|
cur = ggml_soft_max(ctx0, cur); // in-place
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// run the computation
|
|
|
|
// run the computation
|
|
|
|
{
|
|
|
|
{
|
|
|
|
struct ggml_cgraph gf = {};
|
|
|
|
struct ggml_cgraph gf = {};
|
|
|
|
gf.n_threads = n_threads;
|
|
|
|
gf.n_threads = n_threads;
|
|
|
|
|
|
|
|
|
|
|
|
ggml_build_forward_expand(&gf, cur);
|
|
|
|
ggml_build_forward_expand(&gf, logits);
|
|
|
|
ggml_graph_compute (ctx0, &gf);
|
|
|
|
ggml_graph_compute (ctx0, &gf);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
logits_out.resize(N*n_vocab);
|
|
|
|
logits_out.resize(N*n_vocab);
|
|
|
|
memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
|
|
|
|
memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
|
|
|
|
|
|
|
|
|
|
|
|
probs_out.resize(N*n_vocab);
|
|
|
|
|
|
|
|
memcpy(probs_out.data(), ggml_get_data(cur), sizeof(float)*N*n_vocab);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (N > 1) {
|
|
|
|
if (N > 1) {
|
|
|
|
//const float mem_per_token = ggml_used_mem(ctx0)/1024.0/1024.0/N;
|
|
|
|
//const float mem_per_token = ggml_used_mem(ctx0)/1024.0/1024.0/N;
|
|
|
|
//printf("%s: used_mem = %f MB / %f per token\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, mem_per_token);
|
|
|
|
//printf("%s: used_mem = %f MB / %f per token\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, mem_per_token);
|
|
|
@ -1960,6 +1964,8 @@ static bool whisper_decode(
|
|
|
|
|
|
|
|
|
|
|
|
ggml_free(ctx0);
|
|
|
|
ggml_free(ctx0);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
wctx.t_decode_us += ggml_time_us() - t_start_us;
|
|
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -2062,6 +2068,7 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
|
|
|
|
|
|
|
|
|
|
|
|
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
|
|
|
|
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
|
|
|
|
static bool log_mel_spectrogram(
|
|
|
|
static bool log_mel_spectrogram(
|
|
|
|
|
|
|
|
whisper_context & wctx,
|
|
|
|
const float * samples,
|
|
|
|
const float * samples,
|
|
|
|
const int n_samples,
|
|
|
|
const int n_samples,
|
|
|
|
const int /*sample_rate*/,
|
|
|
|
const int /*sample_rate*/,
|
|
|
@ -2072,6 +2079,7 @@ static bool log_mel_spectrogram(
|
|
|
|
const whisper_filters & filters,
|
|
|
|
const whisper_filters & filters,
|
|
|
|
const bool speed_up,
|
|
|
|
const bool speed_up,
|
|
|
|
whisper_mel & mel) {
|
|
|
|
whisper_mel & mel) {
|
|
|
|
|
|
|
|
const int64_t t_start_us = ggml_time_us();
|
|
|
|
|
|
|
|
|
|
|
|
// Hanning window
|
|
|
|
// Hanning window
|
|
|
|
std::vector<float> hann;
|
|
|
|
std::vector<float> hann;
|
|
|
@ -2180,6 +2188,8 @@ static bool log_mel_spectrogram(
|
|
|
|
mel.data[i] = (mel.data[i] + 4.0)/4.0;
|
|
|
|
mel.data[i] = (mel.data[i] + 4.0)/4.0;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
wctx.t_mel_us += ggml_time_us() - t_start_us;
|
|
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -2324,10 +2334,6 @@ struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
|
|
|
|
|
|
|
|
|
|
|
|
whisper_context * ctx = new whisper_context;
|
|
|
|
whisper_context * ctx = new whisper_context;
|
|
|
|
|
|
|
|
|
|
|
|
const int64_t t_start_us = ggml_time_us();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ctx->t_start_us = t_start_us;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (!whisper_model_load(loader, *ctx)) {
|
|
|
|
if (!whisper_model_load(loader, *ctx)) {
|
|
|
|
loader->close(loader->context);
|
|
|
|
loader->close(loader->context);
|
|
|
|
fprintf(stderr, "%s: failed to load model\n", __func__);
|
|
|
|
fprintf(stderr, "%s: failed to load model\n", __func__);
|
|
|
@ -2335,8 +2341,6 @@ struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
|
|
|
|
return nullptr;
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
ctx->t_load_us = ggml_time_us() - t_start_us;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loader->close(loader->context);
|
|
|
|
loader->close(loader->context);
|
|
|
|
|
|
|
|
|
|
|
|
return ctx;
|
|
|
|
return ctx;
|
|
|
@ -2363,29 +2367,21 @@ void whisper_free(struct whisper_context * ctx) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
|
|
|
|
int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
|
|
|
|
const int64_t t_start_us = ggml_time_us();
|
|
|
|
if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) {
|
|
|
|
|
|
|
|
|
|
|
|
if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) {
|
|
|
|
|
|
|
|
fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
|
|
|
|
fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
|
|
|
|
return -1;
|
|
|
|
return -1;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
ctx->t_mel_us = ggml_time_us() - t_start_us;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return 0;
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
|
|
|
|
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
|
|
|
|
int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
|
|
|
|
int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
|
|
|
|
const int64_t t_start_us = ggml_time_us();
|
|
|
|
if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) {
|
|
|
|
|
|
|
|
|
|
|
|
if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) {
|
|
|
|
|
|
|
|
fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
|
|
|
|
fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
|
|
|
|
return -1;
|
|
|
|
return -1;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
ctx->t_mel_us = ggml_time_us() - t_start_us;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return 0;
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -2409,31 +2405,23 @@ int whisper_set_mel(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
|
|
|
|
int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
|
|
|
|
const int64_t t_start_us = ggml_time_us();
|
|
|
|
if (!whisper_encode(*ctx, offset, n_threads)) {
|
|
|
|
|
|
|
|
|
|
|
|
if (!whisper_encode(*ctx, n_threads, offset)) {
|
|
|
|
|
|
|
|
fprintf(stderr, "%s: failed to eval\n", __func__);
|
|
|
|
fprintf(stderr, "%s: failed to eval\n", __func__);
|
|
|
|
return -1;
|
|
|
|
return -1;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
ctx->t_encode_us += ggml_time_us() - t_start_us;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return 0;
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
|
|
|
|
int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
|
|
|
|
const int64_t t_start_us = ggml_time_us();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// TODO: add selected_decoder_id to context
|
|
|
|
// TODO: add selected_decoder_id to context
|
|
|
|
const int selected_decoder_id = 0;
|
|
|
|
const int selected_decoder_id = 0;
|
|
|
|
|
|
|
|
|
|
|
|
if (!whisper_decode(*ctx, ctx->decoders[selected_decoder_id], n_threads, tokens, n_tokens, n_past)) {
|
|
|
|
if (!whisper_decode(*ctx, ctx->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
|
|
|
|
fprintf(stderr, "%s: failed to eval\n", __func__);
|
|
|
|
fprintf(stderr, "%s: failed to eval\n", __func__);
|
|
|
|
return 1;
|
|
|
|
return 1;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
ctx->t_decode_us += ggml_time_us() - t_start_us;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return 0;
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -2517,36 +2505,39 @@ int whisper_lang_auto_detect(
|
|
|
|
return -7;
|
|
|
|
return -7;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
auto & probs_id = ctx->probs_id;
|
|
|
|
auto & logits_id = ctx->logits_id;
|
|
|
|
probs_id.clear();
|
|
|
|
logits_id.clear();
|
|
|
|
|
|
|
|
|
|
|
|
for (const auto & kv : g_lang) {
|
|
|
|
for (const auto & kv : g_lang) {
|
|
|
|
const auto token_lang = whisper_token_lang(ctx, kv.second.first);
|
|
|
|
const auto token_lang = whisper_token_lang(ctx, kv.second.first);
|
|
|
|
probs_id.emplace_back(ctx->probs[token_lang], kv.second.first);
|
|
|
|
logits_id.emplace_back(ctx->logits[token_lang], kv.second.first);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// sort descending
|
|
|
|
// sort descending
|
|
|
|
{
|
|
|
|
{
|
|
|
|
using pair_type = std::remove_reference<decltype(probs_id)>::type::value_type;
|
|
|
|
using pair_type = std::remove_reference<decltype(logits_id)>::type::value_type;
|
|
|
|
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
|
|
|
|
std::sort(logits_id.begin(), logits_id.end(), [](const pair_type & a, const pair_type & b) {
|
|
|
|
return a.first > b.first;
|
|
|
|
return a.first > b.first;
|
|
|
|
});
|
|
|
|
});
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// softmax
|
|
|
|
// softmax
|
|
|
|
{
|
|
|
|
{
|
|
|
|
float sum = 0;
|
|
|
|
const auto max = logits_id[0].first;
|
|
|
|
for (const auto & kv : probs_id) {
|
|
|
|
|
|
|
|
sum += exp(kv.first);
|
|
|
|
double sum = 0.0f;
|
|
|
|
|
|
|
|
for (auto & kv : logits_id) {
|
|
|
|
|
|
|
|
kv.first = exp(kv.first - max);
|
|
|
|
|
|
|
|
sum += kv.first;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for (auto & kv : probs_id) {
|
|
|
|
for (auto & kv : logits_id) {
|
|
|
|
kv.first = exp(kv.first) / sum;
|
|
|
|
kv.first /= sum;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
{
|
|
|
|
for (const auto & prob : probs_id) {
|
|
|
|
for (const auto & prob : logits_id) {
|
|
|
|
if (lang_probs) {
|
|
|
|
if (lang_probs) {
|
|
|
|
lang_probs[prob.second] = prob.first;
|
|
|
|
lang_probs[prob.second] = prob.first;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -2555,7 +2546,7 @@ int whisper_lang_auto_detect(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return probs_id[0].second;
|
|
|
|
return logits_id[0].second;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
int whisper_n_len(struct whisper_context * ctx) {
|
|
|
|
int whisper_n_len(struct whisper_context * ctx) {
|
|
|
@ -2582,10 +2573,6 @@ float * whisper_get_logits(struct whisper_context * ctx) {
|
|
|
|
return ctx->logits.data();
|
|
|
|
return ctx->logits.data();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
float * whisper_get_probs(struct whisper_context * ctx) {
|
|
|
|
|
|
|
|
return ctx->probs.data();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) {
|
|
|
|
const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) {
|
|
|
|
return ctx->vocab.id_to_token.at(token).c_str();
|
|
|
|
return ctx->vocab.id_to_token.at(token).c_str();
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -2878,13 +2865,11 @@ static void whisper_process_logits(
|
|
|
|
auto & logits = decoder.logits;
|
|
|
|
auto & logits = decoder.logits;
|
|
|
|
auto & logprobs = decoder.logprobs;
|
|
|
|
auto & logprobs = decoder.logprobs;
|
|
|
|
{
|
|
|
|
{
|
|
|
|
probs.resize(n_logits);
|
|
|
|
|
|
|
|
memcpy(probs.data(), ctx.probs.data() + (ctx.probs.size() - n_logits), n_logits*sizeof(float));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logits.resize(n_logits);
|
|
|
|
logits.resize(n_logits);
|
|
|
|
memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float));
|
|
|
|
memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float));
|
|
|
|
|
|
|
|
|
|
|
|
// will be populated a bit later
|
|
|
|
// will be populated a bit later
|
|
|
|
|
|
|
|
probs.resize(n_logits);
|
|
|
|
logprobs.resize(n_logits);
|
|
|
|
logprobs.resize(n_logits);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -2904,6 +2889,10 @@ static void whisper_process_logits(
|
|
|
|
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412
|
|
|
|
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412
|
|
|
|
logits[vocab.token_not] = -INFINITY;
|
|
|
|
logits[vocab.token_not] = -INFINITY;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// suppress sot and solm tokens
|
|
|
|
|
|
|
|
logits[vocab.token_sot] = -INFINITY;
|
|
|
|
|
|
|
|
logits[vocab.token_solm] = -INFINITY;
|
|
|
|
|
|
|
|
|
|
|
|
// timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
|
|
|
// timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
|
|
|
// https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424
|
|
|
|
// https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424
|
|
|
|
{
|
|
|
|
{
|
|
|
@ -2941,11 +2930,18 @@ static void whisper_process_logits(
|
|
|
|
const float logit_max = *std::max_element(logits.begin(), logits.end());
|
|
|
|
const float logit_max = *std::max_element(logits.begin(), logits.end());
|
|
|
|
float logsumexp = 0.0f;
|
|
|
|
float logsumexp = 0.0f;
|
|
|
|
for (int i = 0; i < n_logits; ++i) {
|
|
|
|
for (int i = 0; i < n_logits; ++i) {
|
|
|
|
|
|
|
|
if (logits[i] > -INFINITY) {
|
|
|
|
logsumexp += expf(logits[i] - logit_max);
|
|
|
|
logsumexp += expf(logits[i] - logit_max);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
logsumexp = logf(logsumexp) + logit_max;
|
|
|
|
logsumexp = logf(logsumexp) + logit_max;
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < n_logits; ++i) {
|
|
|
|
for (int i = 0; i < n_logits; ++i) {
|
|
|
|
|
|
|
|
if (logits[i] > -INFINITY) {
|
|
|
|
logprobs[i] = logits[i] - logsumexp;
|
|
|
|
logprobs[i] = logits[i] - logsumexp;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
logprobs[i] = -INFINITY;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -2958,19 +2954,36 @@ static void whisper_process_logits(
|
|
|
|
float logsumexp = 0.0f;
|
|
|
|
float logsumexp = 0.0f;
|
|
|
|
const float logprob_max = *std::max_element(logprobs.begin() + vocab.token_beg, logprobs.end());
|
|
|
|
const float logprob_max = *std::max_element(logprobs.begin() + vocab.token_beg, logprobs.end());
|
|
|
|
for (int i = vocab.token_beg; i < n_logits; ++i) {
|
|
|
|
for (int i = vocab.token_beg; i < n_logits; ++i) {
|
|
|
|
|
|
|
|
if (logprobs[i] > -INFINITY) {
|
|
|
|
logsumexp += expf(logprobs[i] - logprob_max);
|
|
|
|
logsumexp += expf(logprobs[i] - logprob_max);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
logsumexp = logf(logsumexp) + logprob_max;
|
|
|
|
}
|
|
|
|
timestamp_logprob = logsumexp;
|
|
|
|
if (logsumexp > 0.0f) {
|
|
|
|
|
|
|
|
timestamp_logprob = logf(logsumexp) + logprob_max;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg);
|
|
|
|
const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//fprintf(stderr, "timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
|
|
|
|
|
|
|
|
|
|
|
|
if (timestamp_logprob > max_text_token_logprob) {
|
|
|
|
if (timestamp_logprob > max_text_token_logprob) {
|
|
|
|
for (int i = 0; i < vocab.token_beg; ++i) {
|
|
|
|
for (int i = 0; i < vocab.token_beg; ++i) {
|
|
|
|
logits[i] = -INFINITY;
|
|
|
|
logits[i] = -INFINITY;
|
|
|
|
|
|
|
|
logprobs[i] = -INFINITY;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// compute probs
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
for (int i = 0; i < n_logits; ++i) {
|
|
|
|
|
|
|
|
if (logits[i] == -INFINITY) {
|
|
|
|
|
|
|
|
probs[i] = 0.0f;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
probs[i] = expf(logprobs[i]);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -2978,8 +2991,10 @@ static void whisper_process_logits(
|
|
|
|
// print first 100 logits - token string : logit
|
|
|
|
// print first 100 logits - token string : logit
|
|
|
|
for (int i = 0; i < 100; i++) {
|
|
|
|
for (int i = 0; i < 100; i++) {
|
|
|
|
const auto token = vocab.id_to_token.at(i);
|
|
|
|
const auto token = vocab.id_to_token.at(i);
|
|
|
|
|
|
|
|
const auto prob = probs[i];
|
|
|
|
const auto logit = logits[i];
|
|
|
|
const auto logit = logits[i];
|
|
|
|
printf("%s : %f\n", token.c_str(), logit);
|
|
|
|
const auto logprob = logprobs[i];
|
|
|
|
|
|
|
|
printf("%s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// "And", "and", " And", " and"
|
|
|
|
// "And", "and", " And", " and"
|
|
|
@ -2994,26 +3009,13 @@ static void whisper_process_logits(
|
|
|
|
printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
|
|
|
|
printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
|
|
|
|
printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
|
|
|
|
printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
|
|
|
|
printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
|
|
|
|
printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//switch (params.strategy) {
|
|
|
|
|
|
|
|
// case WHISPER_SAMPLING_GREEDY:
|
|
|
|
|
|
|
|
// {
|
|
|
|
|
|
|
|
// // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L249-L274
|
|
|
|
|
|
|
|
// // TODO: implement
|
|
|
|
|
|
|
|
// result = (is_initial) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx);
|
|
|
|
|
|
|
|
// } break;
|
|
|
|
|
|
|
|
// case WHISPER_SAMPLING_BEAM_SEARCH:
|
|
|
|
|
|
|
|
// {
|
|
|
|
|
|
|
|
// // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L277C13-L364
|
|
|
|
|
|
|
|
// // TODO: implement
|
|
|
|
|
|
|
|
// } break;
|
|
|
|
|
|
|
|
//}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//sum_logprobs += logprobs[result.id];
|
|
|
|
printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]);
|
|
|
|
//printf("token: %16s, logprob: %8.3f, sum_logprobs: %8.3f, avg_logprob: %8.3f\n", vocab.id_to_token.at(result.id).c_str(), logprobs[result.id], sum_logprobs, sum_logprobs/(tokens_cur.size() + 1));
|
|
|
|
printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]);
|
|
|
|
|
|
|
|
printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
|
|
|
|
//return result;
|
|
|
|
printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
|
|
|
|
|
|
|
|
printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]);
|
|
|
|
|
|
|
|
#endif
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// select the most probable token
|
|
|
|
// select the most probable token
|
|
|
@ -3052,13 +3054,6 @@ static whisper_token_data whisper_sample_best(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < n_logits; ++i) {
|
|
|
|
for (int i = 0; i < n_logits; ++i) {
|
|
|
|
// never sample these:
|
|
|
|
|
|
|
|
if (i == vocab.token_sot ||
|
|
|
|
|
|
|
|
i == vocab.token_solm ||
|
|
|
|
|
|
|
|
i == vocab.token_not) {
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (result.p < probs[i]) {
|
|
|
|
if (result.p < probs[i]) {
|
|
|
|
result.id = i;
|
|
|
|
result.id = i;
|
|
|
|
result.p = probs[i];
|
|
|
|
result.p = probs[i];
|
|
|
@ -3334,18 +3329,23 @@ int whisper_full(
|
|
|
|
|
|
|
|
|
|
|
|
// print the prompt
|
|
|
|
// print the prompt
|
|
|
|
//printf("\n\n");
|
|
|
|
//printf("\n\n");
|
|
|
|
//for (int i = 0; i < prompt.size(); i++) {
|
|
|
|
//for (int i = 0; i < (int) prompt.size(); i++) {
|
|
|
|
// printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token[prompt[i]].c_str());
|
|
|
|
// printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token[prompt[i]].c_str());
|
|
|
|
//}
|
|
|
|
//}
|
|
|
|
//printf("\n\n");
|
|
|
|
//printf("\n\n");
|
|
|
|
|
|
|
|
|
|
|
|
if (!whisper_decode(*ctx, ctx->decoders[0], params.n_threads, prompt.data(), prompt.size(), 0)) {
|
|
|
|
if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
|
|
|
|
fprintf(stderr, "%s: failed to decode\n", __func__);
|
|
|
|
fprintf(stderr, "%s: failed to decode\n", __func__);
|
|
|
|
return -7;
|
|
|
|
return -7;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
const int64_t t_start_sample_us = ggml_time_us();
|
|
|
|
|
|
|
|
|
|
|
|
whisper_process_logits(*ctx, ctx->decoders[0], params);
|
|
|
|
whisper_process_logits(*ctx, ctx->decoders[0], params);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ctx->decoders[0].n_past += prompt.size();
|
|
|
|
|
|
|
|
|
|
|
|
for (int j = 1; j < n_decoders_cur; ++j) {
|
|
|
|
for (int j = 1; j < n_decoders_cur; ++j) {
|
|
|
|
auto & decoder = ctx->decoders[j];
|
|
|
|
auto & decoder = ctx->decoders[j];
|
|
|
|
|
|
|
|
|
|
|
@ -3357,9 +3357,14 @@ int whisper_full(
|
|
|
|
memcpy(decoder.logits.data(), ctx->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
|
|
|
|
memcpy(decoder.logits.data(), ctx->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
|
|
|
|
memcpy(decoder.logprobs.data(), ctx->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
|
|
|
|
memcpy(decoder.logprobs.data(), ctx->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
|
|
|
|
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
|
|
|
|
|
|
|
|
const int64_t t_start_sample_us = ggml_time_us();
|
|
|
|
|
|
|
|
|
|
|
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
|
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
|
|
auto & decoder = ctx->decoders[j];
|
|
|
|
auto & decoder = ctx->decoders[j];
|
|
|
|
|
|
|
|
|
|
|
@ -3407,7 +3412,8 @@ int whisper_full(
|
|
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
{
|
|
|
|
const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]";
|
|
|
|
const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]";
|
|
|
|
printf("%s: %3d, decoder = %d, %10s %6d %6.3f '%s'\n", __func__, i, j, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str());
|
|
|
|
printf("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n",
|
|
|
|
|
|
|
|
__func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token[token.id].c_str());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// end of segment
|
|
|
|
// end of segment
|
|
|
@ -3468,6 +3474,8 @@ int whisper_full(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
|
|
|
|
|
|
|
|
|
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
|
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
|
|
auto & decoder = ctx->decoders[j];
|
|
|
|
auto & decoder = ctx->decoders[j];
|
|
|
|
|
|
|
|
|
|
|
@ -3480,14 +3488,20 @@ int whisper_full(
|
|
|
|
|
|
|
|
|
|
|
|
//fprintf(stderr, "%s: decoder %d: token %d, n_past %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.n_past, decoder.seek_delta);
|
|
|
|
//fprintf(stderr, "%s: decoder %d: token %d, n_past %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.n_past, decoder.seek_delta);
|
|
|
|
|
|
|
|
|
|
|
|
if (!whisper_decode(*ctx, decoder, params.n_threads, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.n_past)) {
|
|
|
|
if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.n_past, params.n_threads)) {
|
|
|
|
fprintf(stderr, "%s: failed to decode\n", __func__);
|
|
|
|
fprintf(stderr, "%s: failed to decode\n", __func__);
|
|
|
|
return -7;
|
|
|
|
return -7;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
const int64_t t_start_sample_us = ggml_time_us();
|
|
|
|
|
|
|
|
|
|
|
|
whisper_process_logits(*ctx, decoder, params);
|
|
|
|
whisper_process_logits(*ctx, decoder, params);
|
|
|
|
|
|
|
|
|
|
|
|
++decoder.n_past;
|
|
|
|
++decoder.n_past;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -3502,12 +3516,15 @@ int whisper_full(
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
whisper_sequence_score(params, ctx->decoders[j].sequence);
|
|
|
|
decoder.sequence.tokens.resize(decoder.sequence.result_len);
|
|
|
|
|
|
|
|
whisper_sequence_score(params, decoder.sequence);
|
|
|
|
|
|
|
|
|
|
|
|
if (best_score < decoder.sequence.score) {
|
|
|
|
if (best_score < decoder.sequence.score) {
|
|
|
|
best_score = decoder.sequence.score;
|
|
|
|
best_score = decoder.sequence.score;
|
|
|
|
best_decoder_id = j;
|
|
|
|
best_decoder_id = j;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fprintf(stderr, "%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f\n", __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -3671,9 +3688,8 @@ int whisper_full_parallel(
|
|
|
|
ctx_p = *ctx;
|
|
|
|
ctx_p = *ctx;
|
|
|
|
|
|
|
|
|
|
|
|
ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx);
|
|
|
|
ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx);
|
|
|
|
ctx_p.probs.reserve (ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ctx_p.probs_id.reserve(ctx_p.vocab.n_vocab);
|
|
|
|
ctx_p.logits_id.reserve(ctx_p.vocab.n_vocab);
|
|
|
|
|
|
|
|
|
|
|
|
if (!kv_cache_reinit(ctx_p.kv_cross)) {
|
|
|
|
if (!kv_cache_reinit(ctx_p.kv_cross)) {
|
|
|
|
fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention, processor %d\n", __func__, i);
|
|
|
|
fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention, processor %d\n", __func__, i);
|
|
|
|