whisper : add comments

pull/291/head
Georgi Gerganov 3 years ago
parent 6700cd57f7
commit d83e47573b
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -633,24 +633,24 @@ int main(int argc, char ** argv) {
wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
wparams.print_realtime = false;
wparams.print_progress = params.print_progress;
wparams.print_timestamps = !params.no_timestamps;
wparams.print_special = params.print_special;
wparams.translate = params.translate;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
wparams.offset_ms = params.offset_t_ms;
wparams.duration_ms = params.duration_ms;
wparams.token_timestamps = params.output_wts || params.max_len > 0;
wparams.thold_pt = params.word_thold;
wparams.entropy_threshold = params.entropy_thold;
wparams.logprob_threshold = params.logprob_thold;
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
wparams.speed_up = params.speed_up;
wparams.print_realtime = false;
wparams.print_progress = params.print_progress;
wparams.print_timestamps = !params.no_timestamps;
wparams.print_special = params.print_special;
wparams.translate = params.translate;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
wparams.offset_ms = params.offset_t_ms;
wparams.duration_ms = params.duration_ms;
wparams.token_timestamps = params.output_wts || params.max_len > 0;
wparams.thold_pt = params.word_thold;
wparams.entropy_thold = params.entropy_thold;
wparams.logprob_thold = params.logprob_thold;
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
wparams.speed_up = params.speed_up;
wparams.greedy.best_of = params.best_of;
wparams.beam_search.beam_size = params.beam_size;

@ -616,8 +616,8 @@ int main(int argc, char ** argv) {
wparams.speed_up = params.speed_up;
// disable best_of fallback
wparams.temperature_increment = -1.0f;
wparams.greedy.best_of = -1;
wparams.temperature_inc = -1.0f;
wparams.greedy.best_of = -1;
wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size();

@ -25,7 +25,9 @@
} \
} while (0)
// define this to enable verbose trace logging - useful for debugging purposes
//#define WHISPER_DEBUG
#if defined(WHISPER_DEBUG)
#define WHISPER_PRINT_DEBUG(...) \
do { \
@ -380,6 +382,8 @@ struct whisper_kv_cache {
struct ggml_context * ctx;
std::vector<uint8_t> buf;
int n; // number of tokens currently in the cache
};
struct whisper_model {
@ -442,12 +446,13 @@ struct whisper_sequence {
// TAGS: WHISPER_DECODER_INIT
struct whisper_decoder {
// each decoders keeps its own KV-cache
whisper_kv_cache kv_self;
// the currently generated sequence of tokens
whisper_sequence sequence;
int n_past;
int seek_delta;
int seek_delta; // the window shift found so far based on the decoded timestamp tokens
bool failed; // has the current segment failed to decode?
bool completed; // has the decoder completed the current segment?
@ -476,6 +481,8 @@ struct whisper_context {
whisper_model model;
whisper_vocab vocab;
// cross-attention KV cache for the decoders
// shared between all decoders
whisper_kv_cache kv_cross;
whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
@ -490,6 +497,7 @@ struct whisper_context {
std::vector<whisper_segment> result_all;
std::vector<whisper_token> prompt_past;
// work container used to avoid memory allocations
std::vector<std::pair<double, whisper_vocab::id>> logits_id;
mutable std::mt19937 rng; // used for sampling at t > 0.0
@ -680,6 +688,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
}
// initialize all memory buffers
// always have at least one decoder
wctx.model.buf = new std::vector<uint8_t>();
wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type));
@ -2671,127 +2682,77 @@ const char * whisper_print_system_info(void) {
////////////////////////////////////////////////////////////////////////////
struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
struct whisper_full_params result;
struct whisper_full_params result = {
/*.strategy =*/ WHISPER_SAMPLING_GREEDY,
switch (strategy) {
case WHISPER_SAMPLING_GREEDY:
{
result = {
/*.strategy =*/ WHISPER_SAMPLING_GREEDY,
/*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
/*.n_max_text_ctx =*/ 16384,
/*.offset_ms =*/ 0,
/*.duration_ms =*/ 0,
/*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
/*.n_max_text_ctx =*/ 16384,
/*.offset_ms =*/ 0,
/*.duration_ms =*/ 0,
/*.translate =*/ false,
/*.no_context =*/ false,
/*.single_segment =*/ false,
/*.print_special =*/ false,
/*.print_progress =*/ true,
/*.print_realtime =*/ false,
/*.print_timestamps =*/ true,
/*.translate =*/ false,
/*.no_context =*/ false,
/*.single_segment =*/ false,
/*.print_special =*/ false,
/*.print_progress =*/ true,
/*.print_realtime =*/ false,
/*.print_timestamps =*/ true,
/*.token_timestamps =*/ false,
/*.thold_pt =*/ 0.01f,
/*.thold_ptsum =*/ 0.01f,
/*.max_len =*/ 0,
/*.max_tokens =*/ 0,
/*.token_timestamps =*/ false,
/*.thold_pt =*/ 0.01f,
/*.thold_ptsum =*/ 0.01f,
/*.max_len =*/ 0,
/*.max_tokens =*/ 0,
/*.speed_up =*/ false,
/*.audio_ctx =*/ 0,
/*.speed_up =*/ false,
/*.audio_ctx =*/ 0,
/*.prompt_tokens =*/ nullptr,
/*.prompt_n_tokens =*/ 0,
/*.prompt_tokens =*/ nullptr,
/*.prompt_n_tokens =*/ 0,
/*.language =*/ "en",
/*.language =*/ "en",
/*.suppress_blank =*/ true,
/*.suppress_blank =*/ true,
/*.temperature =*/ 0.0f,
/*.max_initial_ts =*/ 1.0f,
/*.length_penalty =*/ -1.0f,
/*.temperature =*/ 0.0f,
/*.max_initial_timestamp =*/ 1.0f,
/*.length_penalty =*/ -1.0f,
/*.temperature_inc =*/ 0.2f,
/*.entropy_thold =*/ 2.4f,
/*.logprob_thold =*/ -1.0f,
/*.no_speech_thold =*/ 0.6f,
/*.temperature_increment =*/ 0.2f,
/*.entropy_threshold =*/ 2.4f,
/*.logprob_threshold =*/ -1.0f,
/*.no_speech_threshold =*/ 0.6f,
/*.greedy =*/ {
/*.best_of =*/ -1,
},
/*.greedy =*/ {
/*.best_of =*/ 5,
},
/*.beam_search =*/ {
/*.beam_size =*/ -1,
/*.beam_search =*/ {
/*.beam_size =*/ -1,
/*.patience =*/ -1.0f,
},
/*.patience =*/ -1.0f,
},
/*.new_segment_callback =*/ nullptr,
/*.new_segment_callback_user_data =*/ nullptr,
/*.new_segment_callback =*/ nullptr,
/*.new_segment_callback_user_data =*/ nullptr,
/*.encoder_begin_callback =*/ nullptr,
/*.encoder_begin_callback_user_data =*/ nullptr,
};
/*.encoder_begin_callback =*/ nullptr,
/*.encoder_begin_callback_user_data =*/ nullptr,
switch (strategy) {
case WHISPER_SAMPLING_GREEDY:
{
result.greedy = {
/*.best_of =*/ 1,
};
} break;
case WHISPER_SAMPLING_BEAM_SEARCH:
{
result = {
/*.strategy =*/ WHISPER_SAMPLING_BEAM_SEARCH,
/*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
/*.n_max_text_ctx =*/ 16384,
/*.offset_ms =*/ 0,
/*.duration_ms =*/ 0,
/*.translate =*/ false,
/*.no_context =*/ false,
/*.single_segment =*/ false,
/*.print_special =*/ false,
/*.print_progress =*/ true,
/*.print_realtime =*/ false,
/*.print_timestamps =*/ true,
/*.token_timestamps =*/ false,
/*.thold_pt =*/ 0.01f,
/*.thold_ptsum =*/ 0.01f,
/*.max_len =*/ 0,
/*.max_tokens =*/ 0,
/*.speed_up =*/ false,
/*.audio_ctx =*/ 0,
/*.prompt_tokens =*/ nullptr,
/*.prompt_n_tokens =*/ 0,
/*.language =*/ "en",
/*.suppress_blank =*/ true,
/*.temperature =*/ 0.0f,
/*.max_initial_timestamp =*/ 1.0f,
/*.length_penalty =*/ -1.0f,
/*.temperature_increment =*/ 0.2f,
/*.entropy_threshold =*/ 2.4f,
/*.logprob_threshold =*/ -1.0f,
/*.no_speech_threshold =*/ 0.6f,
result.beam_search = {
/*.beam_size =*/ 5,
/*.greedy =*/ {
/*.best_of =*/ 5,
},
/*.beam_search =*/ {
/*.beam_size =*/ 5,
/*.patience =*/ -1.0f,
},
/*.new_segment_callback =*/ nullptr,
/*.new_segment_callback_user_data =*/ nullptr,
/*.encoder_begin_callback =*/ nullptr,
/*.encoder_begin_callback_user_data =*/ nullptr,
/*.patience =*/ -1.0f,
};
} break;
}
@ -2862,8 +2823,8 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) {
}
// process the logits for the selected decoder
// - applyies logit filters
// - computes logprobs
// - applies logit filters
// - computes logprobs and probs
static void whisper_process_logits(
const struct whisper_context & ctx,
const struct whisper_full_params params,
@ -2938,11 +2899,11 @@ static void whisper_process_logits(
}
}
// the initial timestamp cannot be larger than max_initial_timestamp
// the initial timestamp cannot be larger than max_initial_ts
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429
if (is_initial && params.max_initial_timestamp > 0.0f) {
if (is_initial && params.max_initial_ts > 0.0f) {
const float precision = float(WHISPER_CHUNK_SIZE)/ctx.model.hparams.n_audio_ctx;
const int tid0 = std::round(params.max_initial_timestamp/precision);
const int tid0 = std::round(params.max_initial_ts/precision);
for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++i) {
logits[i] = -INFINITY;
@ -3073,7 +3034,7 @@ static whisper_token_data whisper_sample_token(
}
}
result.pt = max_ts/(sum_ts + 1e-10);
result.pt = max_ts/(sum_ts + 1e-10);
result.ptsum = sum_ts;
}
@ -3127,7 +3088,13 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
return a.first > b.first;
});
std::vector<whisper_token_data> result(k);
std::vector<whisper_token_data> result;
result.reserve(k);
whisper_token tid;
float pt;
float ptsum;
{
double sum_ts = 0.0;
@ -3141,24 +3108,18 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
sum_ts += probs[i];
if (max_ts < probs[i]) {
max_ts = probs[i];
result[0].tid = i;
tid = i;
}
}
result[0].pt = max_ts/(sum_ts + 1e-10);
result[0].ptsum = sum_ts;
pt = max_ts/(sum_ts + 1e-10);
ptsum = sum_ts;
}
for (int i = 0; i < k; ++i) {
result[i].id = logits_id[i].second;
result[i].p = probs[result[i].id];
result[i].plog = logprobs[result[i].id];
result[i].tid = result[0].tid;
result[i].pt = result[0].pt;
result[i].ptsum = result[0].ptsum;
result[i].t0 = -1;
result[i].t1 = -1;
result[i].vlen = 0.0f;
const auto id = logits_id[i].second;
result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, });
if (result[i].id >= vocab.token_beg) {
result[i].tid = result[i].id;
@ -3276,8 +3237,8 @@ int whisper_full(
// a set of temperatures to use
// [ t0, t0 + delta, t0 + 2*delta, ..., < 1.0f + 1e-6f ]
std::vector<float> temperatures;
if (params.temperature_increment > 0.0f) {
for (float t = params.temperature; t < 1.0f + 1e-6f; t += params.temperature_increment) {
if (params.temperature_inc > 0.0f) {
for (float t = params.temperature; t < 1.0f + 1e-6f; t += params.temperature_inc) {
temperatures.push_back(t);
}
} else {
@ -3448,6 +3409,8 @@ int whisper_full(
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = ctx->decoders[j];
decoder.kv_self.n = 0;
decoder.sequence.tokens.clear();
decoder.sequence.result_len = 0;
decoder.sequence.sum_logprobs_all = 0.0;
@ -3456,7 +3419,6 @@ int whisper_full(
decoder.sequence.entropy = 0.0;
decoder.sequence.score = -INFINITY;
decoder.n_past = 0;
decoder.seek_delta = 100*WHISPER_CHUNK_SIZE;
decoder.failed = false;
@ -3497,7 +3459,7 @@ int whisper_full(
whisper_process_logits(*ctx, params, ctx->decoders[0], t_cur);
ctx->decoders[0].n_past += prompt.size();
ctx->decoders[0].kv_self.n += prompt.size();
for (int j = 1; j < n_decoders_cur; ++j) {
auto & decoder = ctx->decoders[j];
@ -3505,7 +3467,7 @@ int whisper_full(
memcpy(decoder.kv_self.k->data, ctx->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
memcpy(decoder.kv_self.v->data, ctx->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
decoder.n_past += prompt.size();
decoder.kv_self.n += prompt.size();
memcpy(decoder.probs.data(), ctx->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
memcpy(decoder.logits.data(), ctx->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
@ -3519,6 +3481,7 @@ int whisper_full(
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
const int64_t t_start_sample_us = ggml_time_us();
// store the KV caches of all decoders when doing beam-search
if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
kv_bufs.resize(n_decoders_cur);
for (int j = 0; j < n_decoders_cur; ++j) {
@ -3538,6 +3501,7 @@ int whisper_full(
beam_candidates.clear();
}
// generate new sequence candidates for each decoder
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = ctx->decoders[j];
@ -3571,6 +3535,7 @@ int whisper_full(
};
}
// for beam-search, choose the top candidates and update the KV caches
if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
std::sort(
beam_candidates.begin(),
@ -3606,6 +3571,10 @@ int whisper_full(
}
}
// update the decoder state
// - check if the sequence is completed
// - check if the sequence is failed
// - update sliding window based on timestamp tokens
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = ctx->decoders[j];
@ -3705,6 +3674,7 @@ int whisper_full(
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
// obtain logits for the next token
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = ctx->decoders[j];
@ -3715,11 +3685,11 @@ int whisper_full(
decoder.tokens_tmp.resize(1);
decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id;
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, n_past %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.n_past, decoder.seek_delta);
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.n_past, params.n_threads)) {
if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) {
fprintf(stderr, "%s: failed to decode\n", __func__);
return -7;
return -8;
}
{
@ -3727,7 +3697,7 @@ int whisper_full(
whisper_process_logits(*ctx, params, decoder, t_cur);
++decoder.n_past;
++decoder.kv_self.n;
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
}
@ -3736,7 +3706,7 @@ int whisper_full(
// rank the resulting sequences and select the best one
{
double best_score = -1e9;
double best_score = -INFINITY;
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = ctx->decoders[j];
@ -3751,9 +3721,9 @@ int whisper_full(
WHISPER_PRINT_DEBUG("%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n",
__func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy);
if (decoder.sequence.result_len > 8 && decoder.sequence.entropy < params.entropy_threshold) {
if (decoder.sequence.result_len > 8 && decoder.sequence.entropy < params.entropy_thold) {
WHISPER_PRINT_DEBUG("%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n",
__func__, j, decoder.sequence.entropy, params.entropy_threshold);
__func__, j, decoder.sequence.entropy, params.entropy_thold);
decoder.failed = true;
@ -3765,31 +3735,33 @@ int whisper_full(
best_decoder_id = j;
}
}
}
bool success = true;
WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id);
}
// implement logprob threshold
// was the decoding successful for the current temperature?
{
auto & decoder = ctx->decoders[best_decoder_id];
bool success = true;
const auto & decoder = ctx->decoders[best_decoder_id];
if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_threshold) {
if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
success = false;
}
}
if (success) {
WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id);
//for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) {
// WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str());
//}
if (success) {
//for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) {
// WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str());
//}
break;
break;
}
}
WHISPER_PRINT_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur);
}
// output results through a user-provided callback
{
const auto & best_decoder = ctx->decoders[best_decoder_id];

@ -137,6 +137,7 @@ extern "C" {
// tokens + n_tokens is the provided context for the decoder.
// n_past is the number of tokens to use from previous decoder calls.
// Returns 0 on success
// TODO: add support for multiple decoders
WHISPER_API int whisper_decode(
struct whisper_context * ctx,
const whisper_token * tokens,
@ -218,8 +219,8 @@ extern "C" {
// Available sampling strategies
enum whisper_sampling_strategy {
WHISPER_SAMPLING_GREEDY, // Always select the most probable token
WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet!
WHISPER_SAMPLING_GREEDY, // similar to OpenAI's GreefyDecoder
WHISPER_SAMPLING_BEAM_SEARCH, // similar to OpenAI's BeamSearchDecoder
};
// Text segment callback
@ -239,17 +240,17 @@ extern "C" {
enum whisper_sampling_strategy strategy;
int n_threads;
int n_max_text_ctx;
int n_max_text_ctx; // max tokens to use from past text as prompt for the decoder
int offset_ms; // start offset in ms
int duration_ms; // audio duration to process in ms
bool translate;
bool no_context;
bool no_context; // do not use initial prompt for the decoder (if any)
bool single_segment; // force single segment output (useful for streaming)
bool print_special;
bool print_progress;
bool print_realtime;
bool print_timestamps;
bool print_special; // print special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.)
bool print_progress; // print progress information
bool print_realtime; // print results from within whisper.cpp (avoid it, use callback instead)
bool print_timestamps; // print timestamps for each text segment when printing realtime
// [EXPERIMENTAL] token-level timestamps
bool token_timestamps; // enable token-level timestamps
@ -259,10 +260,11 @@ extern "C" {
int max_tokens; // max tokens per segment (0 = no limit)
// [EXPERIMENTAL] speed-up techniques
// note: these can significantly reduce the quality of the output
bool speed_up; // speed-up the audio by 2x using Phase Vocoder
int audio_ctx; // overwrite the audio context size (0 = use default)
// tokens to provide the whisper model as initial prompt
// tokens to provide to the whisper decoder as initial prompt
// these are prepended to any existing text context from a previous call
const whisper_token * prompt_tokens;
int prompt_n_tokens;
@ -271,31 +273,34 @@ extern "C" {
const char * language;
// common decoding parameters:
bool suppress_blank;
bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
float temperature;
float max_initial_timestamp;
float length_penalty;
float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478
float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
float length_penalty; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L267
// fallback parameters
float temperature_increment;
float entropy_threshold; // analog tho OpenAI's compression_ratio_threshold
float logprob_threshold;
float no_speech_threshold; // TODO: not implemented
// ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L274-L278
float temperature_inc;
float entropy_thold; // similar to OpenAI's "compression_ratio_threshold"
float logprob_thold;
float no_speech_thold; // TODO: not implemented
struct {
int best_of;
int best_of; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264
} greedy;
struct {
int beam_size;
int beam_size; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265
float patience; // TODO: not implemented
float patience; // TODO: not implemented, ref: https://arxiv.org/pdf/2204.05424.pdf
} beam_search;
// called for every newly generated text segment
whisper_new_segment_callback new_segment_callback;
void * new_segment_callback_user_data;
// called each time before the encoder starts
whisper_encoder_begin_callback encoder_begin_callback;
void * encoder_begin_callback_user_data;
};

Loading…
Cancel
Save