From ee58108df8d4f46b08a3497016e66effb288f0a6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 9 Jan 2023 20:42:32 +0200 Subject: [PATCH] whisper : refactor kv cache into separate struct --- whisper.cpp | 279 ++++++++++++++++++++++------------------------------ whisper.h | 8 -- 2 files changed, 116 insertions(+), 171 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index 5b7b3fe..a5b402b 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -149,7 +149,7 @@ static const std::map MEM_REQ_MODEL = { { MODEL_LARGE, 2952ull*MB }, }; -static const std::map MEM_REQ_MEMORY = { +static const std::map MEM_REQ_KV_SELF = { { MODEL_TINY, 3ull*MB }, { MODEL_BASE, 6ull*MB }, { MODEL_SMALL, 16ull*MB }, @@ -157,7 +157,7 @@ static const std::map MEM_REQ_MEMORY = { { MODEL_LARGE, 71ull*MB }, }; -static const std::map MEM_REQ_MEMORY_CROSS = { +static const std::map MEM_REQ_KV_CROSS = { { MODEL_TINY, 9ull*MB }, { MODEL_BASE, 18ull*MB }, { MODEL_SMALL, 53ull*MB }, @@ -361,6 +361,15 @@ struct whisper_layer_decoder { struct ggml_tensor * mlp_1_b; }; +struct whisper_kv_cache { + struct ggml_tensor * k; + struct ggml_tensor * v; + + struct ggml_context * ctx; + + std::vector buf; +}; + struct whisper_model { e_model type = MODEL_UNKNOWN; @@ -395,25 +404,21 @@ struct whisper_model { std::vector layers_encoder; std::vector layers_decoder; - // key + value memory for self attention - struct ggml_tensor * memory_k; - struct ggml_tensor * memory_v; - - // key + value memory for cross attention - struct ggml_tensor * memory_cross_k; - struct ggml_tensor * memory_cross_v; - // context struct ggml_context * ctx; - struct ggml_context * ctx_mem; - struct ggml_context * ctx_mem_cross; + + // the model memory buffer is read-only and can be shared between processors + std::vector * buf; // tensors int n_loaded; std::map tensors; }; -struct whisper_decoder_data { +struct whisper_decoder { +}; + +struct whisper_sequence { }; struct whisper_context { @@ -424,17 +429,18 @@ struct whisper_context { int64_t t_decode_us = 0; int64_t t_start_us = 0; - std::vector * buf_model; // the model buffer is read-only and can be shared between processors - std::vector buf_memory; - std::vector buf_memory_cross; - std::vector buf_compute; - std::vector buf_compute_layer; + // memory buffers used by encode / decode contexts + std::vector buf_compute; + std::vector buf_compute_layer; ggml_type wtype; // weight type (FP32 or FP16) whisper_model model; whisper_vocab vocab; + whisper_kv_cache kv_self; + whisper_kv_cache kv_cross; + whisper_mel mel; std::vector probs; @@ -464,6 +470,34 @@ static void read_safe(whisper_model_loader * loader, T & dest) { loader->read(loader->context, &dest, sizeof(T)); } +static bool init_kv_cache( + const struct whisper_hparams & hparams, + struct whisper_kv_cache & cache, + ggml_type wtype, + int n_ctx) { + struct ggml_init_params params; + params.mem_size = cache.buf.size(); + params.mem_buffer = cache.buf.data(); + + cache.ctx = ggml_init(params); + + if (!cache.ctx) { + fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__); + return false; + } + + const int n_text_state = hparams.n_text_state; + const int n_text_layer = hparams.n_text_layer; + + const int n_mem = n_text_layer*n_ctx; + const int n_elements = n_text_state*n_mem; + + cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + + return true; +} + // load the model from a ggml file // // file format: @@ -542,12 +576,20 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16); fprintf(stderr, "%s: type = %d\n", __func__, model.type); - wctx.buf_model = new std::vector(); - wctx.buf_model->resize (MEM_REQ_MODEL.at(model.type)); - wctx.buf_memory.resize (MEM_REQ_MEMORY.at(model.type)); - wctx.buf_memory_cross.resize (MEM_REQ_MEMORY_CROSS.at(model.type)); - wctx.buf_compute.resize (std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); - wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type))); + // for the big tensors, we have the option to store the data in 16-bit floats + // in order to save memory and also to speed up the computation + wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32; + + const size_t scale = model.hparams.f16 ? 1 : 2; + + wctx.model.buf = new std::vector(); + wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type)); + + wctx.kv_self.buf.resize (scale*MEM_REQ_KV_SELF.at(model.type)); + wctx.kv_cross.buf.resize(scale*MEM_REQ_KV_CROSS.at(model.type)); + + wctx.buf_compute.resize (scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); + wctx.buf_compute_layer.resize(scale*std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type))); } // load mel filters @@ -642,23 +684,19 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con { // this is the total memory required to run the inference const size_t mem_required = - wctx.buf_model->size() + - wctx.buf_memory.size() + - wctx.buf_memory_cross.size() + + wctx.model.buf->size() + + wctx.kv_self.buf.size() + + wctx.kv_cross.buf.size() + wctx.buf_compute.size() + wctx.buf_compute_layer.size(); - fprintf(stderr, "%s: mem_required = %7.2f MB\n", __func__, mem_required / 1024.0 / 1024.0); + fprintf(stderr, "%s: mem required = %7.2f MB\n", __func__, mem_required / 1024.0 / 1024.0); } - // for the big tensors, we have the option to store the data in 16-bit floats - // in order to save memory and also to speed up the computation - wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32; + size_t ctx_size = 0; const ggml_type wtype = wctx.wtype; - size_t ctx_size = 0; - { const auto & hparams = model.hparams; @@ -766,14 +804,14 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead - fprintf(stderr, "%s: ggml ctx size = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0)); + fprintf(stderr, "%s: model ctx = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0)); } // create the ggml context { struct ggml_init_params params; - params.mem_size = wctx.buf_model->size(); - params.mem_buffer = wctx.buf_model->data(); + params.mem_size = wctx.model.buf->size(); + params.mem_buffer = wctx.model.buf->data(); model.ctx = ggml_init(params); if (!model.ctx) { @@ -978,69 +1016,25 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con } } - // create the ggml context for the key/value memory (self-attention) - { - struct ggml_init_params params; - params.mem_size = wctx.buf_memory.size(); - params.mem_buffer = wctx.buf_memory.data(); - - auto & ctx = model.ctx_mem; - - ctx = ggml_init(params); - if (!ctx) { - fprintf(stderr, "%s: ggml_init() failed\n", __func__); - return false; - } - - { - const auto & hparams = model.hparams; - - const int n_text_state = hparams.n_text_state; - const int n_text_layer = hparams.n_text_layer; - const int n_text_ctx = hparams.n_text_ctx; - - const int n_mem = n_text_layer*n_text_ctx; - const int n_elements = n_text_state*n_mem; - - model.memory_k = ggml_new_tensor_1d(ctx, wtype, n_elements); - model.memory_v = ggml_new_tensor_1d(ctx, wtype, n_elements); - } - - const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v); - fprintf(stderr, "%s: memory size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); + // TODO: move to decoder + if (!init_kv_cache(model.hparams, wctx.kv_self, wctx.wtype, model.hparams.n_text_ctx)) { + fprintf(stderr, "%s: init_kv_cache() failed for self-attention cache\n", __func__); + return false; } - // create the ggml context for the key/value memory (cross-attention) { - struct ggml_init_params params; - params.mem_size = wctx.buf_memory_cross.size(); - params.mem_buffer = wctx.buf_memory_cross.data(); - - auto & ctx = model.ctx_mem_cross; - - ctx = ggml_init(params); - if (!ctx) { - fprintf(stderr, "%s: ggml_init() failed\n", __func__); - return false; - } - - { - const auto & hparams = model.hparams; - - const int n_text_state = hparams.n_text_state; - const int n_text_layer = hparams.n_text_layer; - const int n_audio_ctx = hparams.n_audio_ctx; - - const int n_mem = n_text_layer*n_audio_ctx; - const int n_elements = n_text_state*n_mem; - - model.memory_cross_k = ggml_new_tensor_1d(ctx, wtype, n_elements); - model.memory_cross_v = ggml_new_tensor_1d(ctx, wtype, n_elements); - } + const size_t memory_size = ggml_nbytes(wctx.kv_self.k) + ggml_nbytes(wctx.kv_self.v); + fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); + } - const size_t memory_cross_size = ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v); + if (!init_kv_cache(model.hparams, wctx.kv_cross, wctx.wtype, model.hparams.n_audio_ctx)) { + fprintf(stderr, "%s: init_kv_cache() failed for cross-attention cache\n", __func__); + return false; + } - fprintf(stderr, "%s: memory x size = %7.2f MB\n", __func__, memory_cross_size/1024.0/1024.0); + { + const size_t memory_size = ggml_nbytes(wctx.kv_cross.k) + ggml_nbytes(wctx.kv_cross.v); + fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); } // load weights @@ -1504,10 +1498,10 @@ static bool whisper_encode( Vcross), Vcross); - //struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); - //struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); - struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*n_ctx)); - struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*n_ctx)); + //struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); + //struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); + struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*n_ctx)); + struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*n_ctx)); ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k)); ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v)); @@ -1636,8 +1630,8 @@ static bool whisper_decode( // store key and value to memory { - struct ggml_tensor * k = ggml_view_1d(ctxL, model.memory_k, N*n_state, (ggml_element_size(model.memory_k)*n_state)*(il*n_ctx + n_past)); - struct ggml_tensor * v = ggml_view_1d(ctxL, model.memory_v, N*n_state, (ggml_element_size(model.memory_v)*n_state)*(il*n_ctx + n_past)); + struct ggml_tensor * k = ggml_view_1d(ctxL, wctx.kv_self.k, N*n_state, (ggml_element_size(wctx.kv_self.k)*n_state)*(il*n_ctx + n_past)); + struct ggml_tensor * v = ggml_view_1d(ctxL, wctx.kv_self.v, N*n_state, (ggml_element_size(wctx.kv_self.v)*n_state)*(il*n_ctx + n_past)); ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Kcur, k)); ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Vcur, v)); @@ -1655,7 +1649,7 @@ static bool whisper_decode( struct ggml_tensor * K = ggml_permute(ctxL, ggml_reshape_3d(ctxL, - ggml_view_1d(ctxL, model.memory_k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_k)*n_state), + ggml_view_1d(ctxL, wctx.kv_self.k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(wctx.kv_self.k)*n_state), n_state/n_head, n_head, n_past + N), 0, 2, 1, 3); @@ -1675,7 +1669,7 @@ static bool whisper_decode( struct ggml_tensor * V_trans = ggml_permute(ctxL, ggml_reshape_3d(ctxL, - ggml_view_1d(ctxL, model.memory_v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_v)*n_state), + ggml_view_1d(ctxL, wctx.kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(wctx.kv_self.v)*n_state), n_state/n_head, n_head, n_past + N), 1, 2, 0, 3); @@ -1730,12 +1724,12 @@ static bool whisper_decode( // Kcross is already scaled struct ggml_tensor * Kcross = ggml_reshape_3d(ctxL, - ggml_view_1d(ctxL, model.memory_cross_k, M*n_state, il*M*ggml_element_size(model.memory_cross_k)*n_state), + ggml_view_1d(ctxL, wctx.kv_cross.k, M*n_state, il*M*ggml_element_size(wctx.kv_cross.k)*n_state), n_state/n_head, n_head, M); struct ggml_tensor * Vcross = ggml_reshape_3d(ctxL, - ggml_view_1d(ctxL, model.memory_cross_v, M*n_state, il*M*ggml_element_size(model.memory_cross_v)*n_state), + ggml_view_1d(ctxL, wctx.kv_cross.v, M*n_state, il*M*ggml_element_size(wctx.kv_cross.v)*n_state), n_state/n_head, n_head, M); // ------ @@ -2373,14 +2367,14 @@ void whisper_free(struct whisper_context * ctx) { if (ctx->model.ctx) { ggml_free(ctx->model.ctx); } - if (ctx->model.ctx_mem) { - ggml_free(ctx->model.ctx_mem); + if (ctx->model.buf) { + delete ctx->model.buf; } - if (ctx->model.ctx_mem_cross) { - ggml_free(ctx->model.ctx_mem_cross); + if (ctx->kv_self.ctx) { + ggml_free(ctx->kv_self.ctx); } - if (ctx->buf_model) { - delete ctx->buf_model; + if (ctx->kv_cross.ctx) { + ggml_free(ctx->kv_cross.ctx); } delete ctx; } @@ -2458,7 +2452,8 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i return 0; } -struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) { +// TODO: remove +static struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) { const int64_t t_start_sample_us = ggml_time_us(); const auto res = whisper_sample_best(*ctx, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), false, false); @@ -2468,7 +2463,8 @@ struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) { return res; } -struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial) { +// TODO: remove +static struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial) { const int64_t t_start_sample_us = ggml_time_us(); const auto res = whisper_sample_best(*ctx, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), true, is_initial); @@ -3413,61 +3409,18 @@ int whisper_full_parallel( std::vector ctxs(n_processors - 1); for (int i = 0; i < n_processors - 1; ++i) { - ctxs[i] = *ctx; - - auto & model = ctxs[i].model; - - // separate key + value memory for each processor (self-attention) - { - struct ggml_init_params params; - params.mem_size = ctxs[i].buf_memory.size(); - params.mem_buffer = ctxs[i].buf_memory.data(); - - auto & mctx = model.ctx_mem; - mctx = ggml_init(params); - if (!mctx) { - fprintf(stderr, "%s: ggml_init() failed\n", __func__); - return false; - } - - { - const auto & hparams = model.hparams; - - const int n_text_state = hparams.n_text_state; - const int n_text_layer = hparams.n_text_layer; - const int n_text_ctx = hparams.n_text_ctx; + auto & ctx_p = ctxs[i]; - const int n_mem = n_text_layer*n_text_ctx; - const int n_elements = n_text_state*n_mem; + ctx_p = *ctx; - model.memory_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements); - model.memory_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements); - } + if (!init_kv_cache(ctx_p.model.hparams, ctx_p.kv_self, ctx_p.wtype, ctx_p.model.hparams.n_text_ctx)) { + fprintf(stderr, "%s: init_kv_cache() failed for self-attention cache\n", __func__); + return false; } - // separate key + value memory for each processor (cross-attention) - { - struct ggml_init_params params; - params.mem_size = ctxs[i].buf_memory_cross.size(); - params.mem_buffer = ctxs[i].buf_memory_cross.data(); - - auto & mctx = model.ctx_mem_cross; - mctx = ggml_init(params); - if (!mctx) { - fprintf(stderr, "%s: ggml_init() failed\n", __func__); - return false; - } - const auto & hparams = model.hparams; - - const int n_text_state = hparams.n_text_state; - const int n_text_layer = hparams.n_text_layer; - const int n_audio_ctx = hparams.n_audio_ctx; - - const int n_mem = n_text_layer*n_audio_ctx; - const int n_elements = n_text_state*n_mem; - - model.memory_cross_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements); - model.memory_cross_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements); + if (!init_kv_cache(ctx_p.model.hparams, ctx_p.kv_cross, ctx_p.wtype, ctx_p.model.hparams.n_audio_ctx)) { + fprintf(stderr, "%s: init_kv_cache() failed for cross-attention cache\n", __func__); + return false; } } diff --git a/whisper.h b/whisper.h index ccc1ff9..2983765 100644 --- a/whisper.h +++ b/whisper.h @@ -143,14 +143,6 @@ extern "C" { int n_past, int n_threads); - // Token sampling methods. - // These are provided for convenience and can be used after each call to whisper_decode(). - // You can also implement your own sampling method using the whiper_get_logits() or whisper_get_probs() functions. - // whisper_sample_best() returns the token with the highest probability - // whisper_sample_timestamp() returns the most probable timestamp token - WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx); - WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial); - // Convert the provided text into tokens. // The tokens pointer must be large enough to hold the resulting tokens. // Returns the number of tokens on success, no more than n_max_tokens