whisper : use whisper_context functions instead of macros

pull/431/head
Georgi Gerganov 2 years ago
parent 79148a2166
commit 42d7dee4c2
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -99,35 +99,12 @@ static void byteswap_tensor(ggml_tensor * tensor) {
#define WHISPER_PRINT_DEBUG(...)
#endif
#define WHISPER_USE_SCRATCH
#define WHISPER_USE_FLASH_ATTN
//#define WHISPER_USE_FLASH_FF
#define WHISPER_MAX_DECODERS 16
// TODO: make these part of the whisper_context instead of macros
#if defined(WHISPER_USE_SCRATCH)
static int g_buf_last = 0;
static size_t g_buf_max_size[16] = { 0 };
#define USE_BUF(wctx, gctx, i) \
do { \
size_t last_size = 0; \
if (i == -1) { \
last_size = ggml_set_scratch(gctx, { 0, 0, nullptr, }); \
} else { \
auto & buf = wctx.buf_scratch[i]; \
last_size = ggml_set_scratch(gctx, { 0, buf.size(), buf.data(), }); \
} \
if (g_buf_last >= 0) { \
g_buf_max_size[g_buf_last] = std::max(g_buf_max_size[g_buf_last], last_size); \
} \
g_buf_last = i; \
} while (0)
#define GET_BUF_MAX_MEM(i) g_buf_max_size[i]
#else
#define USE_BUF(wctx, gctx, i) do {} while (0)
#define GET_BUF_MAX_MEM(i) 0
#endif
#define WHISPER_USE_SCRATCH
#define WHISPER_MAX_SCRATCH_BUFFERS 16
// available whisper models
enum e_model {
@ -598,8 +575,11 @@ struct whisper_context {
whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
// memory buffers used by encode / decode contexts
std::vector<uint8_t> buf_scratch[4];
std::vector<uint8_t> buf_compute;
std::vector<uint8_t> buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS];
int buf_last = 0;
size_t buf_max_size[WHISPER_MAX_SCRATCH_BUFFERS] = { 0 };
// decode output (2-dimensional array: [n_tokens][n_vocab])
std::vector<float> logits;
@ -620,6 +600,36 @@ struct whisper_context {
// [EXPERIMENTAL] speed-up techniques
int32_t exp_n_audio_ctx; // 0 - use default
void use_buf(struct ggml_context * ctx, int i) {
#if defined(WHISPER_USE_SCRATCH)
size_t last_size = 0;
if (i == -1) {
last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, });
} else {
auto & buf = buf_scratch[i];
last_size = ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), });
}
if (buf_last >= 0) {
buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size);
}
buf_last = i;
#else
(void) i;
#endif
}
size_t get_buf_max_mem(int i) const {
#if defined(WHISPER_USE_SCRATCH)
return buf_max_size[i];
#else
(void) i;
return 0;
#endif
}
};
template<typename T>
@ -828,12 +838,12 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
}
wctx.buf_compute.resize(scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
wctx.buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(model.type));
wctx.buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(model.type));
wctx.buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(model.type));
wctx.buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(model.type));
wctx.buf_compute.resize(scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
}
// load mel filters
@ -1366,7 +1376,7 @@ static bool whisper_encode(
struct ggml_context * ctx0 = ggml_init(params);
USE_BUF(wctx, ctx0, 0);
wctx.use_buf(ctx0, 0);
struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
assert(mel->type == GGML_TYPE_F32);
@ -1388,7 +1398,7 @@ static bool whisper_encode(
// convolution + gelu
{
USE_BUF(wctx, ctx0, 1);
wctx.use_buf(ctx0, 1);
cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel);
cur = ggml_add(ctx0,
@ -1399,7 +1409,7 @@ static bool whisper_encode(
cur = ggml_gelu(ctx0, cur);
USE_BUF(wctx, ctx0, 0);
wctx.use_buf(ctx0, 0);
cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur);
cur = ggml_add(ctx0,
@ -1411,7 +1421,7 @@ static bool whisper_encode(
cur = ggml_gelu(ctx0, cur);
}
USE_BUF(wctx, ctx0, 3);
wctx.use_buf(ctx0, 3);
// ===================================================================
// NOTE: experimenting with partial evaluation of the encoder (ignore)
@ -1446,7 +1456,7 @@ static bool whisper_encode(
// norm
{
USE_BUF(wctx, ctx0, 0);
wctx.use_buf(ctx0, 0);
cur = ggml_norm(ctx0, inpL);
@ -1460,7 +1470,7 @@ static bool whisper_encode(
// self-attention
{
USE_BUF(wctx, ctx0, 1);
wctx.use_buf(ctx0, 1);
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
layer.attn_q_w,
@ -1493,7 +1503,7 @@ static bool whisper_encode(
// ------
USE_BUF(wctx, ctx0, 0);
wctx.use_buf(ctx0, 0);
#ifdef WHISPER_USE_FLASH_ATTN
struct ggml_tensor * Q =
@ -1570,7 +1580,7 @@ static bool whisper_encode(
#endif
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
USE_BUF(wctx, ctx0, 1);
wctx.use_buf(ctx0, 1);
cur = ggml_cpy(ctx0,
KQV_merged,
@ -1579,20 +1589,20 @@ static bool whisper_encode(
// projection
{
USE_BUF(wctx, ctx0, 0);
wctx.use_buf(ctx0, 0);
cur = ggml_mul_mat(ctx0,
layer.attn_ln_1_w,
cur);
USE_BUF(wctx, ctx0, 1);
wctx.use_buf(ctx0, 1);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
cur);
}
USE_BUF(wctx, ctx0, 2);
wctx.use_buf(ctx0, 2);
// add the input
cur = ggml_add(ctx0, cur, inpL);
@ -1603,11 +1613,11 @@ static bool whisper_encode(
{
// norm
{
USE_BUF(wctx, ctx0, 0);
wctx.use_buf(ctx0, 0);
cur = ggml_norm(ctx0, inpFF);
USE_BUF(wctx, ctx0, 1);
wctx.use_buf(ctx0, 1);
// cur = mlp_ln_w*cur + mlp_ln_b
cur = ggml_add(ctx0,
@ -1618,38 +1628,38 @@ static bool whisper_encode(
}
#ifdef WHISPER_USE_FLASH_FF
USE_BUF(wctx, ctx0, 0);
wctx.use_buf(ctx0, 0);
cur = ggml_flash_ff(ctx0,
ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wctx.wtype, n_state, N)),
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
#else
USE_BUF(wctx, ctx0, 0);
wctx.use_buf(ctx0, 0);
// fully connected
cur = ggml_mul_mat(ctx0,
layer.mlp_0_w,
cur);
USE_BUF(wctx, ctx0, 1);
wctx.use_buf(ctx0, 1);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, layer.mlp_0_b, cur),
cur);
USE_BUF(wctx, ctx0, 0);
wctx.use_buf(ctx0, 0);
// GELU activation
cur = ggml_gelu(ctx0, cur);
USE_BUF(wctx, ctx0, 1);
wctx.use_buf(ctx0, 1);
// projection
cur = ggml_mul_mat(ctx0,
layer.mlp_1_w,
cur);
USE_BUF(wctx, ctx0, 0);
wctx.use_buf(ctx0, 0);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, layer.mlp_1_b, cur),
@ -1657,7 +1667,7 @@ static bool whisper_encode(
#endif
}
USE_BUF(wctx, ctx0, 3);
wctx.use_buf(ctx0, 3);
inpL = ggml_add(ctx0, cur, inpFF);
}
@ -1666,11 +1676,11 @@ static bool whisper_encode(
// norm
{
USE_BUF(wctx, ctx0, 0);
wctx.use_buf(ctx0, 0);
cur = ggml_norm(ctx0, cur);
USE_BUF(wctx, ctx0, 1);
wctx.use_buf(ctx0, 1);
// cur = ln_f_g*cur + ln_f_b
cur = ggml_add(ctx0,
@ -1680,7 +1690,7 @@ static bool whisper_encode(
ggml_repeat(ctx0, model.e_ln_b, cur));
}
USE_BUF(wctx, ctx0, -1);
wctx.use_buf(ctx0, -1);
// run the computation
{
@ -1720,7 +1730,7 @@ static bool whisper_encode(
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
auto & layer = model.layers_decoder[il];
USE_BUF(wctx, ctx0, 0);
wctx.use_buf(ctx0, 0);
struct ggml_tensor * Kcross = ggml_mul_mat(ctx0,
layer.cross_attn_k_w,
@ -1728,7 +1738,7 @@ static bool whisper_encode(
Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
USE_BUF(wctx, ctx0, 1);
wctx.use_buf(ctx0, 1);
struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
layer.cross_attn_v_w,
@ -1740,7 +1750,7 @@ static bool whisper_encode(
Vcross),
Vcross);
USE_BUF(wctx, ctx0, -1);
wctx.use_buf(ctx0, -1);
//struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
//struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
@ -1757,8 +1767,12 @@ static bool whisper_encode(
////////////////////////////////////////////////////////////////////////////
printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0,
GET_BUF_MAX_MEM(0)/1024.0/1024.0, GET_BUF_MAX_MEM(1)/1024.0/1024.0, GET_BUF_MAX_MEM(2)/1024.0/1024.0, GET_BUF_MAX_MEM(3)/1024.0/1024.0);
printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
ggml_used_mem(ctx0)/1024.0/1024.0,
wctx.get_buf_max_mem(0)/1024.0/1024.0,
wctx.get_buf_max_mem(1)/1024.0/1024.0,
wctx.get_buf_max_mem(2)/1024.0/1024.0,
wctx.get_buf_max_mem(3)/1024.0/1024.0);
ggml_free(ctx0);
@ -1825,7 +1839,7 @@ static bool whisper_decode(
((int32_t *) position->data)[i] = n_past + i;
}
USE_BUF(wctx, ctx0, 3);
wctx.use_buf(ctx0, 3);
// token encoding + position encoding
struct ggml_tensor * cur =
@ -1840,7 +1854,7 @@ static bool whisper_decode(
// norm
{
USE_BUF(wctx, ctx0, 0);
wctx.use_buf(ctx0, 0);
cur = ggml_norm(ctx0, inpL);
@ -1854,7 +1868,7 @@ static bool whisper_decode(
// self-attention
{
USE_BUF(wctx, ctx0, 1);
wctx.use_buf(ctx0, 1);
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
layer.attn_q_w,
@ -1885,7 +1899,7 @@ static bool whisper_decode(
Vcur),
Vcur);
USE_BUF(wctx, ctx0, 0);
wctx.use_buf(ctx0, 0);
// store key and value to memory
{
@ -1898,7 +1912,7 @@ static bool whisper_decode(
// ------
USE_BUF(wctx, ctx0, 1);
wctx.use_buf(ctx0, 1);
struct ggml_tensor * Q =
ggml_permute(ctx0,
@ -1914,12 +1928,12 @@ static bool whisper_decode(
n_state/n_head, n_head, n_past + N),
0, 2, 1, 3);
USE_BUF(wctx, ctx0, 0);
wctx.use_buf(ctx0, 0);
// K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
USE_BUF(wctx, ctx0, 1);
wctx.use_buf(ctx0, 1);
//struct ggml_tensor * KQ_scaled =
// ggml_scale(ctx0,
@ -1929,11 +1943,11 @@ static bool whisper_decode(
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
USE_BUF(wctx, ctx0, 0);
wctx.use_buf(ctx0, 0);
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
USE_BUF(wctx, ctx0, 1);
wctx.use_buf(ctx0, 1);
struct ggml_tensor * V_trans =
ggml_permute(ctx0,
@ -1942,13 +1956,13 @@ static bool whisper_decode(
n_state/n_head, n_head, n_past + N),
1, 2, 0, 3);
USE_BUF(wctx, ctx0, 0);
wctx.use_buf(ctx0, 0);
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
USE_BUF(wctx, ctx0, 1);
wctx.use_buf(ctx0, 1);
cur = ggml_cpy(ctx0,
KQV_merged,
@ -1957,31 +1971,31 @@ static bool whisper_decode(
// projection
{
USE_BUF(wctx, ctx0, 0);
wctx.use_buf(ctx0, 0);
cur = ggml_mul_mat(ctx0,
layer.attn_ln_1_w,
cur);
USE_BUF(wctx, ctx0, 1);
wctx.use_buf(ctx0, 1);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
cur);
}
USE_BUF(wctx, ctx0, 2);
wctx.use_buf(ctx0, 2);
// add the input
struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL);
// norm
{
USE_BUF(wctx, ctx0, 0);
wctx.use_buf(ctx0, 0);
cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here
USE_BUF(wctx, ctx0, 1);
wctx.use_buf(ctx0, 1);
// cur = ln_0_w*cur + ln_0_b
cur = ggml_add(ctx0,
@ -1993,7 +2007,7 @@ static bool whisper_decode(
// cross-attention
{
USE_BUF(wctx, ctx0, 0);
wctx.use_buf(ctx0, 0);
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
layer.cross_attn_q_w,
@ -2022,7 +2036,7 @@ static bool whisper_decode(
// ------
USE_BUF(wctx, ctx0, 1);
wctx.use_buf(ctx0, 1);
struct ggml_tensor * Q =
ggml_permute(ctx0,
@ -2033,7 +2047,7 @@ static bool whisper_decode(
struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3);
USE_BUF(wctx, ctx0, 0);
wctx.use_buf(ctx0, 0);
// K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
@ -2047,15 +2061,15 @@ static bool whisper_decode(
// no masking for cross-attention
//struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
USE_BUF(wctx, ctx0, 1);
wctx.use_buf(ctx0, 1);
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
USE_BUF(wctx, ctx0, 0);
wctx.use_buf(ctx0, 0);
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
USE_BUF(wctx, ctx0, 1);
wctx.use_buf(ctx0, 1);
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
@ -2067,20 +2081,20 @@ static bool whisper_decode(
// projection
{
USE_BUF(wctx, ctx0, 0);
wctx.use_buf(ctx0, 0);
cur = ggml_mul_mat(ctx0,
layer.cross_attn_ln_1_w,
cur);
USE_BUF(wctx, ctx0, 1);
wctx.use_buf(ctx0, 1);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur),
cur);
}
USE_BUF(wctx, ctx0, 2);
wctx.use_buf(ctx0, 2);
// add the input
cur = ggml_add(ctx0, cur, inpCA);
@ -2091,11 +2105,11 @@ static bool whisper_decode(
{
// norm
{
USE_BUF(wctx, ctx0, 0);
wctx.use_buf(ctx0, 0);
cur = ggml_norm(ctx0, inpFF);
USE_BUF(wctx, ctx0, 1);
wctx.use_buf(ctx0, 1);
// cur = mlp_ln_w*cur + mlp_ln_b
cur = ggml_add(ctx0,
@ -2105,39 +2119,39 @@ static bool whisper_decode(
ggml_repeat(ctx0, layer.mlp_ln_b, cur));
}
USE_BUF(wctx, ctx0, 0);
wctx.use_buf(ctx0, 0);
// fully connected
cur = ggml_mul_mat(ctx0,
layer.mlp_0_w,
cur);
USE_BUF(wctx, ctx0, 1);
wctx.use_buf(ctx0, 1);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, layer.mlp_0_b, cur),
cur);
USE_BUF(wctx, ctx0, 0);
wctx.use_buf(ctx0, 0);
// GELU activation
cur = ggml_gelu(ctx0, cur);
USE_BUF(wctx, ctx0, 1);
wctx.use_buf(ctx0, 1);
// projection
cur = ggml_mul_mat(ctx0,
layer.mlp_1_w,
cur);
USE_BUF(wctx, ctx0, 0);
wctx.use_buf(ctx0, 0);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, layer.mlp_1_b, cur),
cur);
}
USE_BUF(wctx, ctx0, 3);
wctx.use_buf(ctx0, 3);
inpL = ggml_add(ctx0, cur, inpFF);
}
@ -2146,11 +2160,11 @@ static bool whisper_decode(
// norm
{
USE_BUF(wctx, ctx0, 0);
wctx.use_buf(ctx0, 0);
cur = ggml_norm(ctx0, cur);
USE_BUF(wctx, ctx0, 1);
wctx.use_buf(ctx0, 1);
cur = ggml_add(ctx0,
ggml_mul(ctx0,
@ -2159,7 +2173,7 @@ static bool whisper_decode(
ggml_repeat(ctx0, model.d_ln_b, cur));
}
USE_BUF(wctx, ctx0, 0);
wctx.use_buf(ctx0, 0);
// compute logits only for the last token
// comment this line to compute logits for all N tokens
@ -2168,7 +2182,7 @@ static bool whisper_decode(
struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
USE_BUF(wctx, ctx0, -1);
wctx.use_buf(ctx0, -1);
// run the computation
{
@ -2185,8 +2199,12 @@ static bool whisper_decode(
memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
if (N > 1) {
printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0,
GET_BUF_MAX_MEM(0)/1024.0/1024.0, GET_BUF_MAX_MEM(1)/1024.0/1024.0, GET_BUF_MAX_MEM(2)/1024.0/1024.0, GET_BUF_MAX_MEM(3)/1024.0/1024.0);
printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
ggml_used_mem(ctx0)/1024.0/1024.0,
wctx.get_buf_max_mem(0)/1024.0/1024.0,
wctx.get_buf_max_mem(1)/1024.0/1024.0,
wctx.get_buf_max_mem(2)/1024.0/1024.0,
wctx.get_buf_max_mem(3)/1024.0/1024.0);
}
ggml_free(ctx0);

Loading…
Cancel
Save