From ad2a4ffa03c1dc9eac859de1c2dad720954720a0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 5 Jan 2023 22:56:25 +0200 Subject: [PATCH] whisper : do not use F16 tensors when in F32 mode (#369) --- whisper.cpp | 40 +++++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index a61eded..e6d050a 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -412,6 +412,8 @@ struct whisper_context { std::vector buf_compute; std::vector buf_compute_layer; + ggml_type wtype; // weight type (FP32 or FP16) + whisper_model model; whisper_vocab vocab; @@ -629,7 +631,9 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx // for the big tensors, we have the option to store the data in 16-bit floats // in order to save memory and also to speed up the computation - const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32; + wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32; + + const ggml_type wtype = wctx.wtype; size_t ctx_size = 0; @@ -650,7 +654,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx // encoder { - // TODO: F16 .. maybe not? ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe; ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w @@ -665,7 +668,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx // decoder { - // TODO: F16 .. maybe not? ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe; ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te; @@ -982,8 +984,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx const int n_mem = n_text_layer*n_text_ctx; const int n_elements = n_text_state*n_mem; - model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); - model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); + model.memory_k = ggml_new_tensor_1d(ctx, wtype, n_elements); + model.memory_v = ggml_new_tensor_1d(ctx, wtype, n_elements); } // key/value memory for the cross-attention layer @@ -993,8 +995,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx const int n_mem = n_text_layer*n_audio_ctx; const int n_elements = n_text_state*n_mem; - model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); - model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); + model.memory_cross_k = ggml_new_tensor_1d(ctx, wtype, n_elements); + model.memory_cross_v = ggml_new_tensor_1d(ctx, wtype, n_elements); } const size_t memory_size = @@ -1240,14 +1242,14 @@ static bool whisper_encode( ggml_permute(ctxL, ggml_cpy(ctxL, Qcur, - ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), + ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)), 0, 2, 1, 3); struct ggml_tensor * K = ggml_permute(ctxL, ggml_cpy(ctxL, Kcur, - ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), + ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)), 0, 2, 1, 3); struct ggml_tensor * V = @@ -1257,7 +1259,7 @@ static bool whisper_encode( Vcur, n_state/n_head, n_head, n_ctx), 1, 2, 0, 3), - ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_ctx, n_state/n_head, n_head) + ggml_new_tensor_3d(ctxL, wctx.wtype, n_ctx, n_state/n_head, n_head) ); struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false); @@ -1273,7 +1275,7 @@ static bool whisper_encode( ggml_permute(ctxL, ggml_cpy(ctxL, Kcur, - ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), + ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)), 0, 2, 1, 3); // K * Q @@ -1291,7 +1293,7 @@ static bool whisper_encode( // ggml_permute(ctxL, // ggml_cpy(ctxL, // Vcur, - // ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), + // ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)), // 1, 2, 0, 3); //struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max); @@ -1303,7 +1305,7 @@ static bool whisper_encode( Vcur, n_state/n_head, n_head, n_ctx), 0, 2, 1, 3), - ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_ctx, n_head) + ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_ctx, n_head) ); struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max); @@ -1348,7 +1350,7 @@ static bool whisper_encode( #ifdef USE_FLASH_FF cur = ggml_flash_ff(ctxL, - ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, GGML_TYPE_F16, n_state, N)), + ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, wctx.wtype, n_state, N)), layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); #else // fully connected @@ -3156,7 +3158,7 @@ int whisper_full_parallel( // separate key + value memory for each processor { - auto & ctx = model.ctx_mem; + auto & mctx = model.ctx_mem; const auto & hparams = model.hparams; @@ -3169,8 +3171,8 @@ int whisper_full_parallel( const int n_mem = n_text_layer*n_text_ctx; const int n_elements = n_text_state*n_mem; - model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); - model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); + model.memory_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements); + model.memory_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements); } // key/value memory for the cross-attention layer @@ -3180,8 +3182,8 @@ int whisper_full_parallel( const int n_mem = n_text_layer*n_audio_ctx; const int n_elements = n_text_state*n_mem; - model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); - model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); + model.memory_cross_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements); + model.memory_cross_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements); } } }