diff --git a/whisper.cpp b/whisper.cpp index a5b402b..52c70dc 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -24,8 +24,9 @@ } \ } while (0) -#define USE_FLASH_ATTN -//#define USE_FLASH_FF +#define WHISPER_USE_FLASH_ATTN +//#define WHISPER_USE_FLASH_FF +#define WHISPER_MAX_DECODERS 16 // available whisper models enum e_model { @@ -416,6 +417,7 @@ struct whisper_model { }; struct whisper_decoder { + whisper_kv_cache kv_self; }; struct whisper_sequence { @@ -429,20 +431,24 @@ struct whisper_context { int64_t t_decode_us = 0; int64_t t_start_us = 0; - // memory buffers used by encode / decode contexts - std::vector buf_compute; - std::vector buf_compute_layer; - ggml_type wtype; // weight type (FP32 or FP16) + whisper_mel mel; + whisper_model model; whisper_vocab vocab; - whisper_kv_cache kv_self; whisper_kv_cache kv_cross; - whisper_mel mel; + whisper_decoder decoders[WHISPER_MAX_DECODERS] = {}; + int selected_decoder_id = 0; + + // memory buffers used by encode / decode contexts + std::vector buf_compute; + std::vector buf_compute_layer; + + // decode output std::vector probs; std::vector logits; @@ -470,11 +476,14 @@ static void read_safe(whisper_model_loader * loader, T & dest) { loader->read(loader->context, &dest, sizeof(T)); } -static bool init_kv_cache( +static bool kv_cache_init( const struct whisper_hparams & hparams, + const size_t mem_bytes, struct whisper_kv_cache & cache, ggml_type wtype, int n_ctx) { + cache.buf.resize(mem_bytes); + struct ggml_init_params params; params.mem_size = cache.buf.size(); params.mem_buffer = cache.buf.data(); @@ -498,6 +507,41 @@ static bool init_kv_cache( return true; } +static bool kv_cache_reinit(struct whisper_kv_cache & cache) { + WHISPER_ASSERT(cache.ctx); + + const int n_elements = ggml_nelements(cache.k); + WHISPER_ASSERT(n_elements == ggml_nelements(cache.v)); + + const ggml_type wtype = cache.k->type; + WHISPER_ASSERT(wtype == cache.v->type); + + WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*ggml_type_size(wtype)); + + struct ggml_init_params params; + params.mem_size = cache.buf.size(); + params.mem_buffer = cache.buf.data(); + + cache.ctx = ggml_init(params); + + if (!cache.ctx) { + fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__); + return false; + } + + cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + + return true; +} + +static void kv_cache_free(struct whisper_kv_cache & cache) { + if (cache.ctx) { + ggml_free(cache.ctx); + cache.ctx = nullptr; + } +} + // load the model from a ggml file // // file format: @@ -563,6 +607,12 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con model.type = e_model::MODEL_LARGE; } + // for the big tensors, we have the option to store the data in 16-bit floats + // in order to save memory and also to speed up the computation + wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32; + + const size_t scale = model.hparams.f16 ? 1 : 2; + fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab); fprintf(stderr, "%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx); fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); @@ -576,17 +626,47 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16); fprintf(stderr, "%s: type = %d\n", __func__, model.type); - // for the big tensors, we have the option to store the data in 16-bit floats - // in order to save memory and also to speed up the computation - wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32; + // print memory requirements + { + // this is the total memory required to run the inference + const size_t mem_required = + scale*MEM_REQ_MODEL.at (model.type) + + scale*MEM_REQ_KV_CROSS.at (model.type) + + scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)) + + scale*std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)); - const size_t scale = model.hparams.f16 ? 1 : 2; + // this is the memory required by one decoder + const size_t mem_required_decoder = + scale*MEM_REQ_KV_SELF.at(model.type); + + fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__, + mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0); + } wctx.model.buf = new std::vector(); wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type)); - wctx.kv_self.buf.resize (scale*MEM_REQ_KV_SELF.at(model.type)); - wctx.kv_cross.buf.resize(scale*MEM_REQ_KV_CROSS.at(model.type)); + wctx.selected_decoder_id = 0; + + if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, wctx.wtype, model.hparams.n_text_ctx)) { + fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); + return false; + } + + { + const size_t memory_size = ggml_nbytes(wctx.decoders[0].kv_self.k) + ggml_nbytes(wctx.decoders[0].kv_self.v); + fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); + } + + if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_CROSS.at(model.type), wctx.kv_cross, wctx.wtype, model.hparams.n_audio_ctx)) { + fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__); + return false; + } + + { + const size_t memory_size = ggml_nbytes(wctx.kv_cross.k) + ggml_nbytes(wctx.kv_cross.v); + fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); + } wctx.buf_compute.resize (scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); wctx.buf_compute_layer.resize(scale*std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type))); @@ -673,26 +753,14 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con } wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx); - wctx.probs.reserve(vocab.n_vocab*model.hparams.n_text_ctx); + wctx.probs.reserve (vocab.n_vocab*model.hparams.n_text_ctx); - wctx.work_logits.reserve(vocab.n_vocab); + wctx.work_logits.reserve (vocab.n_vocab); wctx.work_logprobs.reserve(vocab.n_vocab); wctx.probs_id.reserve(n_vocab); } - { - // this is the total memory required to run the inference - const size_t mem_required = - wctx.model.buf->size() + - wctx.kv_self.buf.size() + - wctx.kv_cross.buf.size() + - wctx.buf_compute.size() + - wctx.buf_compute_layer.size(); - - fprintf(stderr, "%s: mem required = %7.2f MB\n", __func__, mem_required / 1024.0 / 1024.0); - } - size_t ctx_size = 0; const ggml_type wtype = wctx.wtype; @@ -1016,27 +1084,6 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con } } - // TODO: move to decoder - if (!init_kv_cache(model.hparams, wctx.kv_self, wctx.wtype, model.hparams.n_text_ctx)) { - fprintf(stderr, "%s: init_kv_cache() failed for self-attention cache\n", __func__); - return false; - } - - { - const size_t memory_size = ggml_nbytes(wctx.kv_self.k) + ggml_nbytes(wctx.kv_self.v); - fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); - } - - if (!init_kv_cache(model.hparams, wctx.kv_cross, wctx.wtype, model.hparams.n_audio_ctx)) { - fprintf(stderr, "%s: init_kv_cache() failed for cross-attention cache\n", __func__); - return false; - } - - { - const size_t memory_size = ggml_nbytes(wctx.kv_cross.k) + ggml_nbytes(wctx.kv_cross.v); - fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); - } - // load weights { size_t total_size = 0; @@ -1266,7 +1313,7 @@ static bool whisper_encode( // ------ -#ifdef USE_FLASH_ATTN +#ifdef WHISPER_USE_FLASH_ATTN struct ggml_tensor * Q = ggml_permute(ctxL, ggml_cpy(ctxL, @@ -1377,7 +1424,7 @@ static bool whisper_encode( ggml_repeat(ctxL, layer.mlp_ln_b, cur)); } -#ifdef USE_FLASH_FF +#ifdef WHISPER_USE_FLASH_FF cur = ggml_flash_ff(ctxL, ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, wctx.wtype, n_state, N)), layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); @@ -1539,6 +1586,13 @@ static bool whisper_decode( const auto & model = wctx.model; const auto & hparams = model.hparams; + WHISPER_ASSERT(wctx.selected_decoder_id >= 0); + WHISPER_ASSERT(wctx.selected_decoder_id < WHISPER_MAX_DECODERS); + + auto & kv_self = wctx.decoders[wctx.selected_decoder_id].kv_self; + + WHISPER_ASSERT(!!kv_self.ctx); + auto & logits_out = wctx.logits; auto & probs_out = wctx.probs; @@ -1630,8 +1684,8 @@ static bool whisper_decode( // store key and value to memory { - struct ggml_tensor * k = ggml_view_1d(ctxL, wctx.kv_self.k, N*n_state, (ggml_element_size(wctx.kv_self.k)*n_state)*(il*n_ctx + n_past)); - struct ggml_tensor * v = ggml_view_1d(ctxL, wctx.kv_self.v, N*n_state, (ggml_element_size(wctx.kv_self.v)*n_state)*(il*n_ctx + n_past)); + struct ggml_tensor * k = ggml_view_1d(ctxL, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past)); + struct ggml_tensor * v = ggml_view_1d(ctxL, kv_self.v, N*n_state, (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + n_past)); ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Kcur, k)); ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Vcur, v)); @@ -1649,7 +1703,7 @@ static bool whisper_decode( struct ggml_tensor * K = ggml_permute(ctxL, ggml_reshape_3d(ctxL, - ggml_view_1d(ctxL, wctx.kv_self.k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(wctx.kv_self.k)*n_state), + ggml_view_1d(ctxL, kv_self.k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.k)*n_state), n_state/n_head, n_head, n_past + N), 0, 2, 1, 3); @@ -1669,7 +1723,7 @@ static bool whisper_decode( struct ggml_tensor * V_trans = ggml_permute(ctxL, ggml_reshape_3d(ctxL, - ggml_view_1d(ctxL, wctx.kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(wctx.kv_self.v)*n_state), + ggml_view_1d(ctxL, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state), n_state/n_head, n_head, n_past + N), 1, 2, 0, 3); @@ -2370,12 +2424,14 @@ void whisper_free(struct whisper_context * ctx) { if (ctx->model.buf) { delete ctx->model.buf; } - if (ctx->kv_self.ctx) { - ggml_free(ctx->kv_self.ctx); - } if (ctx->kv_cross.ctx) { ggml_free(ctx->kv_cross.ctx); } + for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) { + if (ctx->decoders[i].kv_self.ctx) { + ggml_free(ctx->decoders[i].kv_self.ctx); + } + } delete ctx; } } @@ -3413,14 +3469,16 @@ int whisper_full_parallel( ctx_p = *ctx; - if (!init_kv_cache(ctx_p.model.hparams, ctx_p.kv_self, ctx_p.wtype, ctx_p.model.hparams.n_text_ctx)) { - fprintf(stderr, "%s: init_kv_cache() failed for self-attention cache\n", __func__); + if (!kv_cache_reinit(ctx_p.kv_cross)) { + fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention cache\n", __func__); return false; } - if (!init_kv_cache(ctx_p.model.hparams, ctx_p.kv_cross, ctx_p.wtype, ctx_p.model.hparams.n_audio_ctx)) { - fprintf(stderr, "%s: init_kv_cache() failed for cross-attention cache\n", __func__); - return false; + for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) { + if (ctx_p.decoders[j].kv_self.ctx && !kv_cache_reinit(ctx_p.decoders[j].kv_self)) { + fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention cache\n", __func__); + return false; + } } } @@ -3485,6 +3543,12 @@ int whisper_full_parallel( ctx->t_sample_us += ctxs[i].t_sample_us; ctx->t_encode_us += ctxs[i].t_encode_us; ctx->t_decode_us += ctxs[i].t_decode_us; + + kv_cache_free(ctx->kv_cross); + + for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) { + kv_cache_free(ctx->decoders[j].kv_self); + } } // average the timings