diff --git a/whisper.cpp b/whisper.cpp index 0776076..d23e97f 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -204,6 +204,10 @@ struct whisper_vocab { std::map token_to_id; std::map id_to_token; + // used to avoid memory allocations during sampling + // TODO: move to whisper_context in the future + std::vector> probs_id; + id token_eot = 50256; id token_sot = 50257; id token_prev = 50360; @@ -551,6 +555,9 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx std::string word; std::vector tmp; + + tmp.reserve(128); + for (int i = 0; i < n_vocab; i++) { uint32_t len; read_safe(fin, len); @@ -603,6 +610,11 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx vocab.id_to_token[i] = word; } } + + wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx); + wctx.probs.reserve(vocab.n_vocab*model.hparams.n_text_ctx); + + vocab.probs_id.reserve(n_vocab); } { @@ -1021,7 +1033,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx std::string name; std::vector tmp(length); // create a buffer - fin.read( &tmp[0], tmp.size() ); // read to buffer + fin.read(&tmp[0], tmp.size()); // read to buffer name.assign(&tmp[0], tmp.size()); if (model.tensors.find(name) == model.tensors.end()) { @@ -1849,7 +1861,7 @@ static bool whisper_decode( // the most basic sampling scheme - select the top token static whisper_token_data whisper_sample_best( - const whisper_vocab & vocab, + whisper_vocab & vocab, const float * probs, bool force_timestamp, bool is_initial) { @@ -1857,11 +1869,11 @@ static whisper_token_data whisper_sample_best( 0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f, }; - int n_logits = vocab.id_to_token.size(); + const int n_logits = vocab.n_vocab; - std::vector> probs_id; - probs_id.reserve(n_logits); + auto & probs_id = vocab.probs_id; + probs_id.clear(); for (int i = 0; i < n_logits; i++) { probs_id.emplace_back(probs[i], i); } @@ -2001,6 +2013,9 @@ static void fft(const std::vector & in, std::vector & out) { std::vector even; std::vector odd; + even.reserve(N/2); + odd.reserve(N/2); + for (int i = 0; i < N; i++) { if (i % 2 == 0) { even.push_back(in[i]); @@ -2434,7 +2449,7 @@ int whisper_lang_auto_detect( std::vector> probs_id; for (const auto & kv : g_lang) { const auto token_lang = whisper_token_lang(ctx, kv.second.first); - probs_id.emplace_back( ctx->probs[token_lang], kv.second.first ); + probs_id.emplace_back(ctx->probs[token_lang], kv.second.first); } // sort descending