|
|
|
@ -592,16 +592,16 @@ struct whisper_context {
|
|
|
|
|
|
|
|
|
|
mutable std::mt19937 rng; // used for sampling at t > 0.0
|
|
|
|
|
|
|
|
|
|
int lang_id;
|
|
|
|
|
int lang_id = 0; // english by default
|
|
|
|
|
|
|
|
|
|
// [EXPERIMENTAL] token-level timestamps data
|
|
|
|
|
int64_t t_beg;
|
|
|
|
|
int64_t t_last;
|
|
|
|
|
int64_t t_beg = 0;
|
|
|
|
|
int64_t t_last = 0;
|
|
|
|
|
whisper_token tid_last;
|
|
|
|
|
std::vector<float> energy; // PCM signal energy
|
|
|
|
|
|
|
|
|
|
// [EXPERIMENTAL] speed-up techniques
|
|
|
|
|
int32_t exp_n_audio_ctx; // 0 - use default
|
|
|
|
|
int32_t exp_n_audio_ctx = 0; // 0 - use default
|
|
|
|
|
|
|
|
|
|
void use_buf(struct ggml_context * ctx, int i) {
|
|
|
|
|
#if defined(WHISPER_USE_SCRATCH)
|
|
|
|
@ -805,7 +805,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
|
|
MEM_REQ_SCRATCH3.at (model.type) +
|
|
|
|
|
scale*MEM_REQ_MODEL.at (model.type) +
|
|
|
|
|
scale*MEM_REQ_KV_CROSS.at(model.type) +
|
|
|
|
|
scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
|
|
|
|
|
scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
|
|
|
|
|
|
|
|
|
|
// this is the memory required by one decoder
|
|
|
|
|
const size_t mem_required_decoder =
|
|
|
|
@ -2936,7 +2936,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
|
|
/*.language =*/ "en",
|
|
|
|
|
|
|
|
|
|
/*.suppress_blank =*/ true,
|
|
|
|
|
/*.suppress_non_speech_tokens =*/true,
|
|
|
|
|
/*.suppress_non_speech_tokens =*/ false,
|
|
|
|
|
|
|
|
|
|
/*.temperature =*/ 0.0f,
|
|
|
|
|
/*.max_initial_ts =*/ 1.0f,
|
|
|
|
@ -2962,6 +2962,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
|
|
|
|
|
|
|
/*.encoder_begin_callback =*/ nullptr,
|
|
|
|
|
/*.encoder_begin_callback_user_data =*/ nullptr,
|
|
|
|
|
|
|
|
|
|
/*.logits_filter_callback =*/ nullptr,
|
|
|
|
|
/*.logits_filter_callback_user_data =*/ nullptr,
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
switch (strategy) {
|
|
|
|
@ -3078,8 +3081,7 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static const std::vector<std::string> non_speech_tokens
|
|
|
|
|
{
|
|
|
|
|
static const std::vector<std::string> non_speech_tokens = {
|
|
|
|
|
"\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^",
|
|
|
|
|
"_", "`", "{", "|", "}", "~", "「", "」", "『", "』", "<<", ">>", "<<<", ">>>", "--",
|
|
|
|
|
"---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪",
|
|
|
|
@ -3090,7 +3092,7 @@ static const std::vector<std::string> non_speech_tokens
|
|
|
|
|
// - applies logit filters
|
|
|
|
|
// - computes logprobs and probs
|
|
|
|
|
static void whisper_process_logits(
|
|
|
|
|
const struct whisper_context & ctx,
|
|
|
|
|
struct whisper_context & ctx,
|
|
|
|
|
const struct whisper_full_params params,
|
|
|
|
|
struct whisper_decoder & decoder,
|
|
|
|
|
float temperature) {
|
|
|
|
@ -3146,29 +3148,27 @@ static void whisper_process_logits(
|
|
|
|
|
logits[vocab.token_translate] = -INFINITY;
|
|
|
|
|
logits[vocab.token_transcribe] = -INFINITY;
|
|
|
|
|
|
|
|
|
|
if (params.logits_filter_callback) {
|
|
|
|
|
params.logits_filter_callback(&ctx, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// suppress non-speech tokens
|
|
|
|
|
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
|
|
|
|
|
if (params.suppress_non_speech_tokens)
|
|
|
|
|
{
|
|
|
|
|
for (const std::string &token : non_speech_tokens)
|
|
|
|
|
{
|
|
|
|
|
std::string suppress_tokens[] = {token, " " + token};
|
|
|
|
|
for (const std::string &suppress_token : suppress_tokens)
|
|
|
|
|
{
|
|
|
|
|
if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end())
|
|
|
|
|
{
|
|
|
|
|
if (params.suppress_non_speech_tokens) {
|
|
|
|
|
for (const std::string & token : non_speech_tokens) {
|
|
|
|
|
const std::string suppress_tokens[] = {token, " " + token};
|
|
|
|
|
for (const std::string & suppress_token : suppress_tokens) {
|
|
|
|
|
if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end()) {
|
|
|
|
|
logits[vocab.token_to_id.at(suppress_token)] = -INFINITY;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
|
|
|
|
if (vocab.token_to_id.find(" -") != vocab.token_to_id.end())
|
|
|
|
|
{
|
|
|
|
|
if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) {
|
|
|
|
|
logits[vocab.token_to_id.at(" -")] = -INFINITY;
|
|
|
|
|
}
|
|
|
|
|
if (vocab.token_to_id.find(" '") != vocab.token_to_id.end())
|
|
|
|
|
{
|
|
|
|
|
if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) {
|
|
|
|
|
logits[vocab.token_to_id.at(" '")] = -INFINITY;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -3854,7 +3854,7 @@ int whisper_full(
|
|
|
|
|
return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
unsigned int cur_c = 0;
|
|
|
|
|
uint32_t cur_c = 0;
|
|
|
|
|
|
|
|
|
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
|
|
|
auto & decoder = ctx->decoders[j];
|
|
|
|
|