whisper : move probs_id buffer to whisper_context

pull/291/head
Georgi Gerganov 3 years ago
parent 523e0494a6
commit 1163f266d2
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -220,10 +220,6 @@ struct whisper_vocab {
std::map<token, id> token_to_id; std::map<token, id> token_to_id;
std::map<id, token> id_to_token; std::map<id, token> id_to_token;
// used to avoid memory allocations during sampling
// TODO: move to whisper_context in the future
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
id token_eot = 50256; id token_eot = 50256;
id token_sot = 50257; id token_sot = 50257;
id token_prev = 50360; id token_prev = 50360;
@ -387,14 +383,14 @@ struct whisper_model {
struct ggml_tensor * e_ln_b; struct ggml_tensor * e_ln_b;
// decoder.positional_embedding // decoder.positional_embedding
struct ggml_tensor * d_pe; // DD struct ggml_tensor * d_pe;
// decoder.token_embedding // decoder.token_embedding
struct ggml_tensor * d_te; // DD struct ggml_tensor * d_te;
// decoder.ln // decoder.ln
struct ggml_tensor * d_ln_w; // DD struct ggml_tensor * d_ln_w;
struct ggml_tensor * d_ln_b; // DD struct ggml_tensor * d_ln_b;
std::vector<whisper_layer_encoder> layers_encoder; std::vector<whisper_layer_encoder> layers_encoder;
std::vector<whisper_layer_decoder> layers_decoder; std::vector<whisper_layer_decoder> layers_decoder;
@ -451,6 +447,8 @@ struct whisper_context {
std::vector<float> work_logits; std::vector<float> work_logits;
std::vector<float> work_logprobs; std::vector<float> work_logprobs;
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
// [EXPERIMENTAL] token-level timestamps data // [EXPERIMENTAL] token-level timestamps data
int64_t t_beg; int64_t t_beg;
int64_t t_last; int64_t t_last;
@ -545,10 +543,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
fprintf(stderr, "%s: type = %d\n", __func__, model.type); fprintf(stderr, "%s: type = %d\n", __func__, model.type);
wctx.buf_model = new std::vector<uint8_t>(); wctx.buf_model = new std::vector<uint8_t>();
wctx.buf_model->resize(MEM_REQ_MODEL.at(model.type)); wctx.buf_model->resize (MEM_REQ_MODEL.at(model.type));
wctx.buf_memory.resize(MEM_REQ_MEMORY.at(model.type)); wctx.buf_memory.resize (MEM_REQ_MEMORY.at(model.type));
wctx.buf_memory_cross.resize(MEM_REQ_MEMORY_CROSS.at(model.type)); wctx.buf_memory_cross.resize (MEM_REQ_MEMORY_CROSS.at(model.type));
wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); wctx.buf_compute.resize (std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type))); wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
} }
@ -638,7 +636,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
wctx.work_logits.reserve(vocab.n_vocab); wctx.work_logits.reserve(vocab.n_vocab);
wctx.work_logprobs.reserve(vocab.n_vocab); wctx.work_logprobs.reserve(vocab.n_vocab);
vocab.probs_id.reserve(n_vocab); wctx.probs_id.reserve(n_vocab);
} }
{ {
@ -1900,7 +1898,7 @@ static bool whisper_decode(
// the most basic sampling scheme - select the top token // the most basic sampling scheme - select the top token
static whisper_token_data whisper_sample_best( static whisper_token_data whisper_sample_best(
whisper_vocab & vocab, whisper_context & ctx,
const float * probs, const float * probs,
bool force_timestamp, bool force_timestamp,
bool is_initial) { bool is_initial) {
@ -1908,9 +1906,11 @@ static whisper_token_data whisper_sample_best(
0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f, 0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
}; };
const auto & vocab = ctx.vocab;
const int n_logits = vocab.n_vocab; const int n_logits = vocab.n_vocab;
auto & probs_id = vocab.probs_id; auto & probs_id = ctx.probs_id;
probs_id.clear(); probs_id.clear();
for (int i = 0; i < n_logits; i++) { for (int i = 0; i < n_logits; i++) {
@ -2461,7 +2461,7 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) { struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) {
const int64_t t_start_sample_us = ggml_time_us(); const int64_t t_start_sample_us = ggml_time_us();
const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), false, false); const auto res = whisper_sample_best(*ctx, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), false, false);
ctx->t_sample_us += ggml_time_us() - t_start_sample_us; ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
@ -2471,7 +2471,7 @@ struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) {
struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial) { struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial) {
const int64_t t_start_sample_us = ggml_time_us(); const int64_t t_start_sample_us = ggml_time_us();
const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), true, is_initial); const auto res = whisper_sample_best(*ctx, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), true, is_initial);
ctx->t_sample_us += ggml_time_us() - t_start_sample_us; ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
@ -2558,7 +2558,9 @@ int whisper_lang_auto_detect(
return -7; return -7;
} }
std::vector<std::pair<float, int>> probs_id; auto & probs_id = ctx->probs_id;
probs_id.clear();
for (const auto & kv : g_lang) { for (const auto & kv : g_lang) {
const auto token_lang = whisper_token_lang(ctx, kv.second.first); const auto token_lang = whisper_token_lang(ctx, kv.second.first);
probs_id.emplace_back(ctx->probs[token_lang], kv.second.first); probs_id.emplace_back(ctx->probs[token_lang], kv.second.first);
@ -2566,7 +2568,7 @@ int whisper_lang_auto_detect(
// sort descending // sort descending
{ {
using pair_type = decltype(probs_id)::value_type; using pair_type = std::remove_reference<decltype(probs_id)>::type::value_type;
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) { std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
return a.first > b.first; return a.first > b.first;
}); });

Loading…
Cancel
Save