|
|
|
@ -1705,18 +1705,18 @@ static bool whisper_encode(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// cur
|
|
|
|
|
{
|
|
|
|
|
printf("ne0 = %d\n", cur->ne[0]);
|
|
|
|
|
printf("ne1 = %d\n", cur->ne[1]);
|
|
|
|
|
for (int i = 0; i < 10; ++i) {
|
|
|
|
|
printf("%8.4f ", ((float *)(cur->data))[i]);
|
|
|
|
|
}
|
|
|
|
|
printf("... ");
|
|
|
|
|
for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
|
|
|
|
|
printf("%8.4f ", ((float *)(cur->data))[i]);
|
|
|
|
|
}
|
|
|
|
|
printf("\n");
|
|
|
|
|
}
|
|
|
|
|
//{
|
|
|
|
|
// printf("ne0 = %d\n", cur->ne[0]);
|
|
|
|
|
// printf("ne1 = %d\n", cur->ne[1]);
|
|
|
|
|
// for (int i = 0; i < 10; ++i) {
|
|
|
|
|
// printf("%8.4f ", ((float *)(cur->data))[i]);
|
|
|
|
|
// }
|
|
|
|
|
// printf("... ");
|
|
|
|
|
// for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
|
|
|
|
|
// printf("%8.4f ", ((float *)(cur->data))[i]);
|
|
|
|
|
// }
|
|
|
|
|
// printf("\n");
|
|
|
|
|
//}
|
|
|
|
|
|
|
|
|
|
// pre-compute cross-attention memory
|
|
|
|
|
{
|
|
|
|
@ -1768,12 +1768,12 @@ static bool whisper_encode(
|
|
|
|
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
|
|
|
|
ggml_used_mem(ctx0)/1024.0/1024.0,
|
|
|
|
|
wctx.get_buf_max_mem(0)/1024.0/1024.0,
|
|
|
|
|
wctx.get_buf_max_mem(1)/1024.0/1024.0,
|
|
|
|
|
wctx.get_buf_max_mem(2)/1024.0/1024.0,
|
|
|
|
|
wctx.get_buf_max_mem(3)/1024.0/1024.0);
|
|
|
|
|
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
|
|
|
|
// ggml_used_mem(ctx0)/1024.0/1024.0,
|
|
|
|
|
// wctx.get_buf_max_mem(0)/1024.0/1024.0,
|
|
|
|
|
// wctx.get_buf_max_mem(1)/1024.0/1024.0,
|
|
|
|
|
// wctx.get_buf_max_mem(2)/1024.0/1024.0,
|
|
|
|
|
// wctx.get_buf_max_mem(3)/1024.0/1024.0);
|
|
|
|
|
|
|
|
|
|
ggml_free(ctx0);
|
|
|
|
|
|
|
|
|
@ -1900,8 +1900,6 @@ static bool whisper_decode(
|
|
|
|
|
Vcur),
|
|
|
|
|
Vcur);
|
|
|
|
|
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
|
|
|
|
|
// store key and value to memory
|
|
|
|
|
{
|
|
|
|
|
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past));
|
|
|
|
@ -1913,7 +1911,7 @@ static bool whisper_decode(
|
|
|
|
|
|
|
|
|
|
// ------
|
|
|
|
|
|
|
|
|
|
wctx.use_buf(ctx0, 1);
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * Q =
|
|
|
|
|
ggml_permute(ctx0,
|
|
|
|
@ -1929,12 +1927,12 @@ static bool whisper_decode(
|
|
|
|
|
n_state/n_head, n_head, n_past + N),
|
|
|
|
|
0, 2, 1, 3);
|
|
|
|
|
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
wctx.use_buf(ctx0, 1);
|
|
|
|
|
|
|
|
|
|
// K * Q
|
|
|
|
|
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
|
|
|
|
|
|
|
|
|
wctx.use_buf(ctx0, 1);
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
|
|
|
|
|
//struct ggml_tensor * KQ_scaled =
|
|
|
|
|
// ggml_scale(ctx0,
|
|
|
|
@ -1944,11 +1942,11 @@ static bool whisper_decode(
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
|
|
|
|
|
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
wctx.use_buf(ctx0, 1);
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
|
|
|
|
|
|
|
|
|
|
wctx.use_buf(ctx0, 1);
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * V_trans =
|
|
|
|
|
ggml_permute(ctx0,
|
|
|
|
@ -1957,14 +1955,12 @@ static bool whisper_decode(
|
|
|
|
|
n_state/n_head, n_head, n_past + N),
|
|
|
|
|
1, 2, 0, 3);
|
|
|
|
|
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
wctx.use_buf(ctx0, 1);
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
|
|
|
|
|
|
|
|
|
wctx.use_buf(ctx0, 1);
|
|
|
|
|
|
|
|
|
|
cur = ggml_cpy(ctx0,
|
|
|
|
|
KQV_merged,
|
|
|
|
|
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N));
|
|
|
|
@ -2200,12 +2196,12 @@ static bool whisper_decode(
|
|
|
|
|
memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
|
|
|
|
|
|
|
|
|
|
if (N > 1) {
|
|
|
|
|
printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
|
|
|
|
ggml_used_mem(ctx0)/1024.0/1024.0,
|
|
|
|
|
wctx.get_buf_max_mem(0)/1024.0/1024.0,
|
|
|
|
|
wctx.get_buf_max_mem(1)/1024.0/1024.0,
|
|
|
|
|
wctx.get_buf_max_mem(2)/1024.0/1024.0,
|
|
|
|
|
wctx.get_buf_max_mem(3)/1024.0/1024.0);
|
|
|
|
|
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
|
|
|
|
// ggml_used_mem(ctx0)/1024.0/1024.0,
|
|
|
|
|
// wctx.get_buf_max_mem(0)/1024.0/1024.0,
|
|
|
|
|
// wctx.get_buf_max_mem(1)/1024.0/1024.0,
|
|
|
|
|
// wctx.get_buf_max_mem(2)/1024.0/1024.0,
|
|
|
|
|
// wctx.get_buf_max_mem(3)/1024.0/1024.0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ggml_free(ctx0);
|
|
|
|
|