whisper : reorder scratch buffers in the decoder

pull/431/head
Georgi Gerganov 2 years ago
parent 62205aedb9
commit 01669ee8b8
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -1705,18 +1705,18 @@ static bool whisper_encode(
}
// cur
{
printf("ne0 = %d\n", cur->ne[0]);
printf("ne1 = %d\n", cur->ne[1]);
for (int i = 0; i < 10; ++i) {
printf("%8.4f ", ((float *)(cur->data))[i]);
}
printf("... ");
for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
printf("%8.4f ", ((float *)(cur->data))[i]);
}
printf("\n");
}
//{
// printf("ne0 = %d\n", cur->ne[0]);
// printf("ne1 = %d\n", cur->ne[1]);
// for (int i = 0; i < 10; ++i) {
// printf("%8.4f ", ((float *)(cur->data))[i]);
// }
// printf("... ");
// for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
// printf("%8.4f ", ((float *)(cur->data))[i]);
// }
// printf("\n");
//}
// pre-compute cross-attention memory
{
@ -1768,12 +1768,12 @@ static bool whisper_encode(
////////////////////////////////////////////////////////////////////////////
printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
ggml_used_mem(ctx0)/1024.0/1024.0,
wctx.get_buf_max_mem(0)/1024.0/1024.0,
wctx.get_buf_max_mem(1)/1024.0/1024.0,
wctx.get_buf_max_mem(2)/1024.0/1024.0,
wctx.get_buf_max_mem(3)/1024.0/1024.0);
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
// ggml_used_mem(ctx0)/1024.0/1024.0,
// wctx.get_buf_max_mem(0)/1024.0/1024.0,
// wctx.get_buf_max_mem(1)/1024.0/1024.0,
// wctx.get_buf_max_mem(2)/1024.0/1024.0,
// wctx.get_buf_max_mem(3)/1024.0/1024.0);
ggml_free(ctx0);
@ -1900,8 +1900,6 @@ static bool whisper_decode(
Vcur),
Vcur);
wctx.use_buf(ctx0, 0);
// store key and value to memory
{
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past));
@ -1913,7 +1911,7 @@ static bool whisper_decode(
// ------
wctx.use_buf(ctx0, 1);
wctx.use_buf(ctx0, 0);
struct ggml_tensor * Q =
ggml_permute(ctx0,
@ -1929,12 +1927,12 @@ static bool whisper_decode(
n_state/n_head, n_head, n_past + N),
0, 2, 1, 3);
wctx.use_buf(ctx0, 0);
wctx.use_buf(ctx0, 1);
// K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
wctx.use_buf(ctx0, 1);
wctx.use_buf(ctx0, 0);
//struct ggml_tensor * KQ_scaled =
// ggml_scale(ctx0,
@ -1944,11 +1942,11 @@ static bool whisper_decode(
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
wctx.use_buf(ctx0, 0);
wctx.use_buf(ctx0, 1);
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
wctx.use_buf(ctx0, 1);
wctx.use_buf(ctx0, 0);
struct ggml_tensor * V_trans =
ggml_permute(ctx0,
@ -1957,14 +1955,12 @@ static bool whisper_decode(
n_state/n_head, n_head, n_past + N),
1, 2, 0, 3);
wctx.use_buf(ctx0, 0);
wctx.use_buf(ctx0, 1);
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
wctx.use_buf(ctx0, 1);
cur = ggml_cpy(ctx0,
KQV_merged,
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N));
@ -2200,12 +2196,12 @@ static bool whisper_decode(
memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
if (N > 1) {
printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
ggml_used_mem(ctx0)/1024.0/1024.0,
wctx.get_buf_max_mem(0)/1024.0/1024.0,
wctx.get_buf_max_mem(1)/1024.0/1024.0,
wctx.get_buf_max_mem(2)/1024.0/1024.0,
wctx.get_buf_max_mem(3)/1024.0/1024.0);
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
// ggml_used_mem(ctx0)/1024.0/1024.0,
// wctx.get_buf_max_mem(0)/1024.0/1024.0,
// wctx.get_buf_max_mem(1)/1024.0/1024.0,
// wctx.get_buf_max_mem(2)/1024.0/1024.0,
// wctx.get_buf_max_mem(3)/1024.0/1024.0);
}
ggml_free(ctx0);

Loading…
Cancel
Save