whisper : move probs_id buffer to whisper_context

pull/291/head
Georgi Gerganov 3 years ago
parent 523e0494a6
commit 1163f266d2
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -220,10 +220,6 @@ struct whisper_vocab {
std::map<token, id> token_to_id;
std::map<id, token> id_to_token;
// used to avoid memory allocations during sampling
// TODO: move to whisper_context in the future
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
id token_eot = 50256;
id token_sot = 50257;
id token_prev = 50360;
@ -387,14 +383,14 @@ struct whisper_model {
struct ggml_tensor * e_ln_b;
// decoder.positional_embedding
struct ggml_tensor * d_pe; // DD
struct ggml_tensor * d_pe;
// decoder.token_embedding
struct ggml_tensor * d_te; // DD
struct ggml_tensor * d_te;
// decoder.ln
struct ggml_tensor * d_ln_w; // DD
struct ggml_tensor * d_ln_b; // DD
struct ggml_tensor * d_ln_w;
struct ggml_tensor * d_ln_b;
std::vector<whisper_layer_encoder> layers_encoder;
std::vector<whisper_layer_decoder> layers_decoder;
@ -451,6 +447,8 @@ struct whisper_context {
std::vector<float> work_logits;
std::vector<float> work_logprobs;
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
// [EXPERIMENTAL] token-level timestamps data
int64_t t_beg;
int64_t t_last;
@ -545,10 +543,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
fprintf(stderr, "%s: type = %d\n", __func__, model.type);
wctx.buf_model = new std::vector<uint8_t>();
wctx.buf_model->resize(MEM_REQ_MODEL.at(model.type));
wctx.buf_memory.resize(MEM_REQ_MEMORY.at(model.type));
wctx.buf_memory_cross.resize(MEM_REQ_MEMORY_CROSS.at(model.type));
wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
wctx.buf_model->resize (MEM_REQ_MODEL.at(model.type));
wctx.buf_memory.resize (MEM_REQ_MEMORY.at(model.type));
wctx.buf_memory_cross.resize (MEM_REQ_MEMORY_CROSS.at(model.type));
wctx.buf_compute.resize (std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
}
@ -638,7 +636,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
wctx.work_logits.reserve(vocab.n_vocab);
wctx.work_logprobs.reserve(vocab.n_vocab);
vocab.probs_id.reserve(n_vocab);
wctx.probs_id.reserve(n_vocab);
}
{
@ -1008,7 +1006,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
model.memory_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
}
const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
fprintf(stderr, "%s: memory size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
}
@ -1900,17 +1898,19 @@ static bool whisper_decode(
// the most basic sampling scheme - select the top token
static whisper_token_data whisper_sample_best(
whisper_vocab & vocab,
whisper_context & ctx,
const float * probs,
bool force_timestamp,
bool is_initial) {
bool force_timestamp,
bool is_initial) {
whisper_token_data result = {
0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
};
const auto & vocab = ctx.vocab;
const int n_logits = vocab.n_vocab;
auto & probs_id = vocab.probs_id;
auto & probs_id = ctx.probs_id;
probs_id.clear();
for (int i = 0; i < n_logits; i++) {
@ -2461,7 +2461,7 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) {
const int64_t t_start_sample_us = ggml_time_us();
const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), false, false);
const auto res = whisper_sample_best(*ctx, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), false, false);
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
@ -2471,7 +2471,7 @@ struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) {
struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial) {
const int64_t t_start_sample_us = ggml_time_us();
const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), true, is_initial);
const auto res = whisper_sample_best(*ctx, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), true, is_initial);
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
@ -2558,7 +2558,9 @@ int whisper_lang_auto_detect(
return -7;
}
std::vector<std::pair<float, int>> probs_id;
auto & probs_id = ctx->probs_id;
probs_id.clear();
for (const auto & kv : g_lang) {
const auto token_lang = whisper_token_lang(ctx, kv.second.first);
probs_id.emplace_back(ctx->probs[token_lang], kv.second.first);
@ -2566,7 +2568,7 @@ int whisper_lang_auto_detect(
// sort descending
{
using pair_type = decltype(probs_id)::value_type;
using pair_type = std::remove_reference<decltype(probs_id)>::type::value_type;
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
return a.first > b.first;
});

Loading…
Cancel
Save