|
|
|
@ -99,35 +99,12 @@ static void byteswap_tensor(ggml_tensor * tensor) {
|
|
|
|
|
#define WHISPER_PRINT_DEBUG(...)
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#define WHISPER_USE_SCRATCH
|
|
|
|
|
#define WHISPER_USE_FLASH_ATTN
|
|
|
|
|
//#define WHISPER_USE_FLASH_FF
|
|
|
|
|
#define WHISPER_MAX_DECODERS 16
|
|
|
|
|
|
|
|
|
|
// TODO: make these part of the whisper_context instead of macros
|
|
|
|
|
#if defined(WHISPER_USE_SCRATCH)
|
|
|
|
|
static int g_buf_last = 0;
|
|
|
|
|
static size_t g_buf_max_size[16] = { 0 };
|
|
|
|
|
#define USE_BUF(wctx, gctx, i) \
|
|
|
|
|
do { \
|
|
|
|
|
size_t last_size = 0; \
|
|
|
|
|
if (i == -1) { \
|
|
|
|
|
last_size = ggml_set_scratch(gctx, { 0, 0, nullptr, }); \
|
|
|
|
|
} else { \
|
|
|
|
|
auto & buf = wctx.buf_scratch[i]; \
|
|
|
|
|
last_size = ggml_set_scratch(gctx, { 0, buf.size(), buf.data(), }); \
|
|
|
|
|
} \
|
|
|
|
|
if (g_buf_last >= 0) { \
|
|
|
|
|
g_buf_max_size[g_buf_last] = std::max(g_buf_max_size[g_buf_last], last_size); \
|
|
|
|
|
} \
|
|
|
|
|
g_buf_last = i; \
|
|
|
|
|
} while (0)
|
|
|
|
|
|
|
|
|
|
#define GET_BUF_MAX_MEM(i) g_buf_max_size[i]
|
|
|
|
|
#else
|
|
|
|
|
#define USE_BUF(wctx, gctx, i) do {} while (0)
|
|
|
|
|
#define GET_BUF_MAX_MEM(i) 0
|
|
|
|
|
#endif
|
|
|
|
|
#define WHISPER_USE_SCRATCH
|
|
|
|
|
#define WHISPER_MAX_SCRATCH_BUFFERS 16
|
|
|
|
|
|
|
|
|
|
// available whisper models
|
|
|
|
|
enum e_model {
|
|
|
|
@ -598,8 +575,11 @@ struct whisper_context {
|
|
|
|
|
whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
|
|
|
|
|
|
|
|
|
|
// memory buffers used by encode / decode contexts
|
|
|
|
|
std::vector<uint8_t> buf_scratch[4];
|
|
|
|
|
std::vector<uint8_t> buf_compute;
|
|
|
|
|
std::vector<uint8_t> buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS];
|
|
|
|
|
|
|
|
|
|
int buf_last = 0;
|
|
|
|
|
size_t buf_max_size[WHISPER_MAX_SCRATCH_BUFFERS] = { 0 };
|
|
|
|
|
|
|
|
|
|
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
|
|
|
|
std::vector<float> logits;
|
|
|
|
@ -620,6 +600,36 @@ struct whisper_context {
|
|
|
|
|
|
|
|
|
|
// [EXPERIMENTAL] speed-up techniques
|
|
|
|
|
int32_t exp_n_audio_ctx; // 0 - use default
|
|
|
|
|
|
|
|
|
|
void use_buf(struct ggml_context * ctx, int i) {
|
|
|
|
|
#if defined(WHISPER_USE_SCRATCH)
|
|
|
|
|
size_t last_size = 0;
|
|
|
|
|
|
|
|
|
|
if (i == -1) {
|
|
|
|
|
last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, });
|
|
|
|
|
} else {
|
|
|
|
|
auto & buf = buf_scratch[i];
|
|
|
|
|
last_size = ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), });
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (buf_last >= 0) {
|
|
|
|
|
buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
buf_last = i;
|
|
|
|
|
#else
|
|
|
|
|
(void) i;
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t get_buf_max_mem(int i) const {
|
|
|
|
|
#if defined(WHISPER_USE_SCRATCH)
|
|
|
|
|
return buf_max_size[i];
|
|
|
|
|
#else
|
|
|
|
|
(void) i;
|
|
|
|
|
return 0;
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template<typename T>
|
|
|
|
@ -828,12 +838,12 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
|
|
fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
wctx.buf_compute.resize(scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
|
|
|
|
|
|
|
|
|
|
wctx.buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(model.type));
|
|
|
|
|
wctx.buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(model.type));
|
|
|
|
|
wctx.buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(model.type));
|
|
|
|
|
wctx.buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(model.type));
|
|
|
|
|
|
|
|
|
|
wctx.buf_compute.resize(scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// load mel filters
|
|
|
|
@ -1366,7 +1376,7 @@ static bool whisper_encode(
|
|
|
|
|
|
|
|
|
|
struct ggml_context * ctx0 = ggml_init(params);
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 0);
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
|
|
|
|
|
assert(mel->type == GGML_TYPE_F32);
|
|
|
|
@ -1388,7 +1398,7 @@ static bool whisper_encode(
|
|
|
|
|
|
|
|
|
|
// convolution + gelu
|
|
|
|
|
{
|
|
|
|
|
USE_BUF(wctx, ctx0, 1);
|
|
|
|
|
wctx.use_buf(ctx0, 1);
|
|
|
|
|
|
|
|
|
|
cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel);
|
|
|
|
|
cur = ggml_add(ctx0,
|
|
|
|
@ -1399,7 +1409,7 @@ static bool whisper_encode(
|
|
|
|
|
|
|
|
|
|
cur = ggml_gelu(ctx0, cur);
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 0);
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
|
|
|
|
|
cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur);
|
|
|
|
|
cur = ggml_add(ctx0,
|
|
|
|
@ -1411,7 +1421,7 @@ static bool whisper_encode(
|
|
|
|
|
cur = ggml_gelu(ctx0, cur);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 3);
|
|
|
|
|
wctx.use_buf(ctx0, 3);
|
|
|
|
|
|
|
|
|
|
// ===================================================================
|
|
|
|
|
// NOTE: experimenting with partial evaluation of the encoder (ignore)
|
|
|
|
@ -1446,7 +1456,7 @@ static bool whisper_encode(
|
|
|
|
|
|
|
|
|
|
// norm
|
|
|
|
|
{
|
|
|
|
|
USE_BUF(wctx, ctx0, 0);
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
|
|
|
|
|
cur = ggml_norm(ctx0, inpL);
|
|
|
|
|
|
|
|
|
@ -1460,7 +1470,7 @@ static bool whisper_encode(
|
|
|
|
|
|
|
|
|
|
// self-attention
|
|
|
|
|
{
|
|
|
|
|
USE_BUF(wctx, ctx0, 1);
|
|
|
|
|
wctx.use_buf(ctx0, 1);
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
|
|
|
|
|
layer.attn_q_w,
|
|
|
|
@ -1493,7 +1503,7 @@ static bool whisper_encode(
|
|
|
|
|
|
|
|
|
|
// ------
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 0);
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
|
|
|
|
|
#ifdef WHISPER_USE_FLASH_ATTN
|
|
|
|
|
struct ggml_tensor * Q =
|
|
|
|
@ -1570,7 +1580,7 @@ static bool whisper_encode(
|
|
|
|
|
#endif
|
|
|
|
|
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 1);
|
|
|
|
|
wctx.use_buf(ctx0, 1);
|
|
|
|
|
|
|
|
|
|
cur = ggml_cpy(ctx0,
|
|
|
|
|
KQV_merged,
|
|
|
|
@ -1579,20 +1589,20 @@ static bool whisper_encode(
|
|
|
|
|
|
|
|
|
|
// projection
|
|
|
|
|
{
|
|
|
|
|
USE_BUF(wctx, ctx0, 0);
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
|
|
|
|
|
cur = ggml_mul_mat(ctx0,
|
|
|
|
|
layer.attn_ln_1_w,
|
|
|
|
|
cur);
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 1);
|
|
|
|
|
wctx.use_buf(ctx0, 1);
|
|
|
|
|
|
|
|
|
|
cur = ggml_add(ctx0,
|
|
|
|
|
ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
|
|
|
|
|
cur);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 2);
|
|
|
|
|
wctx.use_buf(ctx0, 2);
|
|
|
|
|
|
|
|
|
|
// add the input
|
|
|
|
|
cur = ggml_add(ctx0, cur, inpL);
|
|
|
|
@ -1603,11 +1613,11 @@ static bool whisper_encode(
|
|
|
|
|
{
|
|
|
|
|
// norm
|
|
|
|
|
{
|
|
|
|
|
USE_BUF(wctx, ctx0, 0);
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
|
|
|
|
|
cur = ggml_norm(ctx0, inpFF);
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 1);
|
|
|
|
|
wctx.use_buf(ctx0, 1);
|
|
|
|
|
|
|
|
|
|
// cur = mlp_ln_w*cur + mlp_ln_b
|
|
|
|
|
cur = ggml_add(ctx0,
|
|
|
|
@ -1618,38 +1628,38 @@ static bool whisper_encode(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifdef WHISPER_USE_FLASH_FF
|
|
|
|
|
USE_BUF(wctx, ctx0, 0);
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
|
|
|
|
|
cur = ggml_flash_ff(ctx0,
|
|
|
|
|
ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wctx.wtype, n_state, N)),
|
|
|
|
|
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
|
|
|
|
|
#else
|
|
|
|
|
USE_BUF(wctx, ctx0, 0);
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
|
|
|
|
|
// fully connected
|
|
|
|
|
cur = ggml_mul_mat(ctx0,
|
|
|
|
|
layer.mlp_0_w,
|
|
|
|
|
cur);
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 1);
|
|
|
|
|
wctx.use_buf(ctx0, 1);
|
|
|
|
|
|
|
|
|
|
cur = ggml_add(ctx0,
|
|
|
|
|
ggml_repeat(ctx0, layer.mlp_0_b, cur),
|
|
|
|
|
cur);
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 0);
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
|
|
|
|
|
// GELU activation
|
|
|
|
|
cur = ggml_gelu(ctx0, cur);
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 1);
|
|
|
|
|
wctx.use_buf(ctx0, 1);
|
|
|
|
|
|
|
|
|
|
// projection
|
|
|
|
|
cur = ggml_mul_mat(ctx0,
|
|
|
|
|
layer.mlp_1_w,
|
|
|
|
|
cur);
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 0);
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
|
|
|
|
|
cur = ggml_add(ctx0,
|
|
|
|
|
ggml_repeat(ctx0, layer.mlp_1_b, cur),
|
|
|
|
@ -1657,7 +1667,7 @@ static bool whisper_encode(
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 3);
|
|
|
|
|
wctx.use_buf(ctx0, 3);
|
|
|
|
|
|
|
|
|
|
inpL = ggml_add(ctx0, cur, inpFF);
|
|
|
|
|
}
|
|
|
|
@ -1666,11 +1676,11 @@ static bool whisper_encode(
|
|
|
|
|
|
|
|
|
|
// norm
|
|
|
|
|
{
|
|
|
|
|
USE_BUF(wctx, ctx0, 0);
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
|
|
|
|
|
cur = ggml_norm(ctx0, cur);
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 1);
|
|
|
|
|
wctx.use_buf(ctx0, 1);
|
|
|
|
|
|
|
|
|
|
// cur = ln_f_g*cur + ln_f_b
|
|
|
|
|
cur = ggml_add(ctx0,
|
|
|
|
@ -1680,7 +1690,7 @@ static bool whisper_encode(
|
|
|
|
|
ggml_repeat(ctx0, model.e_ln_b, cur));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, -1);
|
|
|
|
|
wctx.use_buf(ctx0, -1);
|
|
|
|
|
|
|
|
|
|
// run the computation
|
|
|
|
|
{
|
|
|
|
@ -1720,7 +1730,7 @@ static bool whisper_encode(
|
|
|
|
|
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
|
|
|
|
auto & layer = model.layers_decoder[il];
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 0);
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * Kcross = ggml_mul_mat(ctx0,
|
|
|
|
|
layer.cross_attn_k_w,
|
|
|
|
@ -1728,7 +1738,7 @@ static bool whisper_encode(
|
|
|
|
|
|
|
|
|
|
Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 1);
|
|
|
|
|
wctx.use_buf(ctx0, 1);
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
|
|
|
|
|
layer.cross_attn_v_w,
|
|
|
|
@ -1740,7 +1750,7 @@ static bool whisper_encode(
|
|
|
|
|
Vcross),
|
|
|
|
|
Vcross);
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, -1);
|
|
|
|
|
wctx.use_buf(ctx0, -1);
|
|
|
|
|
|
|
|
|
|
//struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
|
|
|
|
|
//struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
|
|
|
|
@ -1757,8 +1767,12 @@ static bool whisper_encode(
|
|
|
|
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0,
|
|
|
|
|
GET_BUF_MAX_MEM(0)/1024.0/1024.0, GET_BUF_MAX_MEM(1)/1024.0/1024.0, GET_BUF_MAX_MEM(2)/1024.0/1024.0, GET_BUF_MAX_MEM(3)/1024.0/1024.0);
|
|
|
|
|
printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
|
|
|
|
ggml_used_mem(ctx0)/1024.0/1024.0,
|
|
|
|
|
wctx.get_buf_max_mem(0)/1024.0/1024.0,
|
|
|
|
|
wctx.get_buf_max_mem(1)/1024.0/1024.0,
|
|
|
|
|
wctx.get_buf_max_mem(2)/1024.0/1024.0,
|
|
|
|
|
wctx.get_buf_max_mem(3)/1024.0/1024.0);
|
|
|
|
|
|
|
|
|
|
ggml_free(ctx0);
|
|
|
|
|
|
|
|
|
@ -1825,7 +1839,7 @@ static bool whisper_decode(
|
|
|
|
|
((int32_t *) position->data)[i] = n_past + i;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 3);
|
|
|
|
|
wctx.use_buf(ctx0, 3);
|
|
|
|
|
|
|
|
|
|
// token encoding + position encoding
|
|
|
|
|
struct ggml_tensor * cur =
|
|
|
|
@ -1840,7 +1854,7 @@ static bool whisper_decode(
|
|
|
|
|
|
|
|
|
|
// norm
|
|
|
|
|
{
|
|
|
|
|
USE_BUF(wctx, ctx0, 0);
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
|
|
|
|
|
cur = ggml_norm(ctx0, inpL);
|
|
|
|
|
|
|
|
|
@ -1854,7 +1868,7 @@ static bool whisper_decode(
|
|
|
|
|
|
|
|
|
|
// self-attention
|
|
|
|
|
{
|
|
|
|
|
USE_BUF(wctx, ctx0, 1);
|
|
|
|
|
wctx.use_buf(ctx0, 1);
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
|
|
|
|
|
layer.attn_q_w,
|
|
|
|
@ -1885,7 +1899,7 @@ static bool whisper_decode(
|
|
|
|
|
Vcur),
|
|
|
|
|
Vcur);
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 0);
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
|
|
|
|
|
// store key and value to memory
|
|
|
|
|
{
|
|
|
|
@ -1898,7 +1912,7 @@ static bool whisper_decode(
|
|
|
|
|
|
|
|
|
|
// ------
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 1);
|
|
|
|
|
wctx.use_buf(ctx0, 1);
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * Q =
|
|
|
|
|
ggml_permute(ctx0,
|
|
|
|
@ -1914,12 +1928,12 @@ static bool whisper_decode(
|
|
|
|
|
n_state/n_head, n_head, n_past + N),
|
|
|
|
|
0, 2, 1, 3);
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 0);
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
|
|
|
|
|
// K * Q
|
|
|
|
|
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 1);
|
|
|
|
|
wctx.use_buf(ctx0, 1);
|
|
|
|
|
|
|
|
|
|
//struct ggml_tensor * KQ_scaled =
|
|
|
|
|
// ggml_scale(ctx0,
|
|
|
|
@ -1929,11 +1943,11 @@ static bool whisper_decode(
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 0);
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 1);
|
|
|
|
|
wctx.use_buf(ctx0, 1);
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * V_trans =
|
|
|
|
|
ggml_permute(ctx0,
|
|
|
|
@ -1942,13 +1956,13 @@ static bool whisper_decode(
|
|
|
|
|
n_state/n_head, n_head, n_past + N),
|
|
|
|
|
1, 2, 0, 3);
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 0);
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 1);
|
|
|
|
|
wctx.use_buf(ctx0, 1);
|
|
|
|
|
|
|
|
|
|
cur = ggml_cpy(ctx0,
|
|
|
|
|
KQV_merged,
|
|
|
|
@ -1957,31 +1971,31 @@ static bool whisper_decode(
|
|
|
|
|
|
|
|
|
|
// projection
|
|
|
|
|
{
|
|
|
|
|
USE_BUF(wctx, ctx0, 0);
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
|
|
|
|
|
cur = ggml_mul_mat(ctx0,
|
|
|
|
|
layer.attn_ln_1_w,
|
|
|
|
|
cur);
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 1);
|
|
|
|
|
wctx.use_buf(ctx0, 1);
|
|
|
|
|
|
|
|
|
|
cur = ggml_add(ctx0,
|
|
|
|
|
ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
|
|
|
|
|
cur);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 2);
|
|
|
|
|
wctx.use_buf(ctx0, 2);
|
|
|
|
|
|
|
|
|
|
// add the input
|
|
|
|
|
struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL);
|
|
|
|
|
|
|
|
|
|
// norm
|
|
|
|
|
{
|
|
|
|
|
USE_BUF(wctx, ctx0, 0);
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
|
|
|
|
|
cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 1);
|
|
|
|
|
wctx.use_buf(ctx0, 1);
|
|
|
|
|
|
|
|
|
|
// cur = ln_0_w*cur + ln_0_b
|
|
|
|
|
cur = ggml_add(ctx0,
|
|
|
|
@ -1993,7 +2007,7 @@ static bool whisper_decode(
|
|
|
|
|
|
|
|
|
|
// cross-attention
|
|
|
|
|
{
|
|
|
|
|
USE_BUF(wctx, ctx0, 0);
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
|
|
|
|
|
layer.cross_attn_q_w,
|
|
|
|
@ -2022,7 +2036,7 @@ static bool whisper_decode(
|
|
|
|
|
|
|
|
|
|
// ------
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 1);
|
|
|
|
|
wctx.use_buf(ctx0, 1);
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * Q =
|
|
|
|
|
ggml_permute(ctx0,
|
|
|
|
@ -2033,7 +2047,7 @@ static bool whisper_decode(
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3);
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 0);
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
|
|
|
|
|
// K * Q
|
|
|
|
|
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
|
|
|
@ -2047,15 +2061,15 @@ static bool whisper_decode(
|
|
|
|
|
// no masking for cross-attention
|
|
|
|
|
//struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 1);
|
|
|
|
|
wctx.use_buf(ctx0, 1);
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 0);
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 1);
|
|
|
|
|
wctx.use_buf(ctx0, 1);
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
|
|
|
|
|
|
|
|
@ -2067,20 +2081,20 @@ static bool whisper_decode(
|
|
|
|
|
|
|
|
|
|
// projection
|
|
|
|
|
{
|
|
|
|
|
USE_BUF(wctx, ctx0, 0);
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
|
|
|
|
|
cur = ggml_mul_mat(ctx0,
|
|
|
|
|
layer.cross_attn_ln_1_w,
|
|
|
|
|
cur);
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 1);
|
|
|
|
|
wctx.use_buf(ctx0, 1);
|
|
|
|
|
|
|
|
|
|
cur = ggml_add(ctx0,
|
|
|
|
|
ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur),
|
|
|
|
|
cur);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 2);
|
|
|
|
|
wctx.use_buf(ctx0, 2);
|
|
|
|
|
|
|
|
|
|
// add the input
|
|
|
|
|
cur = ggml_add(ctx0, cur, inpCA);
|
|
|
|
@ -2091,11 +2105,11 @@ static bool whisper_decode(
|
|
|
|
|
{
|
|
|
|
|
// norm
|
|
|
|
|
{
|
|
|
|
|
USE_BUF(wctx, ctx0, 0);
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
|
|
|
|
|
cur = ggml_norm(ctx0, inpFF);
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 1);
|
|
|
|
|
wctx.use_buf(ctx0, 1);
|
|
|
|
|
|
|
|
|
|
// cur = mlp_ln_w*cur + mlp_ln_b
|
|
|
|
|
cur = ggml_add(ctx0,
|
|
|
|
@ -2105,39 +2119,39 @@ static bool whisper_decode(
|
|
|
|
|
ggml_repeat(ctx0, layer.mlp_ln_b, cur));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 0);
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
|
|
|
|
|
// fully connected
|
|
|
|
|
cur = ggml_mul_mat(ctx0,
|
|
|
|
|
layer.mlp_0_w,
|
|
|
|
|
cur);
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 1);
|
|
|
|
|
wctx.use_buf(ctx0, 1);
|
|
|
|
|
|
|
|
|
|
cur = ggml_add(ctx0,
|
|
|
|
|
ggml_repeat(ctx0, layer.mlp_0_b, cur),
|
|
|
|
|
cur);
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 0);
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
|
|
|
|
|
// GELU activation
|
|
|
|
|
cur = ggml_gelu(ctx0, cur);
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 1);
|
|
|
|
|
wctx.use_buf(ctx0, 1);
|
|
|
|
|
|
|
|
|
|
// projection
|
|
|
|
|
cur = ggml_mul_mat(ctx0,
|
|
|
|
|
layer.mlp_1_w,
|
|
|
|
|
cur);
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 0);
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
|
|
|
|
|
cur = ggml_add(ctx0,
|
|
|
|
|
ggml_repeat(ctx0, layer.mlp_1_b, cur),
|
|
|
|
|
cur);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 3);
|
|
|
|
|
wctx.use_buf(ctx0, 3);
|
|
|
|
|
|
|
|
|
|
inpL = ggml_add(ctx0, cur, inpFF);
|
|
|
|
|
}
|
|
|
|
@ -2146,11 +2160,11 @@ static bool whisper_decode(
|
|
|
|
|
|
|
|
|
|
// norm
|
|
|
|
|
{
|
|
|
|
|
USE_BUF(wctx, ctx0, 0);
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
|
|
|
|
|
cur = ggml_norm(ctx0, cur);
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 1);
|
|
|
|
|
wctx.use_buf(ctx0, 1);
|
|
|
|
|
|
|
|
|
|
cur = ggml_add(ctx0,
|
|
|
|
|
ggml_mul(ctx0,
|
|
|
|
@ -2159,7 +2173,7 @@ static bool whisper_decode(
|
|
|
|
|
ggml_repeat(ctx0, model.d_ln_b, cur));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, 0);
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
|
|
|
|
|
// compute logits only for the last token
|
|
|
|
|
// comment this line to compute logits for all N tokens
|
|
|
|
@ -2168,7 +2182,7 @@ static bool whisper_decode(
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
|
|
|
|
|
|
|
|
|
|
USE_BUF(wctx, ctx0, -1);
|
|
|
|
|
wctx.use_buf(ctx0, -1);
|
|
|
|
|
|
|
|
|
|
// run the computation
|
|
|
|
|
{
|
|
|
|
@ -2185,8 +2199,12 @@ static bool whisper_decode(
|
|
|
|
|
memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
|
|
|
|
|
|
|
|
|
|
if (N > 1) {
|
|
|
|
|
printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0,
|
|
|
|
|
GET_BUF_MAX_MEM(0)/1024.0/1024.0, GET_BUF_MAX_MEM(1)/1024.0/1024.0, GET_BUF_MAX_MEM(2)/1024.0/1024.0, GET_BUF_MAX_MEM(3)/1024.0/1024.0);
|
|
|
|
|
printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
|
|
|
|
ggml_used_mem(ctx0)/1024.0/1024.0,
|
|
|
|
|
wctx.get_buf_max_mem(0)/1024.0/1024.0,
|
|
|
|
|
wctx.get_buf_max_mem(1)/1024.0/1024.0,
|
|
|
|
|
wctx.get_buf_max_mem(2)/1024.0/1024.0,
|
|
|
|
|
wctx.get_buf_max_mem(3)/1024.0/1024.0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ggml_free(ctx0);
|
|
|
|
|