From f2d174f530119ee6c361ced202dfad61fe9e9007 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 26 Feb 2023 17:14:28 +0200 Subject: [PATCH] whisper : add support for quantized models --- examples/whisper/whisper.cpp | 172 +++++++++++++++++++++-------------- 1 file changed, 105 insertions(+), 67 deletions(-) diff --git a/examples/whisper/whisper.cpp b/examples/whisper/whisper.cpp index 3a21581..90b7f7c 100644 --- a/examples/whisper/whisper.cpp +++ b/examples/whisper/whisper.cpp @@ -252,6 +252,8 @@ static const std::map MEM_REQ_SCRATCH3 = { { MODEL_LARGE, 9ull*MB }, }; +// TODO determine this on runtime +#if 0 // F16 static const std::map MEM_REQ_MODEL = { { MODEL_TINY, 74ull*MB }, { MODEL_BASE, 142ull*MB }, @@ -260,6 +262,17 @@ static const std::map MEM_REQ_MODEL = { { MODEL_LARGE, 2952ull*MB }, }; +#else // Q4 + +static const std::map MEM_REQ_MODEL = { + { MODEL_TINY, 1ull*MB }, // TODO + { MODEL_BASE, 1ull*MB }, // TODO + { MODEL_SMALL, 154ull*MB }, + { MODEL_MEDIUM, 470ull*MB }, + { MODEL_LARGE, 940ull*MB }, +}; +#endif + static const std::map MEM_REQ_KV_SELF = { { MODEL_TINY, 3ull*MB }, { MODEL_BASE, 6ull*MB }, @@ -681,7 +694,7 @@ static bool kv_cache_reinit(struct whisper_kv_cache & cache) { const ggml_type wtype = cache.k->type; WHISPER_ASSERT(wtype == cache.v->type); - WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*ggml_type_size(wtype)); + WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*ggml_type_sizef(wtype)); struct ggml_init_params params; params.mem_size = cache.buf.size(); @@ -776,12 +789,25 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con model.type = e_model::MODEL_LARGE; } - // for the big tensors, we have the option to store the data in 16-bit floats + // for the big tensors, we have the option to store the data in 16-bit floats or quantized // in order to save memory and also to speed up the computation - wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32; + wctx.wtype = GGML_TYPE_COUNT; + switch (model.hparams.f16) { + case 0: wctx.wtype = GGML_TYPE_F32; break; + case 1: wctx.wtype = GGML_TYPE_F16; break; + case 2: wctx.wtype = GGML_TYPE_Q4_0; break; + case 3: wctx.wtype = GGML_TYPE_Q4_1; break; + default: + { + fprintf(stderr, "%s: invalid model (bad f16 value %d)\n", __func__, model.hparams.f16); + return false; + } + } const size_t scale = model.hparams.f16 ? 1 : 2; + static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", }; + fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab); fprintf(stderr, "%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx); fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); @@ -792,7 +818,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con fprintf(stderr, "%s: n_text_head = %d\n", __func__, hparams.n_text_head); fprintf(stderr, "%s: n_text_layer = %d\n", __func__, hparams.n_text_layer); fprintf(stderr, "%s: n_mels = %d\n", __func__, hparams.n_mels); - fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16); + fprintf(stderr, "%s: ftype = %s\n", __func__, ftype_str[model.hparams.f16]); fprintf(stderr, "%s: type = %d\n", __func__, model.type); // print memory requirements @@ -821,7 +847,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con wctx.model.buf = new std::vector(); wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type)); - if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, wctx.wtype, model.hparams.n_text_ctx)) { + if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, GGML_TYPE_F16, model.hparams.n_text_ctx)) { fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); return false; } @@ -831,7 +857,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); } - if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_CROSS.at(model.type), wctx.kv_cross, wctx.wtype, model.hparams.n_audio_ctx)) { + if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_CROSS.at(model.type), wctx.kv_cross, GGML_TYPE_F16, model.hparams.n_audio_ctx)) { fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__); return false; } @@ -963,92 +989,92 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con // encoder { - ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe; + ctx_size += n_audio_ctx*n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_pe; - ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w - ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_1_b + ctx_size += 3*n_mels*n_audio_state*ggml_type_sizef(GGML_TYPE_F16); // e_conv_1_w + ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_conv_1_b - ctx_size += 3*n_audio_state*n_audio_state*ggml_type_size(wtype); // e_conv_2_w - ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_2_b + ctx_size += 3*n_audio_state*n_audio_state*ggml_type_sizef(GGML_TYPE_F16); // e_conv_2_w + ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_conv_2_b - ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_w; - ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_b; + ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_ln_w; + ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_ln_b; } // decoder { - ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe; + ctx_size += n_text_ctx*n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_pe; - ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te; + ctx_size += n_vocab*n_text_state*ggml_type_sizef(wtype); // d_te; - ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_w; - ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_b; + ctx_size += n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_ln_w; + ctx_size += n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_ln_b; } // encoder layers { - ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w - ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b + ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_w + ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_b - ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_0_w - ctx_size += n_audio_layer*( 4*n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b + ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // mlp_0_w + ctx_size += n_audio_layer*( 4*n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_0_b - ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_1_w - ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b + ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // mlp_1_w + ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_1_b - ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w - ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b + ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_w + ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_b - ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_q_w - ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b + ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_q_w + ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_q_b - ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_k_w + ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_k_w - ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_v_w - ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b + ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_v_w + ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_v_b - ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_ln_1_w - ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b + ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_ln_1_w + ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_1_b } // decoder layers { - ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w - ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b + ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_w + ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_b - ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_0_w - ctx_size += n_text_layer*( 4*n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b + ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_sizef(wtype)); // mlp_0_w + ctx_size += n_text_layer*( 4*n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_0_b - ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_1_w - ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b + ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_sizef(wtype)); // mlp_1_w + ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_1_b - ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w - ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b + ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_w + ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_b - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_q_w - ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_q_w + ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_q_b - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_k_w + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_k_w - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_v_w - ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_v_w + ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_v_b - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_ln_1_w - ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_ln_1_w + ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_1_b // - ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_w - ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_b + ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_0_w + ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_0_b - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_q_w - ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_q_b + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_q_w + ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_q_b - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_k_w + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_k_w - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_v_w - ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_v_b + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_v_w + ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_v_b - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_ln_1_w - ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_ln_1_w + ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_1_b } ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead @@ -1094,10 +1120,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con { model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx); - model.e_conv_1_w = ggml_new_tensor_3d(ctx, wtype, 3, n_mels, n_audio_state); + model.e_conv_1_w = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, 3, n_mels, n_audio_state); model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state); - model.e_conv_2_w = ggml_new_tensor_3d(ctx, wtype, 3, n_audio_state, n_audio_state); + model.e_conv_2_w = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, 3, n_audio_state, n_audio_state); model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state); model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); @@ -1313,9 +1339,21 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con return false; } - const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t); + size_t bpe = 0; + + switch (ftype) { + case 0: bpe = ggml_type_size(GGML_TYPE_F32); break; + case 1: bpe = ggml_type_size(GGML_TYPE_F16); break; + case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break; + case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break; + default: + { + fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype); + return false; + } + }; - if (nelements*bpe != ggml_nbytes(tensor)) { + if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); return false; @@ -1513,14 +1551,14 @@ static bool whisper_encode( ggml_permute(ctx0, ggml_cpy(ctx0, Qcur, - ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)), + ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), 0, 2, 1, 3); struct ggml_tensor * K = ggml_permute(ctx0, ggml_cpy(ctx0, Kcur, - ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)), + ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), 0, 2, 1, 3); struct ggml_tensor * V = @@ -1530,7 +1568,7 @@ static bool whisper_encode( Vcur, n_state/n_head, n_head, n_ctx), 1, 2, 0, 3), - ggml_new_tensor_3d(ctx0, wctx.wtype, n_ctx, n_state/n_head, n_head) + ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_ctx, n_state/n_head, n_head) ); struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false); @@ -1546,7 +1584,7 @@ static bool whisper_encode( ggml_permute(ctx0, ggml_cpy(ctx0, Kcur, - ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)), + ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), 0, 2, 1, 3); // K * Q @@ -1564,7 +1602,7 @@ static bool whisper_encode( // ggml_permute(ctx0, // ggml_cpy(ctx0, // Vcur, - // ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)), + // ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), // 1, 2, 0, 3); //struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); @@ -1576,7 +1614,7 @@ static bool whisper_encode( Vcur, n_state/n_head, n_head, n_ctx), 0, 2, 1, 3), - ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_ctx, n_head) + ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_state/n_head, n_ctx, n_head) ); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, ggml_transpose(ctx0, V), KQ_soft_max); @@ -1634,7 +1672,7 @@ static bool whisper_encode( wctx.use_buf(ctx0, 0); cur = ggml_flash_ff(ctx0, - ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wctx.wtype, n_state, n_ctx)), + ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_state, n_ctx)), layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); #else wctx.use_buf(ctx0, 0);