diff --git a/whisper.cpp b/whisper.cpp index 734fb78..21e559f 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -99,35 +99,12 @@ static void byteswap_tensor(ggml_tensor * tensor) { #define WHISPER_PRINT_DEBUG(...) #endif -#define WHISPER_USE_SCRATCH #define WHISPER_USE_FLASH_ATTN //#define WHISPER_USE_FLASH_FF #define WHISPER_MAX_DECODERS 16 -// TODO: make these part of the whisper_context instead of macros -#if defined(WHISPER_USE_SCRATCH) -static int g_buf_last = 0; -static size_t g_buf_max_size[16] = { 0 }; -#define USE_BUF(wctx, gctx, i) \ - do { \ - size_t last_size = 0; \ - if (i == -1) { \ - last_size = ggml_set_scratch(gctx, { 0, 0, nullptr, }); \ - } else { \ - auto & buf = wctx.buf_scratch[i]; \ - last_size = ggml_set_scratch(gctx, { 0, buf.size(), buf.data(), }); \ - } \ - if (g_buf_last >= 0) { \ - g_buf_max_size[g_buf_last] = std::max(g_buf_max_size[g_buf_last], last_size); \ - } \ - g_buf_last = i; \ - } while (0) - -#define GET_BUF_MAX_MEM(i) g_buf_max_size[i] -#else -#define USE_BUF(wctx, gctx, i) do {} while (0) -#define GET_BUF_MAX_MEM(i) 0 -#endif +#define WHISPER_USE_SCRATCH +#define WHISPER_MAX_SCRATCH_BUFFERS 16 // available whisper models enum e_model { @@ -598,8 +575,11 @@ struct whisper_context { whisper_decoder decoders[WHISPER_MAX_DECODERS] = {}; // memory buffers used by encode / decode contexts - std::vector buf_scratch[4]; std::vector buf_compute; + std::vector buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS]; + + int buf_last = 0; + size_t buf_max_size[WHISPER_MAX_SCRATCH_BUFFERS] = { 0 }; // decode output (2-dimensional array: [n_tokens][n_vocab]) std::vector logits; @@ -620,6 +600,36 @@ struct whisper_context { // [EXPERIMENTAL] speed-up techniques int32_t exp_n_audio_ctx; // 0 - use default + + void use_buf(struct ggml_context * ctx, int i) { +#if defined(WHISPER_USE_SCRATCH) + size_t last_size = 0; + + if (i == -1) { + last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, }); + } else { + auto & buf = buf_scratch[i]; + last_size = ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), }); + } + + if (buf_last >= 0) { + buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size); + } + + buf_last = i; +#else + (void) i; +#endif + } + + size_t get_buf_max_mem(int i) const { +#if defined(WHISPER_USE_SCRATCH) + return buf_max_size[i]; +#else + (void) i; + return 0; +#endif + } }; template @@ -828,12 +838,12 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); } + wctx.buf_compute.resize(scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); + wctx.buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(model.type)); wctx.buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(model.type)); wctx.buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(model.type)); wctx.buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(model.type)); - - wctx.buf_compute.resize(scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); } // load mel filters @@ -1366,7 +1376,7 @@ static bool whisper_encode( struct ggml_context * ctx0 = ggml_init(params); - USE_BUF(wctx, ctx0, 0); + wctx.use_buf(ctx0, 0); struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels); assert(mel->type == GGML_TYPE_F32); @@ -1388,7 +1398,7 @@ static bool whisper_encode( // convolution + gelu { - USE_BUF(wctx, ctx0, 1); + wctx.use_buf(ctx0, 1); cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel); cur = ggml_add(ctx0, @@ -1399,7 +1409,7 @@ static bool whisper_encode( cur = ggml_gelu(ctx0, cur); - USE_BUF(wctx, ctx0, 0); + wctx.use_buf(ctx0, 0); cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur); cur = ggml_add(ctx0, @@ -1411,7 +1421,7 @@ static bool whisper_encode( cur = ggml_gelu(ctx0, cur); } - USE_BUF(wctx, ctx0, 3); + wctx.use_buf(ctx0, 3); // =================================================================== // NOTE: experimenting with partial evaluation of the encoder (ignore) @@ -1446,7 +1456,7 @@ static bool whisper_encode( // norm { - USE_BUF(wctx, ctx0, 0); + wctx.use_buf(ctx0, 0); cur = ggml_norm(ctx0, inpL); @@ -1460,7 +1470,7 @@ static bool whisper_encode( // self-attention { - USE_BUF(wctx, ctx0, 1); + wctx.use_buf(ctx0, 1); struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.attn_q_w, @@ -1493,7 +1503,7 @@ static bool whisper_encode( // ------ - USE_BUF(wctx, ctx0, 0); + wctx.use_buf(ctx0, 0); #ifdef WHISPER_USE_FLASH_ATTN struct ggml_tensor * Q = @@ -1570,7 +1580,7 @@ static bool whisper_encode( #endif struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - USE_BUF(wctx, ctx0, 1); + wctx.use_buf(ctx0, 1); cur = ggml_cpy(ctx0, KQV_merged, @@ -1579,20 +1589,20 @@ static bool whisper_encode( // projection { - USE_BUF(wctx, ctx0, 0); + wctx.use_buf(ctx0, 0); cur = ggml_mul_mat(ctx0, layer.attn_ln_1_w, cur); - USE_BUF(wctx, ctx0, 1); + wctx.use_buf(ctx0, 1); cur = ggml_add(ctx0, ggml_repeat(ctx0, layer.attn_ln_1_b, cur), cur); } - USE_BUF(wctx, ctx0, 2); + wctx.use_buf(ctx0, 2); // add the input cur = ggml_add(ctx0, cur, inpL); @@ -1603,11 +1613,11 @@ static bool whisper_encode( { // norm { - USE_BUF(wctx, ctx0, 0); + wctx.use_buf(ctx0, 0); cur = ggml_norm(ctx0, inpFF); - USE_BUF(wctx, ctx0, 1); + wctx.use_buf(ctx0, 1); // cur = mlp_ln_w*cur + mlp_ln_b cur = ggml_add(ctx0, @@ -1618,38 +1628,38 @@ static bool whisper_encode( } #ifdef WHISPER_USE_FLASH_FF - USE_BUF(wctx, ctx0, 0); + wctx.use_buf(ctx0, 0); cur = ggml_flash_ff(ctx0, ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wctx.wtype, n_state, N)), layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); #else - USE_BUF(wctx, ctx0, 0); + wctx.use_buf(ctx0, 0); // fully connected cur = ggml_mul_mat(ctx0, layer.mlp_0_w, cur); - USE_BUF(wctx, ctx0, 1); + wctx.use_buf(ctx0, 1); cur = ggml_add(ctx0, ggml_repeat(ctx0, layer.mlp_0_b, cur), cur); - USE_BUF(wctx, ctx0, 0); + wctx.use_buf(ctx0, 0); // GELU activation cur = ggml_gelu(ctx0, cur); - USE_BUF(wctx, ctx0, 1); + wctx.use_buf(ctx0, 1); // projection cur = ggml_mul_mat(ctx0, layer.mlp_1_w, cur); - USE_BUF(wctx, ctx0, 0); + wctx.use_buf(ctx0, 0); cur = ggml_add(ctx0, ggml_repeat(ctx0, layer.mlp_1_b, cur), @@ -1657,7 +1667,7 @@ static bool whisper_encode( #endif } - USE_BUF(wctx, ctx0, 3); + wctx.use_buf(ctx0, 3); inpL = ggml_add(ctx0, cur, inpFF); } @@ -1666,11 +1676,11 @@ static bool whisper_encode( // norm { - USE_BUF(wctx, ctx0, 0); + wctx.use_buf(ctx0, 0); cur = ggml_norm(ctx0, cur); - USE_BUF(wctx, ctx0, 1); + wctx.use_buf(ctx0, 1); // cur = ln_f_g*cur + ln_f_b cur = ggml_add(ctx0, @@ -1680,7 +1690,7 @@ static bool whisper_encode( ggml_repeat(ctx0, model.e_ln_b, cur)); } - USE_BUF(wctx, ctx0, -1); + wctx.use_buf(ctx0, -1); // run the computation { @@ -1720,7 +1730,7 @@ static bool whisper_encode( for (int il = 0; il < model.hparams.n_text_layer; ++il) { auto & layer = model.layers_decoder[il]; - USE_BUF(wctx, ctx0, 0); + wctx.use_buf(ctx0, 0); struct ggml_tensor * Kcross = ggml_mul_mat(ctx0, layer.cross_attn_k_w, @@ -1728,7 +1738,7 @@ static bool whisper_encode( Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); - USE_BUF(wctx, ctx0, 1); + wctx.use_buf(ctx0, 1); struct ggml_tensor * Vcross = ggml_mul_mat(ctx0, layer.cross_attn_v_w, @@ -1740,7 +1750,7 @@ static bool whisper_encode( Vcross), Vcross); - USE_BUF(wctx, ctx0, -1); + wctx.use_buf(ctx0, -1); //struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); //struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); @@ -1757,8 +1767,12 @@ static bool whisper_encode( //////////////////////////////////////////////////////////////////////////// - printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, - GET_BUF_MAX_MEM(0)/1024.0/1024.0, GET_BUF_MAX_MEM(1)/1024.0/1024.0, GET_BUF_MAX_MEM(2)/1024.0/1024.0, GET_BUF_MAX_MEM(3)/1024.0/1024.0); + printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, + ggml_used_mem(ctx0)/1024.0/1024.0, + wctx.get_buf_max_mem(0)/1024.0/1024.0, + wctx.get_buf_max_mem(1)/1024.0/1024.0, + wctx.get_buf_max_mem(2)/1024.0/1024.0, + wctx.get_buf_max_mem(3)/1024.0/1024.0); ggml_free(ctx0); @@ -1825,7 +1839,7 @@ static bool whisper_decode( ((int32_t *) position->data)[i] = n_past + i; } - USE_BUF(wctx, ctx0, 3); + wctx.use_buf(ctx0, 3); // token encoding + position encoding struct ggml_tensor * cur = @@ -1840,7 +1854,7 @@ static bool whisper_decode( // norm { - USE_BUF(wctx, ctx0, 0); + wctx.use_buf(ctx0, 0); cur = ggml_norm(ctx0, inpL); @@ -1854,7 +1868,7 @@ static bool whisper_decode( // self-attention { - USE_BUF(wctx, ctx0, 1); + wctx.use_buf(ctx0, 1); struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.attn_q_w, @@ -1885,7 +1899,7 @@ static bool whisper_decode( Vcur), Vcur); - USE_BUF(wctx, ctx0, 0); + wctx.use_buf(ctx0, 0); // store key and value to memory { @@ -1898,7 +1912,7 @@ static bool whisper_decode( // ------ - USE_BUF(wctx, ctx0, 1); + wctx.use_buf(ctx0, 1); struct ggml_tensor * Q = ggml_permute(ctx0, @@ -1914,12 +1928,12 @@ static bool whisper_decode( n_state/n_head, n_head, n_past + N), 0, 2, 1, 3); - USE_BUF(wctx, ctx0, 0); + wctx.use_buf(ctx0, 0); // K * Q struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - USE_BUF(wctx, ctx0, 1); + wctx.use_buf(ctx0, 1); //struct ggml_tensor * KQ_scaled = // ggml_scale(ctx0, @@ -1929,11 +1943,11 @@ static bool whisper_decode( struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past); - USE_BUF(wctx, ctx0, 0); + wctx.use_buf(ctx0, 0); struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); - USE_BUF(wctx, ctx0, 1); + wctx.use_buf(ctx0, 1); struct ggml_tensor * V_trans = ggml_permute(ctx0, @@ -1942,13 +1956,13 @@ static bool whisper_decode( n_state/n_head, n_head, n_past + N), 1, 2, 0, 3); - USE_BUF(wctx, ctx0, 0); + wctx.use_buf(ctx0, 0); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - USE_BUF(wctx, ctx0, 1); + wctx.use_buf(ctx0, 1); cur = ggml_cpy(ctx0, KQV_merged, @@ -1957,31 +1971,31 @@ static bool whisper_decode( // projection { - USE_BUF(wctx, ctx0, 0); + wctx.use_buf(ctx0, 0); cur = ggml_mul_mat(ctx0, layer.attn_ln_1_w, cur); - USE_BUF(wctx, ctx0, 1); + wctx.use_buf(ctx0, 1); cur = ggml_add(ctx0, ggml_repeat(ctx0, layer.attn_ln_1_b, cur), cur); } - USE_BUF(wctx, ctx0, 2); + wctx.use_buf(ctx0, 2); // add the input struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL); // norm { - USE_BUF(wctx, ctx0, 0); + wctx.use_buf(ctx0, 0); cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here - USE_BUF(wctx, ctx0, 1); + wctx.use_buf(ctx0, 1); // cur = ln_0_w*cur + ln_0_b cur = ggml_add(ctx0, @@ -1993,7 +2007,7 @@ static bool whisper_decode( // cross-attention { - USE_BUF(wctx, ctx0, 0); + wctx.use_buf(ctx0, 0); struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.cross_attn_q_w, @@ -2022,7 +2036,7 @@ static bool whisper_decode( // ------ - USE_BUF(wctx, ctx0, 1); + wctx.use_buf(ctx0, 1); struct ggml_tensor * Q = ggml_permute(ctx0, @@ -2033,7 +2047,7 @@ static bool whisper_decode( struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3); - USE_BUF(wctx, ctx0, 0); + wctx.use_buf(ctx0, 0); // K * Q struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); @@ -2047,15 +2061,15 @@ static bool whisper_decode( // no masking for cross-attention //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); - USE_BUF(wctx, ctx0, 1); + wctx.use_buf(ctx0, 1); struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ); - USE_BUF(wctx, ctx0, 0); + wctx.use_buf(ctx0, 0); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); - USE_BUF(wctx, ctx0, 1); + wctx.use_buf(ctx0, 1); struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); @@ -2067,20 +2081,20 @@ static bool whisper_decode( // projection { - USE_BUF(wctx, ctx0, 0); + wctx.use_buf(ctx0, 0); cur = ggml_mul_mat(ctx0, layer.cross_attn_ln_1_w, cur); - USE_BUF(wctx, ctx0, 1); + wctx.use_buf(ctx0, 1); cur = ggml_add(ctx0, ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur), cur); } - USE_BUF(wctx, ctx0, 2); + wctx.use_buf(ctx0, 2); // add the input cur = ggml_add(ctx0, cur, inpCA); @@ -2091,11 +2105,11 @@ static bool whisper_decode( { // norm { - USE_BUF(wctx, ctx0, 0); + wctx.use_buf(ctx0, 0); cur = ggml_norm(ctx0, inpFF); - USE_BUF(wctx, ctx0, 1); + wctx.use_buf(ctx0, 1); // cur = mlp_ln_w*cur + mlp_ln_b cur = ggml_add(ctx0, @@ -2105,39 +2119,39 @@ static bool whisper_decode( ggml_repeat(ctx0, layer.mlp_ln_b, cur)); } - USE_BUF(wctx, ctx0, 0); + wctx.use_buf(ctx0, 0); // fully connected cur = ggml_mul_mat(ctx0, layer.mlp_0_w, cur); - USE_BUF(wctx, ctx0, 1); + wctx.use_buf(ctx0, 1); cur = ggml_add(ctx0, ggml_repeat(ctx0, layer.mlp_0_b, cur), cur); - USE_BUF(wctx, ctx0, 0); + wctx.use_buf(ctx0, 0); // GELU activation cur = ggml_gelu(ctx0, cur); - USE_BUF(wctx, ctx0, 1); + wctx.use_buf(ctx0, 1); // projection cur = ggml_mul_mat(ctx0, layer.mlp_1_w, cur); - USE_BUF(wctx, ctx0, 0); + wctx.use_buf(ctx0, 0); cur = ggml_add(ctx0, ggml_repeat(ctx0, layer.mlp_1_b, cur), cur); } - USE_BUF(wctx, ctx0, 3); + wctx.use_buf(ctx0, 3); inpL = ggml_add(ctx0, cur, inpFF); } @@ -2146,11 +2160,11 @@ static bool whisper_decode( // norm { - USE_BUF(wctx, ctx0, 0); + wctx.use_buf(ctx0, 0); cur = ggml_norm(ctx0, cur); - USE_BUF(wctx, ctx0, 1); + wctx.use_buf(ctx0, 1); cur = ggml_add(ctx0, ggml_mul(ctx0, @@ -2159,7 +2173,7 @@ static bool whisper_decode( ggml_repeat(ctx0, model.d_ln_b, cur)); } - USE_BUF(wctx, ctx0, 0); + wctx.use_buf(ctx0, 0); // compute logits only for the last token // comment this line to compute logits for all N tokens @@ -2168,7 +2182,7 @@ static bool whisper_decode( struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur); - USE_BUF(wctx, ctx0, -1); + wctx.use_buf(ctx0, -1); // run the computation { @@ -2185,8 +2199,12 @@ static bool whisper_decode( memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab); if (N > 1) { - printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, - GET_BUF_MAX_MEM(0)/1024.0/1024.0, GET_BUF_MAX_MEM(1)/1024.0/1024.0, GET_BUF_MAX_MEM(2)/1024.0/1024.0, GET_BUF_MAX_MEM(3)/1024.0/1024.0); + printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, + ggml_used_mem(ctx0)/1024.0/1024.0, + wctx.get_buf_max_mem(0)/1024.0/1024.0, + wctx.get_buf_max_mem(1)/1024.0/1024.0, + wctx.get_buf_max_mem(2)/1024.0/1024.0, + wctx.get_buf_max_mem(3)/1024.0/1024.0); } ggml_free(ctx0);