sync : latest whisper.cpp

pull/35/head
Georgi Gerganov 1 year ago
parent efa2cc36a2
commit 8f8a5aca99
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -91,12 +91,12 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
for (int i = 1; i < argc; i++) { for (int i = 1; i < argc; i++) {
std::string arg = argv[i]; std::string arg = argv[i];
if (arg == "-"){ if (arg == "-"){
params.fname_inp.push_back(arg); params.fname_inp.push_back(arg);
continue; continue;
} }
if (arg[0] != '-') { if (arg[0] != '-') {
params.fname_inp.push_back(arg); params.fname_inp.push_back(arg);
continue; continue;

@ -592,16 +592,16 @@ struct whisper_context {
mutable std::mt19937 rng; // used for sampling at t > 0.0 mutable std::mt19937 rng; // used for sampling at t > 0.0
int lang_id; int lang_id = 0; // english by default
// [EXPERIMENTAL] token-level timestamps data // [EXPERIMENTAL] token-level timestamps data
int64_t t_beg; int64_t t_beg = 0;
int64_t t_last; int64_t t_last = 0;
whisper_token tid_last; whisper_token tid_last;
std::vector<float> energy; // PCM signal energy std::vector<float> energy; // PCM signal energy
// [EXPERIMENTAL] speed-up techniques // [EXPERIMENTAL] speed-up techniques
int32_t exp_n_audio_ctx; // 0 - use default int32_t exp_n_audio_ctx = 0; // 0 - use default
void use_buf(struct ggml_context * ctx, int i) { void use_buf(struct ggml_context * ctx, int i) {
#if defined(WHISPER_USE_SCRATCH) #if defined(WHISPER_USE_SCRATCH)
@ -805,7 +805,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
MEM_REQ_SCRATCH3.at (model.type) + MEM_REQ_SCRATCH3.at (model.type) +
scale*MEM_REQ_MODEL.at (model.type) + scale*MEM_REQ_MODEL.at (model.type) +
scale*MEM_REQ_KV_CROSS.at(model.type) + scale*MEM_REQ_KV_CROSS.at(model.type) +
scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)); scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
// this is the memory required by one decoder // this is the memory required by one decoder
const size_t mem_required_decoder = const size_t mem_required_decoder =
@ -2936,7 +2936,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.language =*/ "en", /*.language =*/ "en",
/*.suppress_blank =*/ true, /*.suppress_blank =*/ true,
/*.suppress_non_speech_tokens =*/true, /*.suppress_non_speech_tokens =*/ false,
/*.temperature =*/ 0.0f, /*.temperature =*/ 0.0f,
/*.max_initial_ts =*/ 1.0f, /*.max_initial_ts =*/ 1.0f,
@ -2962,6 +2962,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.encoder_begin_callback =*/ nullptr, /*.encoder_begin_callback =*/ nullptr,
/*.encoder_begin_callback_user_data =*/ nullptr, /*.encoder_begin_callback_user_data =*/ nullptr,
/*.logits_filter_callback =*/ nullptr,
/*.logits_filter_callback_user_data =*/ nullptr,
}; };
switch (strategy) { switch (strategy) {
@ -3078,8 +3081,7 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool
return res; return res;
} }
static const std::vector<std::string> non_speech_tokens static const std::vector<std::string> non_speech_tokens = {
{
"\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^", "\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^",
"_", "`", "{", "|", "}", "~", "", "", "", "", "<<", ">>", "<<<", ">>>", "--", "_", "`", "{", "|", "}", "~", "", "", "", "", "<<", ">>", "<<<", ">>>", "--",
"---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪", "---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪",
@ -3090,7 +3092,7 @@ static const std::vector<std::string> non_speech_tokens
// - applies logit filters // - applies logit filters
// - computes logprobs and probs // - computes logprobs and probs
static void whisper_process_logits( static void whisper_process_logits(
const struct whisper_context & ctx, struct whisper_context & ctx,
const struct whisper_full_params params, const struct whisper_full_params params,
struct whisper_decoder & decoder, struct whisper_decoder & decoder,
float temperature) { float temperature) {
@ -3146,29 +3148,27 @@ static void whisper_process_logits(
logits[vocab.token_translate] = -INFINITY; logits[vocab.token_translate] = -INFINITY;
logits[vocab.token_transcribe] = -INFINITY; logits[vocab.token_transcribe] = -INFINITY;
if (params.logits_filter_callback) {
params.logits_filter_callback(&ctx, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
}
// suppress non-speech tokens // suppress non-speech tokens
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
if (params.suppress_non_speech_tokens) if (params.suppress_non_speech_tokens) {
{ for (const std::string & token : non_speech_tokens) {
for (const std::string &token : non_speech_tokens) const std::string suppress_tokens[] = {token, " " + token};
{ for (const std::string & suppress_token : suppress_tokens) {
std::string suppress_tokens[] = {token, " " + token}; if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end()) {
for (const std::string &suppress_token : suppress_tokens)
{
if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end())
{
logits[vocab.token_to_id.at(suppress_token)] = -INFINITY; logits[vocab.token_to_id.at(suppress_token)] = -INFINITY;
} }
} }
} }
// allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word // allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) {
{
logits[vocab.token_to_id.at(" -")] = -INFINITY; logits[vocab.token_to_id.at(" -")] = -INFINITY;
} }
if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) {
{
logits[vocab.token_to_id.at(" '")] = -INFINITY; logits[vocab.token_to_id.at(" '")] = -INFINITY;
} }
} }
@ -3854,7 +3854,7 @@ int whisper_full(
return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all; return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
}); });
unsigned int cur_c = 0; uint32_t cur_c = 0;
for (int j = 0; j < n_decoders_cur; ++j) { for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = ctx->decoders[j]; auto & decoder = ctx->decoders[j];
@ -4339,7 +4339,7 @@ int whisper_full_n_segments(struct whisper_context * ctx) {
} }
int whisper_full_lang_id(struct whisper_context * ctx) { int whisper_full_lang_id(struct whisper_context * ctx) {
return ctx->lang_id; return ctx->lang_id;
} }
int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) { int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {

@ -243,6 +243,16 @@ extern "C" {
// If it returns false, the computation is aborted // If it returns false, the computation is aborted
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data); typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data);
// Logits filter callback
// Can be used to modify the logits before sampling
// If not NULL, called after applying temperature to logits
typedef void (*whisper_logits_filter_callback)(
struct whisper_context * ctx,
const whisper_token_data * tokens,
int n_tokens,
float * logits,
void * user_data);
// Parameters for the whisper_full() function // Parameters for the whisper_full() function
// If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp: // If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp:
// whisper_full_default_params() // whisper_full_default_params()
@ -315,6 +325,10 @@ extern "C" {
// called each time before the encoder starts // called each time before the encoder starts
whisper_encoder_begin_callback encoder_begin_callback; whisper_encoder_begin_callback encoder_begin_callback;
void * encoder_begin_callback_user_data; void * encoder_begin_callback_user_data;
// called by each decoder to filter obtained logits
whisper_logits_filter_callback logits_filter_callback;
void * logits_filter_callback_user_data;
}; };
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy); WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);

Loading…
Cancel
Save