|
|
@ -418,14 +418,31 @@ struct whisper_model {
|
|
|
|
|
|
|
|
|
|
|
|
struct whisper_sequence {
|
|
|
|
struct whisper_sequence {
|
|
|
|
std::vector<whisper_token_data> tokens;
|
|
|
|
std::vector<whisper_token_data> tokens;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// the accumulated transcription in the current interation (used to truncate the tokens array)
|
|
|
|
|
|
|
|
int result_len;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
double sum_logprobs;
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// TAGS: WHISPER_DECODER_INIT
|
|
|
|
struct whisper_decoder {
|
|
|
|
struct whisper_decoder {
|
|
|
|
whisper_kv_cache kv_self;
|
|
|
|
whisper_kv_cache kv_self;
|
|
|
|
|
|
|
|
|
|
|
|
whisper_sequence sequence;
|
|
|
|
whisper_sequence sequence;
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<whisper_token> prompt;
|
|
|
|
int n_past;
|
|
|
|
|
|
|
|
int seek_delta;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool failed; // has the current segment failed to decode?
|
|
|
|
|
|
|
|
bool completed; // has the decoder completed the current segment?
|
|
|
|
|
|
|
|
bool has_ts; // have we already sampled a non-beg timestamp token for the current segment?
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<whisper_token> tokens;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// new token logits and logprobs after the last whisper_decode (1-dimensional array: [n_vocab])
|
|
|
|
|
|
|
|
std::vector<float> logits;
|
|
|
|
|
|
|
|
std::vector<float> logprobs;
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
struct whisper_context {
|
|
|
|
struct whisper_context {
|
|
|
@ -447,23 +464,17 @@ struct whisper_context {
|
|
|
|
|
|
|
|
|
|
|
|
whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
|
|
|
|
whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
|
|
|
|
|
|
|
|
|
|
|
|
int selected_decoder_id = 0;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// memory buffers used by encode / decode contexts
|
|
|
|
// memory buffers used by encode / decode contexts
|
|
|
|
std::vector<uint8_t> buf_compute;
|
|
|
|
std::vector<uint8_t> buf_compute;
|
|
|
|
std::vector<uint8_t> buf_compute_layer;
|
|
|
|
std::vector<uint8_t> buf_compute_layer;
|
|
|
|
|
|
|
|
|
|
|
|
// decode output
|
|
|
|
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
|
|
|
std::vector<float> probs;
|
|
|
|
std::vector<float> probs;
|
|
|
|
std::vector<float> logits;
|
|
|
|
std::vector<float> logits;
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<whisper_segment> result_all;
|
|
|
|
std::vector<whisper_segment> result_all;
|
|
|
|
std::vector<whisper_token> prompt_past;
|
|
|
|
std::vector<whisper_token> prompt_past;
|
|
|
|
|
|
|
|
|
|
|
|
// used to avoid allocations
|
|
|
|
|
|
|
|
std::vector<float> work_logits;
|
|
|
|
|
|
|
|
std::vector<float> work_logprobs;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
|
|
|
|
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
|
|
|
|
|
|
|
|
|
|
|
|
// [EXPERIMENTAL] token-level timestamps data
|
|
|
|
// [EXPERIMENTAL] token-level timestamps data
|
|
|
@ -651,8 +662,6 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
|
wctx.model.buf = new std::vector<uint8_t>();
|
|
|
|
wctx.model.buf = new std::vector<uint8_t>();
|
|
|
|
wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type));
|
|
|
|
wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type));
|
|
|
|
|
|
|
|
|
|
|
|
wctx.selected_decoder_id = 0;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, wctx.wtype, model.hparams.n_text_ctx)) {
|
|
|
|
if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, wctx.wtype, model.hparams.n_text_ctx)) {
|
|
|
|
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
|
|
|
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
@ -760,13 +769,13 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
|
wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
|
|
|
|
wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
|
|
|
|
wctx.probs.reserve (vocab.n_vocab*model.hparams.n_text_ctx);
|
|
|
|
wctx.probs.reserve (vocab.n_vocab*model.hparams.n_text_ctx);
|
|
|
|
|
|
|
|
|
|
|
|
wctx.work_logits.reserve (vocab.n_vocab);
|
|
|
|
|
|
|
|
wctx.work_logprobs.reserve(vocab.n_vocab);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
wctx.probs_id.reserve(n_vocab);
|
|
|
|
wctx.probs_id.reserve(n_vocab);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// TAGS: WHISPER_DECODER_INIT
|
|
|
|
wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx);
|
|
|
|
wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx);
|
|
|
|
wctx.decoders[0].prompt.reserve(model.hparams.n_text_ctx);
|
|
|
|
|
|
|
|
|
|
|
|
wctx.decoders[0].logits.reserve (vocab.n_vocab);
|
|
|
|
|
|
|
|
wctx.decoders[0].logprobs.reserve(vocab.n_vocab);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
size_t ctx_size = 0;
|
|
|
|
size_t ctx_size = 0;
|
|
|
@ -1587,6 +1596,7 @@ static bool whisper_encode(
|
|
|
|
//
|
|
|
|
//
|
|
|
|
static bool whisper_decode(
|
|
|
|
static bool whisper_decode(
|
|
|
|
whisper_context & wctx,
|
|
|
|
whisper_context & wctx,
|
|
|
|
|
|
|
|
whisper_decoder & decoder,
|
|
|
|
const int n_threads,
|
|
|
|
const int n_threads,
|
|
|
|
const whisper_token * tokens,
|
|
|
|
const whisper_token * tokens,
|
|
|
|
const int n_tokens,
|
|
|
|
const int n_tokens,
|
|
|
@ -1594,10 +1604,7 @@ static bool whisper_decode(
|
|
|
|
const auto & model = wctx.model;
|
|
|
|
const auto & model = wctx.model;
|
|
|
|
const auto & hparams = model.hparams;
|
|
|
|
const auto & hparams = model.hparams;
|
|
|
|
|
|
|
|
|
|
|
|
WHISPER_ASSERT(wctx.selected_decoder_id >= 0);
|
|
|
|
auto & kv_self = decoder.kv_self;
|
|
|
|
WHISPER_ASSERT(wctx.selected_decoder_id < WHISPER_MAX_DECODERS);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto & kv_self = wctx.decoders[wctx.selected_decoder_id].kv_self;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
WHISPER_ASSERT(!!kv_self.ctx);
|
|
|
|
WHISPER_ASSERT(!!kv_self.ctx);
|
|
|
|
|
|
|
|
|
|
|
@ -2506,7 +2513,10 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
|
|
|
|
int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
|
|
|
|
int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
|
|
|
|
const int64_t t_start_us = ggml_time_us();
|
|
|
|
const int64_t t_start_us = ggml_time_us();
|
|
|
|
|
|
|
|
|
|
|
|
if (!whisper_decode(*ctx, n_threads, tokens, n_tokens, n_past)) {
|
|
|
|
// TODO: add selected_decoder_id to context
|
|
|
|
|
|
|
|
const int selected_decoder_id = 0;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (!whisper_decode(*ctx, ctx->decoders[selected_decoder_id], n_threads, tokens, n_tokens, n_past)) {
|
|
|
|
fprintf(stderr, "%s: failed to eval\n", __func__);
|
|
|
|
fprintf(stderr, "%s: failed to eval\n", __func__);
|
|
|
|
return 1;
|
|
|
|
return 1;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -2516,28 +2526,6 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
|
|
|
|
return 0;
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// TODO: remove
|
|
|
|
|
|
|
|
static struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) {
|
|
|
|
|
|
|
|
const int64_t t_start_sample_us = ggml_time_us();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const auto res = whisper_sample_best(*ctx, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), false, false);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return res;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// TODO: remove
|
|
|
|
|
|
|
|
static struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial) {
|
|
|
|
|
|
|
|
const int64_t t_start_sample_us = ggml_time_us();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const auto res = whisper_sample_best(*ctx, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), true, is_initial);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return res;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) {
|
|
|
|
int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) {
|
|
|
|
const auto res = tokenize(ctx->vocab, text);
|
|
|
|
const auto res = tokenize(ctx->vocab, text);
|
|
|
|
|
|
|
|
|
|
|
@ -2899,15 +2887,15 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
|
// forward declarations
|
|
|
|
// forward declarations
|
|
|
|
static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window);
|
|
|
|
static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window);
|
|
|
|
static void whisper_exp_compute_token_level_timestamps(
|
|
|
|
static void whisper_exp_compute_token_level_timestamps(
|
|
|
|
struct whisper_context * ctx,
|
|
|
|
struct whisper_context & ctx,
|
|
|
|
int i_segment,
|
|
|
|
int i_segment,
|
|
|
|
float thold_pt,
|
|
|
|
float thold_pt,
|
|
|
|
float thold_ptsum);
|
|
|
|
float thold_ptsum);
|
|
|
|
|
|
|
|
|
|
|
|
// wrap the last segment to max_len characters
|
|
|
|
// wrap the last segment to max_len characters
|
|
|
|
// returns the number of new segments
|
|
|
|
// returns the number of new segments
|
|
|
|
static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) {
|
|
|
|
static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) {
|
|
|
|
auto segment = ctx->result_all.back();
|
|
|
|
auto segment = ctx.result_all.back();
|
|
|
|
|
|
|
|
|
|
|
|
int res = 1;
|
|
|
|
int res = 1;
|
|
|
|
int acc = 0;
|
|
|
|
int acc = 0;
|
|
|
@ -2916,34 +2904,34 @@ static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) {
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < (int) segment.tokens.size(); i++) {
|
|
|
|
for (int i = 0; i < (int) segment.tokens.size(); i++) {
|
|
|
|
const auto & token = segment.tokens[i];
|
|
|
|
const auto & token = segment.tokens[i];
|
|
|
|
if (token.id >= whisper_token_eot(ctx)) {
|
|
|
|
if (token.id >= whisper_token_eot(&ctx)) {
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
const auto txt = whisper_token_to_str(ctx, token.id);
|
|
|
|
const auto txt = whisper_token_to_str(&ctx, token.id);
|
|
|
|
|
|
|
|
|
|
|
|
const int cur = strlen(txt);
|
|
|
|
const int cur = strlen(txt);
|
|
|
|
|
|
|
|
|
|
|
|
if (acc + cur > max_len && i > 0) {
|
|
|
|
if (acc + cur > max_len && i > 0) {
|
|
|
|
// split here
|
|
|
|
// split here
|
|
|
|
ctx->result_all.back().text = std::move(text);
|
|
|
|
ctx.result_all.back().text = std::move(text);
|
|
|
|
ctx->result_all.back().t1 = token.t0;
|
|
|
|
ctx.result_all.back().t1 = token.t0;
|
|
|
|
ctx->result_all.back().tokens.resize(i);
|
|
|
|
ctx.result_all.back().tokens.resize(i);
|
|
|
|
|
|
|
|
|
|
|
|
ctx->result_all.push_back({});
|
|
|
|
ctx.result_all.push_back({});
|
|
|
|
ctx->result_all.back().t0 = token.t0;
|
|
|
|
ctx.result_all.back().t0 = token.t0;
|
|
|
|
ctx->result_all.back().t1 = segment.t1;
|
|
|
|
ctx.result_all.back().t1 = segment.t1;
|
|
|
|
|
|
|
|
|
|
|
|
// add tokens [i, end] to the new segment
|
|
|
|
// add tokens [i, end] to the new segment
|
|
|
|
ctx->result_all.back().tokens.insert(
|
|
|
|
ctx.result_all.back().tokens.insert(
|
|
|
|
ctx->result_all.back().tokens.end(),
|
|
|
|
ctx.result_all.back().tokens.end(),
|
|
|
|
segment.tokens.begin() + i,
|
|
|
|
segment.tokens.begin() + i,
|
|
|
|
segment.tokens.end());
|
|
|
|
segment.tokens.end());
|
|
|
|
|
|
|
|
|
|
|
|
acc = 0;
|
|
|
|
acc = 0;
|
|
|
|
text = "";
|
|
|
|
text = "";
|
|
|
|
|
|
|
|
|
|
|
|
segment = ctx->result_all.back();
|
|
|
|
segment = ctx.result_all.back();
|
|
|
|
i = -1;
|
|
|
|
i = -1;
|
|
|
|
|
|
|
|
|
|
|
|
res++;
|
|
|
|
res++;
|
|
|
@ -2953,33 +2941,33 @@ static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
ctx->result_all.back().text = std::move(text);
|
|
|
|
ctx.result_all.back().text = std::move(text);
|
|
|
|
|
|
|
|
|
|
|
|
return res;
|
|
|
|
return res;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
static struct whisper_token_data whisper_sample_next_token(
|
|
|
|
// process the logits for the selected decoder
|
|
|
|
struct whisper_context * ctx,
|
|
|
|
// - applyies logit filters
|
|
|
|
struct whisper_full_params params,
|
|
|
|
// - computes logprobs
|
|
|
|
double & sum_logprobs,
|
|
|
|
static void whisper_process_logits(
|
|
|
|
const std::vector<whisper_token> & prompt,
|
|
|
|
struct whisper_context & ctx,
|
|
|
|
const std::vector<whisper_token_data> & tokens_cur) {
|
|
|
|
struct whisper_decoder & decoder,
|
|
|
|
struct whisper_token_data result = {};
|
|
|
|
struct whisper_full_params params) {
|
|
|
|
|
|
|
|
const auto & vocab = ctx.vocab;
|
|
|
|
const auto & vocab = ctx->vocab;
|
|
|
|
const auto & tokens_cur = decoder.sequence.tokens;
|
|
|
|
|
|
|
|
|
|
|
|
const bool is_initial = tokens_cur.size() == 0;
|
|
|
|
const bool is_initial = tokens_cur.size() == 0;
|
|
|
|
const int n_logits = vocab.id_to_token.size();
|
|
|
|
const int n_logits = vocab.id_to_token.size();
|
|
|
|
|
|
|
|
|
|
|
|
WHISPER_ASSERT(n_logits == ctx->vocab.n_vocab);
|
|
|
|
WHISPER_ASSERT(n_logits == ctx.vocab.n_vocab);
|
|
|
|
|
|
|
|
|
|
|
|
// extract the logits for the last token
|
|
|
|
// extract the logits for the last token
|
|
|
|
// we will be mutating and therefore we don't want to use the ctx->logits buffer directly
|
|
|
|
// we will be mutating and therefore we don't want to use the ctx.logits buffer directly
|
|
|
|
auto & logits = ctx->work_logits;
|
|
|
|
auto & logits = decoder.logits;
|
|
|
|
auto & logprobs = ctx->work_logprobs;
|
|
|
|
auto & logprobs = decoder.logprobs;
|
|
|
|
{
|
|
|
|
{
|
|
|
|
logits.resize(n_logits);
|
|
|
|
logits.resize(n_logits);
|
|
|
|
memcpy(logits.data(), ctx->logits.data() + (ctx->logits.size() - n_logits), n_logits*sizeof(float));
|
|
|
|
memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float));
|
|
|
|
|
|
|
|
|
|
|
|
// will be populated a bit later
|
|
|
|
// will be populated a bit later
|
|
|
|
logprobs.resize(n_logits);
|
|
|
|
logprobs.resize(n_logits);
|
|
|
@ -3023,7 +3011,7 @@ static struct whisper_token_data whisper_sample_next_token(
|
|
|
|
// the initial timestamp cannot be larger than max_initial_timestamp
|
|
|
|
// the initial timestamp cannot be larger than max_initial_timestamp
|
|
|
|
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429
|
|
|
|
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429
|
|
|
|
if (is_initial && params.max_initial_timestamp > 0.0f) {
|
|
|
|
if (is_initial && params.max_initial_timestamp > 0.0f) {
|
|
|
|
const float precision = float(WHISPER_CHUNK_SIZE)/ctx->model.hparams.n_audio_ctx;
|
|
|
|
const float precision = float(WHISPER_CHUNK_SIZE)/ctx.model.hparams.n_audio_ctx;
|
|
|
|
const int tid0 = std::round(params.max_initial_timestamp/precision);
|
|
|
|
const int tid0 = std::round(params.max_initial_timestamp/precision);
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++ i) {
|
|
|
|
for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++ i) {
|
|
|
@ -3091,24 +3079,24 @@ static struct whisper_token_data whisper_sample_next_token(
|
|
|
|
printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
|
|
|
|
printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
|
|
switch (params.strategy) {
|
|
|
|
//switch (params.strategy) {
|
|
|
|
case WHISPER_SAMPLING_GREEDY:
|
|
|
|
// case WHISPER_SAMPLING_GREEDY:
|
|
|
|
{
|
|
|
|
// {
|
|
|
|
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L249-L274
|
|
|
|
// // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L249-L274
|
|
|
|
// TODO: implement
|
|
|
|
// // TODO: implement
|
|
|
|
result = (is_initial) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx);
|
|
|
|
// result = (is_initial) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx);
|
|
|
|
} break;
|
|
|
|
// } break;
|
|
|
|
case WHISPER_SAMPLING_BEAM_SEARCH:
|
|
|
|
// case WHISPER_SAMPLING_BEAM_SEARCH:
|
|
|
|
{
|
|
|
|
// {
|
|
|
|
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L277C13-L364
|
|
|
|
// // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L277C13-L364
|
|
|
|
// TODO: implement
|
|
|
|
// // TODO: implement
|
|
|
|
} break;
|
|
|
|
// } break;
|
|
|
|
}
|
|
|
|
//}
|
|
|
|
|
|
|
|
|
|
|
|
sum_logprobs += logprobs[result.id];
|
|
|
|
//sum_logprobs += logprobs[result.id];
|
|
|
|
printf("token: %16s, logprob: %8.3f, sum_logprobs: %8.3f, avg_logprob: %8.3f\n", vocab.id_to_token.at(result.id).c_str(), logprobs[result.id], sum_logprobs, sum_logprobs/(tokens_cur.size() + 1));
|
|
|
|
//printf("token: %16s, logprob: %8.3f, sum_logprobs: %8.3f, avg_logprob: %8.3f\n", vocab.id_to_token.at(result.id).c_str(), logprobs[result.id], sum_logprobs, sum_logprobs/(tokens_cur.size() + 1));
|
|
|
|
|
|
|
|
|
|
|
|
return result;
|
|
|
|
//return result;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
int whisper_full(
|
|
|
|
int whisper_full(
|
|
|
@ -3191,6 +3179,7 @@ int whisper_full(
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = 1; i < n_decoders; i++) {
|
|
|
|
for (int i = 1; i < n_decoders; i++) {
|
|
|
|
|
|
|
|
// TAGS: WHISPER_DECODER_INIT
|
|
|
|
if (ctx->decoders[i].kv_self.ctx == nullptr) {
|
|
|
|
if (ctx->decoders[i].kv_self.ctx == nullptr) {
|
|
|
|
ctx->decoders[i].kv_self = ctx->decoders[0].kv_self;
|
|
|
|
ctx->decoders[i].kv_self = ctx->decoders[0].kv_self;
|
|
|
|
if (!kv_cache_reinit(ctx->decoders[i].kv_self)) {
|
|
|
|
if (!kv_cache_reinit(ctx->decoders[i].kv_self)) {
|
|
|
@ -3201,7 +3190,9 @@ int whisper_full(
|
|
|
|
fprintf(stderr, "%s: initialized self-attention kv cache, decoder %d\n", __func__, i);
|
|
|
|
fprintf(stderr, "%s: initialized self-attention kv cache, decoder %d\n", __func__, i);
|
|
|
|
|
|
|
|
|
|
|
|
ctx->decoders[i].sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity());
|
|
|
|
ctx->decoders[i].sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity());
|
|
|
|
ctx->decoders[i].prompt.reserve(ctx->decoders[0].prompt.capacity());
|
|
|
|
|
|
|
|
|
|
|
|
ctx->decoders[i].logits.reserve (ctx->vocab.n_vocab);
|
|
|
|
|
|
|
|
ctx->decoders[i].logprobs.reserve(ctx->vocab.n_vocab);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -3242,14 +3233,12 @@ int whisper_full(
|
|
|
|
int progress_prev = 0;
|
|
|
|
int progress_prev = 0;
|
|
|
|
int progress_step = 5;
|
|
|
|
int progress_step = 5;
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<whisper_token_data> tokens_cur;
|
|
|
|
int seek = seek_start;
|
|
|
|
tokens_cur.reserve(whisper_n_text_ctx(ctx));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<whisper_token> prompt;
|
|
|
|
std::vector<whisper_token> prompt;
|
|
|
|
prompt.reserve(whisper_n_text_ctx(ctx));
|
|
|
|
prompt.reserve(whisper_n_text_ctx(ctx));
|
|
|
|
|
|
|
|
|
|
|
|
// main loop
|
|
|
|
// main loop
|
|
|
|
int seek = seek_start;
|
|
|
|
|
|
|
|
while (true) {
|
|
|
|
while (true) {
|
|
|
|
const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start);
|
|
|
|
const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start);
|
|
|
|
while (progress_cur >= progress_prev + progress_step) {
|
|
|
|
while (progress_cur >= progress_prev + progress_step) {
|
|
|
@ -3283,6 +3272,181 @@ int whisper_full(
|
|
|
|
prompt_past.clear();
|
|
|
|
prompt_past.clear();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//prompt.clear();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//// if we have already generated some text, use it as a prompt to condition the next generation
|
|
|
|
|
|
|
|
//if (!prompt_past.empty()) {
|
|
|
|
|
|
|
|
// int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// prompt = { whisper_token_prev(ctx) };
|
|
|
|
|
|
|
|
// prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// prompt_past.clear();
|
|
|
|
|
|
|
|
// prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end());
|
|
|
|
|
|
|
|
//}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//// print the prompt
|
|
|
|
|
|
|
|
////printf("\n\n");
|
|
|
|
|
|
|
|
////for (int i = 0; i < prompt.size(); i++) {
|
|
|
|
|
|
|
|
//// printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token[prompt[i]].c_str());
|
|
|
|
|
|
|
|
////}
|
|
|
|
|
|
|
|
////printf("\n\n");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//int n_past = 0;
|
|
|
|
|
|
|
|
//int seek_delta = 100*WHISPER_CHUNK_SIZE;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//// the accumulated transcription in the current interation
|
|
|
|
|
|
|
|
//int result_len = 0;
|
|
|
|
|
|
|
|
//tokens_cur.clear();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//bool failed = false; // has the current segment failed to decode?
|
|
|
|
|
|
|
|
//bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment?
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//double sum_logprobs = 0.0;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
|
|
|
|
|
|
|
|
// if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
|
|
|
|
|
|
|
|
// fprintf(stderr, "%s: failed to decode\n", __func__);
|
|
|
|
|
|
|
|
// return -7;
|
|
|
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// n_past += prompt.size();
|
|
|
|
|
|
|
|
// prompt.clear();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// // sample the next token based on the selected decoding strategy + parameters
|
|
|
|
|
|
|
|
// // also, update the sliding window position based on the sampled timestamp tokens
|
|
|
|
|
|
|
|
// {
|
|
|
|
|
|
|
|
// const auto token = whisper_sample_next_token(ctx, params, sum_logprobs, prompt, tokens_cur);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// // timestamp token - update sliding window
|
|
|
|
|
|
|
|
// if (token.id > whisper_token_beg(ctx)) {
|
|
|
|
|
|
|
|
// const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// // do not allow to go back in time
|
|
|
|
|
|
|
|
// if (has_ts && seek_delta > seek_delta_new && result_len < i) {
|
|
|
|
|
|
|
|
// break;
|
|
|
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// seek_delta = seek_delta_new;
|
|
|
|
|
|
|
|
// result_len = i + 1;
|
|
|
|
|
|
|
|
// has_ts = true;
|
|
|
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// // add it to the context
|
|
|
|
|
|
|
|
// prompt.push_back(token.id);
|
|
|
|
|
|
|
|
// tokens_cur.push_back(token);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// //{
|
|
|
|
|
|
|
|
// // const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]";
|
|
|
|
|
|
|
|
// // printf("%s: %3d %10s %6d %6.3f '%s'\n", __func__, i, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str());
|
|
|
|
|
|
|
|
// //}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// // end of segment
|
|
|
|
|
|
|
|
// if (token.id == whisper_token_eot(ctx) || // end of text token
|
|
|
|
|
|
|
|
// (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached
|
|
|
|
|
|
|
|
// (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached
|
|
|
|
|
|
|
|
// ) {
|
|
|
|
|
|
|
|
// if (result_len == 0) {
|
|
|
|
|
|
|
|
// if (seek + seek_delta + 100 >= seek_end) {
|
|
|
|
|
|
|
|
// result_len = i + 1;
|
|
|
|
|
|
|
|
// } else {
|
|
|
|
|
|
|
|
// failed = true;
|
|
|
|
|
|
|
|
// break;
|
|
|
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// if (params.single_segment) {
|
|
|
|
|
|
|
|
// result_len = i + 1;
|
|
|
|
|
|
|
|
// seek_delta = 100*WHISPER_CHUNK_SIZE;
|
|
|
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// break;
|
|
|
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// // TESTS: if no tensors are loaded, it means we are running tests
|
|
|
|
|
|
|
|
// if (ctx->model.n_loaded == 0) {
|
|
|
|
|
|
|
|
// seek_delta = 100*WHISPER_CHUNK_SIZE;
|
|
|
|
|
|
|
|
// break;
|
|
|
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// // sometimes, the decoding can get stuck in a repetition loop
|
|
|
|
|
|
|
|
// // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
|
|
|
|
|
|
|
|
// if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
|
|
|
|
|
|
|
|
// failed = true;
|
|
|
|
|
|
|
|
// break;
|
|
|
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
//}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//if (failed) {
|
|
|
|
|
|
|
|
// // when we fail to sample timestamp token, retry by clearing the past prompt
|
|
|
|
|
|
|
|
// // if it fails again, then we advance the window by 1 second
|
|
|
|
|
|
|
|
// if (!prompt_past.empty()) {
|
|
|
|
|
|
|
|
// prompt_past.clear();
|
|
|
|
|
|
|
|
// } else {
|
|
|
|
|
|
|
|
// fprintf(stderr, "\n%s: failed to generate timestamp token - skipping one second\n\n", __func__);
|
|
|
|
|
|
|
|
// seek += 100;
|
|
|
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
// continue;
|
|
|
|
|
|
|
|
//}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//// shrink down to result_len
|
|
|
|
|
|
|
|
//tokens_cur.resize(result_len);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//for (const auto & r : tokens_cur) {
|
|
|
|
|
|
|
|
// prompt_past.push_back(r.id);
|
|
|
|
|
|
|
|
//}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (int it = 0; it < (int) temperatures.size(); ++it) {
|
|
|
|
|
|
|
|
const float t_cur = temperatures[it];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int n_decoders_cur = 1;
|
|
|
|
|
|
|
|
switch (params.strategy) {
|
|
|
|
|
|
|
|
case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
if (t_cur > 0.0f) {
|
|
|
|
|
|
|
|
n_decoders_cur = params.greedy.best_of;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
} break;
|
|
|
|
|
|
|
|
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
if (t_cur > 0.0f) {
|
|
|
|
|
|
|
|
n_decoders_cur = params.greedy.best_of;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
n_decoders_cur = params.beam_search.beam_size;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
} break;
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fprintf(stderr, "\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (t_cur > 0.5) {
|
|
|
|
|
|
|
|
prompt_past.clear();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fprintf(stderr, "%s: clearing prompt_past\n", __func__);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// TAGS: WHISPER_DECODER_INIT
|
|
|
|
|
|
|
|
for (int i = 0; i < n_decoders_cur; ++i) {
|
|
|
|
|
|
|
|
auto & decoder = ctx->decoders[i];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
decoder.sequence.tokens.clear();
|
|
|
|
|
|
|
|
decoder.sequence.result_len = 0;
|
|
|
|
|
|
|
|
decoder.sequence.sum_logprobs = 0.0;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
decoder.n_past = 0;
|
|
|
|
|
|
|
|
decoder.seek_delta = 100*WHISPER_CHUNK_SIZE;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
decoder.failed = false;
|
|
|
|
|
|
|
|
decoder.completed = false;
|
|
|
|
|
|
|
|
decoder.has_ts = false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// init prompt and kv cache for the current iteration
|
|
|
|
|
|
|
|
// run whisper_decoder() only for decoder 0 and copy the results for the other decoders
|
|
|
|
|
|
|
|
{
|
|
|
|
prompt.clear();
|
|
|
|
prompt.clear();
|
|
|
|
|
|
|
|
|
|
|
|
// if we have already generated some text, use it as a prompt to condition the next generation
|
|
|
|
// if we have already generated some text, use it as a prompt to condition the next generation
|
|
|
@ -3291,11 +3455,9 @@ int whisper_full(
|
|
|
|
|
|
|
|
|
|
|
|
prompt = { whisper_token_prev(ctx) };
|
|
|
|
prompt = { whisper_token_prev(ctx) };
|
|
|
|
prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
|
|
|
|
prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
|
|
|
|
|
|
|
|
|
|
|
|
prompt_past.clear();
|
|
|
|
|
|
|
|
prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end());
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// init new transcription with sot, language (opt) and task tokens
|
|
|
|
prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
|
|
|
|
prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
|
|
|
|
|
|
|
|
|
|
|
|
// print the prompt
|
|
|
|
// print the prompt
|
|
|
@ -3305,110 +3467,101 @@ int whisper_full(
|
|
|
|
//}
|
|
|
|
//}
|
|
|
|
//printf("\n\n");
|
|
|
|
//printf("\n\n");
|
|
|
|
|
|
|
|
|
|
|
|
int n_past = 0;
|
|
|
|
if (whisper_decode(*ctx, ctx->decoders[0], params.n_threads, prompt.data(), prompt.size(), 0) != 0) {
|
|
|
|
int seek_delta = 100*WHISPER_CHUNK_SIZE;
|
|
|
|
fprintf(stderr, "%s: failed to decode\n", __func__);
|
|
|
|
|
|
|
|
return -7;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// the accumulated transcription in the current interation
|
|
|
|
whisper_process_logits(*ctx, ctx->decoders[0], params);
|
|
|
|
int result_len = 0;
|
|
|
|
|
|
|
|
tokens_cur.clear();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool failed = false; // has the current segment failed to decode?
|
|
|
|
for (int i = 1; i < n_decoders_cur; ++i) {
|
|
|
|
bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment?
|
|
|
|
auto & decoder = ctx->decoders[i];
|
|
|
|
|
|
|
|
|
|
|
|
double sum_logprobs = 0.0;
|
|
|
|
memcpy(decoder.kv_self.buf.data(), ctx->decoders[0].kv_self.buf.data(), decoder.kv_self.buf.size());
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
|
|
|
|
decoder.n_past += prompt.size();
|
|
|
|
if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
|
|
|
|
|
|
|
|
fprintf(stderr, "%s: failed to decode\n", __func__);
|
|
|
|
memcpy(decoder.logits.data(), ctx->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
|
|
|
|
return -7;
|
|
|
|
memcpy(decoder.logprobs.data(), ctx->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
n_past += prompt.size();
|
|
|
|
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
|
|
|
|
prompt.clear();
|
|
|
|
bool completed = true;
|
|
|
|
|
|
|
|
|
|
|
|
// sample the next token based on the selected decoding strategy + parameters
|
|
|
|
switch (params.strategy) {
|
|
|
|
// also, update the sliding window position based on the sampled timestamp tokens
|
|
|
|
case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
|
|
|
|
{
|
|
|
|
{
|
|
|
|
const auto token = whisper_sample_next_token(ctx, params, sum_logprobs, prompt, tokens_cur);
|
|
|
|
for (int i = 0; i < n_decoders_cur; ++i) {
|
|
|
|
|
|
|
|
auto & decoder = ctx->decoders[i];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (decoder.completed || decoder.failed) {
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// timestamp token - update sliding window
|
|
|
|
if (t_cur < 1e-6f) {
|
|
|
|
if (token.id > whisper_token_beg(ctx)) {
|
|
|
|
// select top token
|
|
|
|
const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx));
|
|
|
|
} else {
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
} break;
|
|
|
|
|
|
|
|
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
// TODO: ..
|
|
|
|
|
|
|
|
} break;
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
// do not allow to go back in time
|
|
|
|
if (completed) {
|
|
|
|
if (has_ts && seek_delta > seek_delta_new && result_len < i) {
|
|
|
|
|
|
|
|
break;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
seek_delta = seek_delta_new;
|
|
|
|
for (int i = 0; i < n_decoders_cur; ++i) {
|
|
|
|
result_len = i + 1;
|
|
|
|
auto & decoder = ctx->decoders[i];
|
|
|
|
has_ts = true;
|
|
|
|
|
|
|
|
|
|
|
|
if (decoder.failed || decoder.completed) {
|
|
|
|
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// add it to the context
|
|
|
|
decoder.tokens.resize(1);
|
|
|
|
prompt.push_back(token.id);
|
|
|
|
decoder.tokens[0] = decoder.sequence.tokens.back().id;
|
|
|
|
tokens_cur.push_back(token);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//{
|
|
|
|
if (whisper_decode(*ctx, decoder, params.n_threads, decoder.tokens.data(), decoder.tokens.size(), decoder.n_past) != 0) {
|
|
|
|
// const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]";
|
|
|
|
fprintf(stderr, "%s: failed to decode\n", __func__);
|
|
|
|
// printf("%s: %3d %10s %6d %6.3f '%s'\n", __func__, i, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str());
|
|
|
|
return -7;
|
|
|
|
//}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// end of segment
|
|
|
|
whisper_process_logits(*ctx, decoder, params);
|
|
|
|
if (token.id == whisper_token_eot(ctx) || // end of text token
|
|
|
|
|
|
|
|
(params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached
|
|
|
|
++decoder.n_past;
|
|
|
|
(has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached
|
|
|
|
|
|
|
|
) {
|
|
|
|
|
|
|
|
if (result_len == 0) {
|
|
|
|
|
|
|
|
if (seek + seek_delta + 100 >= seek_end) {
|
|
|
|
|
|
|
|
result_len = i + 1;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
failed = true;
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if (params.single_segment) {
|
|
|
|
// TODO: rank the resulting sequences and select the best one
|
|
|
|
result_len = i + 1;
|
|
|
|
{
|
|
|
|
seek_delta = 100*WHISPER_CHUNK_SIZE;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
break;
|
|
|
|
bool success = true;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// TESTS: if no tensors are loaded, it means we are running tests
|
|
|
|
// TODO: implement logprob threshold + compression threshold
|
|
|
|
if (ctx->model.n_loaded == 0) {
|
|
|
|
{
|
|
|
|
seek_delta = 100*WHISPER_CHUNK_SIZE;
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// sometimes, the decoding can get stuck in a repetition loop
|
|
|
|
if (success) {
|
|
|
|
// this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
|
|
|
|
|
|
|
|
if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
|
|
|
|
|
|
|
|
failed = true;
|
|
|
|
|
|
|
|
break;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (failed) {
|
|
|
|
fprintf(stderr, "\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur);
|
|
|
|
// when we fail to sample timestamp token, retry by clearing the past prompt
|
|
|
|
|
|
|
|
// if it fails again, then we advance the window by 1 second
|
|
|
|
|
|
|
|
if (!prompt_past.empty()) {
|
|
|
|
|
|
|
|
prompt_past.clear();
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
fprintf(stderr, "\n%s: failed to generate timestamp token - skipping one second\n\n", __func__);
|
|
|
|
|
|
|
|
seek += 100;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// shrink down to result_len
|
|
|
|
// TODO
|
|
|
|
tokens_cur.resize(result_len);
|
|
|
|
const int seek_delta = 0;
|
|
|
|
|
|
|
|
whisper_sequence seq_best;
|
|
|
|
|
|
|
|
|
|
|
|
for (const auto & r : tokens_cur) {
|
|
|
|
const auto & tokens_cur = seq_best.tokens;
|
|
|
|
prompt_past.push_back(r.id);
|
|
|
|
|
|
|
|
}
|
|
|
|
// TODO: update prompt_past to keep only the last whisper_n_text_ctx(ctx)/2 tokens
|
|
|
|
|
|
|
|
//prompt_past.clear();
|
|
|
|
|
|
|
|
//prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end());
|
|
|
|
|
|
|
|
|
|
|
|
// store the text from this iteration
|
|
|
|
// store the text from this iteration
|
|
|
|
if (!tokens_cur.empty()) {
|
|
|
|
if (!tokens_cur.empty()) {
|
|
|
@ -3450,10 +3603,10 @@ int whisper_full(
|
|
|
|
|
|
|
|
|
|
|
|
if (params.token_timestamps) {
|
|
|
|
if (params.token_timestamps) {
|
|
|
|
whisper_exp_compute_token_level_timestamps(
|
|
|
|
whisper_exp_compute_token_level_timestamps(
|
|
|
|
ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
|
|
|
|
*ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
|
|
|
|
|
|
|
|
|
|
|
|
if (params.max_len > 0) {
|
|
|
|
if (params.max_len > 0) {
|
|
|
|
n_new = whisper_wrap_segment(ctx, params.max_len);
|
|
|
|
n_new = whisper_wrap_segment(*ctx, params.max_len);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (params.new_segment_callback) {
|
|
|
|
if (params.new_segment_callback) {
|
|
|
@ -3494,10 +3647,10 @@ int whisper_full(
|
|
|
|
|
|
|
|
|
|
|
|
if (params.token_timestamps) {
|
|
|
|
if (params.token_timestamps) {
|
|
|
|
whisper_exp_compute_token_level_timestamps(
|
|
|
|
whisper_exp_compute_token_level_timestamps(
|
|
|
|
ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
|
|
|
|
*ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
|
|
|
|
|
|
|
|
|
|
|
|
if (params.max_len > 0) {
|
|
|
|
if (params.max_len > 0) {
|
|
|
|
n_new = whisper_wrap_segment(ctx, params.max_len);
|
|
|
|
n_new = whisper_wrap_segment(*ctx, params.max_len);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (params.new_segment_callback) {
|
|
|
|
if (params.new_segment_callback) {
|
|
|
@ -3535,9 +3688,6 @@ int whisper_full_parallel(
|
|
|
|
ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx);
|
|
|
|
ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx);
|
|
|
|
ctx_p.probs.reserve (ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx);
|
|
|
|
ctx_p.probs.reserve (ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx);
|
|
|
|
|
|
|
|
|
|
|
|
ctx_p.work_logits.reserve (ctx_p.vocab.n_vocab);
|
|
|
|
|
|
|
|
ctx_p.work_logprobs.reserve(ctx_p.vocab.n_vocab);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ctx_p.probs_id.reserve(ctx_p.vocab.n_vocab);
|
|
|
|
ctx_p.probs_id.reserve(ctx_p.vocab.n_vocab);
|
|
|
|
|
|
|
|
|
|
|
|
if (!kv_cache_reinit(ctx_p.kv_cross)) {
|
|
|
|
if (!kv_cache_reinit(ctx_p.kv_cross)) {
|
|
|
@ -3545,6 +3695,7 @@ int whisper_full_parallel(
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// TAGS: WHISPER_DECODER_INIT
|
|
|
|
for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) {
|
|
|
|
for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) {
|
|
|
|
if (ctx_p.decoders[j].kv_self.ctx && !kv_cache_reinit(ctx_p.decoders[j].kv_self)) {
|
|
|
|
if (ctx_p.decoders[j].kv_self.ctx && !kv_cache_reinit(ctx_p.decoders[j].kv_self)) {
|
|
|
|
fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d, processor %d\n", __func__, j, i);
|
|
|
|
fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d, processor %d\n", __func__, j, i);
|
|
|
@ -3552,7 +3703,9 @@ int whisper_full_parallel(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
ctx_p.decoders[j].sequence.tokens.reserve(ctx_p.model.hparams.n_text_ctx);
|
|
|
|
ctx_p.decoders[j].sequence.tokens.reserve(ctx_p.model.hparams.n_text_ctx);
|
|
|
|
ctx_p.decoders[j].prompt.reserve(ctx_p.model.hparams.n_text_ctx);
|
|
|
|
|
|
|
|
|
|
|
|
ctx_p.decoders[j].logits.reserve (ctx_p.vocab.n_vocab);
|
|
|
|
|
|
|
|
ctx_p.decoders[j].logprobs.reserve(ctx_p.vocab.n_vocab);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -3747,14 +3900,14 @@ static std::vector<float> get_signal_energy(const float * signal, int n_samples,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
static void whisper_exp_compute_token_level_timestamps(
|
|
|
|
static void whisper_exp_compute_token_level_timestamps(
|
|
|
|
struct whisper_context * ctx,
|
|
|
|
struct whisper_context & ctx,
|
|
|
|
int i_segment,
|
|
|
|
int i_segment,
|
|
|
|
float thold_pt,
|
|
|
|
float thold_pt,
|
|
|
|
float thold_ptsum) {
|
|
|
|
float thold_ptsum) {
|
|
|
|
auto & segment = ctx->result_all[i_segment];
|
|
|
|
auto & segment = ctx.result_all[i_segment];
|
|
|
|
auto & tokens = segment.tokens;
|
|
|
|
auto & tokens = segment.tokens;
|
|
|
|
|
|
|
|
|
|
|
|
const int n_samples = ctx->energy.size();
|
|
|
|
const int n_samples = ctx.energy.size();
|
|
|
|
|
|
|
|
|
|
|
|
if (n_samples == 0) {
|
|
|
|
if (n_samples == 0) {
|
|
|
|
fprintf(stderr, "%s: no signal data available\n", __func__);
|
|
|
|
fprintf(stderr, "%s: no signal data available\n", __func__);
|
|
|
@ -3777,28 +3930,28 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
|
|
return;
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
auto & t_beg = ctx->t_beg;
|
|
|
|
auto & t_beg = ctx.t_beg;
|
|
|
|
auto & t_last = ctx->t_last;
|
|
|
|
auto & t_last = ctx.t_last;
|
|
|
|
auto & tid_last = ctx->tid_last;
|
|
|
|
auto & tid_last = ctx.tid_last;
|
|
|
|
|
|
|
|
|
|
|
|
for (int j = 0; j < n; ++j) {
|
|
|
|
for (int j = 0; j < n; ++j) {
|
|
|
|
auto & token = tokens[j];
|
|
|
|
auto & token = tokens[j];
|
|
|
|
|
|
|
|
|
|
|
|
if (j == 0) {
|
|
|
|
if (j == 0) {
|
|
|
|
if (token.id == whisper_token_beg(ctx)) {
|
|
|
|
if (token.id == whisper_token_beg(&ctx)) {
|
|
|
|
tokens[j ].t0 = t0;
|
|
|
|
tokens[j ].t0 = t0;
|
|
|
|
tokens[j ].t1 = t0;
|
|
|
|
tokens[j ].t1 = t0;
|
|
|
|
tokens[j + 1].t0 = t0;
|
|
|
|
tokens[j + 1].t0 = t0;
|
|
|
|
|
|
|
|
|
|
|
|
t_beg = t0;
|
|
|
|
t_beg = t0;
|
|
|
|
t_last = t0;
|
|
|
|
t_last = t0;
|
|
|
|
tid_last = whisper_token_beg(ctx);
|
|
|
|
tid_last = whisper_token_beg(&ctx);
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
tokens[j ].t0 = t_last;
|
|
|
|
tokens[j ].t0 = t_last;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx));
|
|
|
|
const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(&ctx));
|
|
|
|
|
|
|
|
|
|
|
|
tokens[j].id = token.id;
|
|
|
|
tokens[j].id = token.id;
|
|
|
|
tokens[j].tid = token.tid;
|
|
|
|
tokens[j].tid = token.tid;
|
|
|
@ -3806,7 +3959,7 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
|
|
tokens[j].pt = token.pt;
|
|
|
|
tokens[j].pt = token.pt;
|
|
|
|
tokens[j].ptsum = token.ptsum;
|
|
|
|
tokens[j].ptsum = token.ptsum;
|
|
|
|
|
|
|
|
|
|
|
|
tokens[j].vlen = voice_length(whisper_token_to_str(ctx, token.id));
|
|
|
|
tokens[j].vlen = voice_length(whisper_token_to_str(&ctx, token.id));
|
|
|
|
|
|
|
|
|
|
|
|
if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) {
|
|
|
|
if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) {
|
|
|
|
if (j > 0) {
|
|
|
|
if (j > 0) {
|
|
|
@ -3885,7 +4038,7 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
|
|
const int hw = WHISPER_SAMPLE_RATE/8;
|
|
|
|
const int hw = WHISPER_SAMPLE_RATE/8;
|
|
|
|
|
|
|
|
|
|
|
|
for (int j = 0; j < n; j++) {
|
|
|
|
for (int j = 0; j < n; j++) {
|
|
|
|
if (tokens[j].id >= whisper_token_eot(ctx)) {
|
|
|
|
if (tokens[j].id >= whisper_token_eot(&ctx)) {
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -3900,15 +4053,15 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
|
|
float sum = 0.0f;
|
|
|
|
float sum = 0.0f;
|
|
|
|
|
|
|
|
|
|
|
|
for (int k = ss0; k < ss1; k++) {
|
|
|
|
for (int k = ss0; k < ss1; k++) {
|
|
|
|
sum += ctx->energy[k];
|
|
|
|
sum += ctx.energy[k];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
const float thold = 0.5*sum/ns;
|
|
|
|
const float thold = 0.5*sum/ns;
|
|
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
{
|
|
|
|
int k = s0;
|
|
|
|
int k = s0;
|
|
|
|
if (ctx->energy[k] > thold && j > 0) {
|
|
|
|
if (ctx.energy[k] > thold && j > 0) {
|
|
|
|
while (k > 0 && ctx->energy[k] > thold) {
|
|
|
|
while (k > 0 && ctx.energy[k] > thold) {
|
|
|
|
k--;
|
|
|
|
k--;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
tokens[j].t0 = sample_to_timestamp(k);
|
|
|
|
tokens[j].t0 = sample_to_timestamp(k);
|
|
|
@ -3918,7 +4071,7 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
|
|
s0 = k;
|
|
|
|
s0 = k;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
while (ctx->energy[k] < thold && k < s1) {
|
|
|
|
while (ctx.energy[k] < thold && k < s1) {
|
|
|
|
k++;
|
|
|
|
k++;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
s0 = k;
|
|
|
|
s0 = k;
|
|
|
@ -3928,8 +4081,8 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
{
|
|
|
|
int k = s1;
|
|
|
|
int k = s1;
|
|
|
|
if (ctx->energy[k] > thold) {
|
|
|
|
if (ctx.energy[k] > thold) {
|
|
|
|
while (k < n_samples - 1 && ctx->energy[k] > thold) {
|
|
|
|
while (k < n_samples - 1 && ctx.energy[k] > thold) {
|
|
|
|
k++;
|
|
|
|
k++;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
tokens[j].t1 = sample_to_timestamp(k);
|
|
|
|
tokens[j].t1 = sample_to_timestamp(k);
|
|
|
@ -3939,7 +4092,7 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
|
|
s1 = k;
|
|
|
|
s1 = k;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
while (ctx->energy[k] < thold && k > s0) {
|
|
|
|
while (ctx.energy[k] < thold && k > s0) {
|
|
|
|
k--;
|
|
|
|
k--;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
s1 = k;
|
|
|
|
s1 = k;
|
|
|
|