whisper : move self-attention kv cache to whisper_decoder

pull/291/head
Georgi Gerganov 3 years ago
parent ee58108df8
commit 628843c60a
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -24,8 +24,9 @@
} \ } \
} while (0) } while (0)
#define USE_FLASH_ATTN #define WHISPER_USE_FLASH_ATTN
//#define USE_FLASH_FF //#define WHISPER_USE_FLASH_FF
#define WHISPER_MAX_DECODERS 16
// available whisper models // available whisper models
enum e_model { enum e_model {
@ -416,6 +417,7 @@ struct whisper_model {
}; };
struct whisper_decoder { struct whisper_decoder {
whisper_kv_cache kv_self;
}; };
struct whisper_sequence { struct whisper_sequence {
@ -429,20 +431,24 @@ struct whisper_context {
int64_t t_decode_us = 0; int64_t t_decode_us = 0;
int64_t t_start_us = 0; int64_t t_start_us = 0;
// memory buffers used by encode / decode contexts
std::vector<uint8_t> buf_compute;
std::vector<uint8_t> buf_compute_layer;
ggml_type wtype; // weight type (FP32 or FP16) ggml_type wtype; // weight type (FP32 or FP16)
whisper_mel mel;
whisper_model model; whisper_model model;
whisper_vocab vocab; whisper_vocab vocab;
whisper_kv_cache kv_self;
whisper_kv_cache kv_cross; whisper_kv_cache kv_cross;
whisper_mel mel; whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
int selected_decoder_id = 0;
// memory buffers used by encode / decode contexts
std::vector<uint8_t> buf_compute;
std::vector<uint8_t> buf_compute_layer;
// decode output
std::vector<float> probs; std::vector<float> probs;
std::vector<float> logits; std::vector<float> logits;
@ -470,11 +476,14 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
loader->read(loader->context, &dest, sizeof(T)); loader->read(loader->context, &dest, sizeof(T));
} }
static bool init_kv_cache( static bool kv_cache_init(
const struct whisper_hparams & hparams, const struct whisper_hparams & hparams,
const size_t mem_bytes,
struct whisper_kv_cache & cache, struct whisper_kv_cache & cache,
ggml_type wtype, ggml_type wtype,
int n_ctx) { int n_ctx) {
cache.buf.resize(mem_bytes);
struct ggml_init_params params; struct ggml_init_params params;
params.mem_size = cache.buf.size(); params.mem_size = cache.buf.size();
params.mem_buffer = cache.buf.data(); params.mem_buffer = cache.buf.data();
@ -498,6 +507,41 @@ static bool init_kv_cache(
return true; return true;
} }
static bool kv_cache_reinit(struct whisper_kv_cache & cache) {
WHISPER_ASSERT(cache.ctx);
const int n_elements = ggml_nelements(cache.k);
WHISPER_ASSERT(n_elements == ggml_nelements(cache.v));
const ggml_type wtype = cache.k->type;
WHISPER_ASSERT(wtype == cache.v->type);
WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*ggml_type_size(wtype));
struct ggml_init_params params;
params.mem_size = cache.buf.size();
params.mem_buffer = cache.buf.data();
cache.ctx = ggml_init(params);
if (!cache.ctx) {
fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__);
return false;
}
cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
return true;
}
static void kv_cache_free(struct whisper_kv_cache & cache) {
if (cache.ctx) {
ggml_free(cache.ctx);
cache.ctx = nullptr;
}
}
// load the model from a ggml file // load the model from a ggml file
// //
// file format: // file format:
@ -563,6 +607,12 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
model.type = e_model::MODEL_LARGE; model.type = e_model::MODEL_LARGE;
} }
// for the big tensors, we have the option to store the data in 16-bit floats
// in order to save memory and also to speed up the computation
wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
const size_t scale = model.hparams.f16 ? 1 : 2;
fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab); fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
fprintf(stderr, "%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx); fprintf(stderr, "%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
@ -576,17 +626,47 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16); fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16);
fprintf(stderr, "%s: type = %d\n", __func__, model.type); fprintf(stderr, "%s: type = %d\n", __func__, model.type);
// for the big tensors, we have the option to store the data in 16-bit floats // print memory requirements
// in order to save memory and also to speed up the computation {
wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32; // this is the total memory required to run the inference
const size_t mem_required =
scale*MEM_REQ_MODEL.at (model.type) +
scale*MEM_REQ_KV_CROSS.at (model.type) +
scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)) +
scale*std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type));
const size_t scale = model.hparams.f16 ? 1 : 2; // this is the memory required by one decoder
const size_t mem_required_decoder =
scale*MEM_REQ_KV_SELF.at(model.type);
fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
}
wctx.model.buf = new std::vector<uint8_t>(); wctx.model.buf = new std::vector<uint8_t>();
wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type)); wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type));
wctx.kv_self.buf.resize (scale*MEM_REQ_KV_SELF.at(model.type)); wctx.selected_decoder_id = 0;
wctx.kv_cross.buf.resize(scale*MEM_REQ_KV_CROSS.at(model.type));
if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, wctx.wtype, model.hparams.n_text_ctx)) {
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
return false;
}
{
const size_t memory_size = ggml_nbytes(wctx.decoders[0].kv_self.k) + ggml_nbytes(wctx.decoders[0].kv_self.v);
fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
}
if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_CROSS.at(model.type), wctx.kv_cross, wctx.wtype, model.hparams.n_audio_ctx)) {
fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__);
return false;
}
{
const size_t memory_size = ggml_nbytes(wctx.kv_cross.k) + ggml_nbytes(wctx.kv_cross.v);
fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
}
wctx.buf_compute.resize (scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); wctx.buf_compute.resize (scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
wctx.buf_compute_layer.resize(scale*std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type))); wctx.buf_compute_layer.resize(scale*std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
@ -673,26 +753,14 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
} }
wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx); wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
wctx.probs.reserve(vocab.n_vocab*model.hparams.n_text_ctx); wctx.probs.reserve (vocab.n_vocab*model.hparams.n_text_ctx);
wctx.work_logits.reserve(vocab.n_vocab); wctx.work_logits.reserve (vocab.n_vocab);
wctx.work_logprobs.reserve(vocab.n_vocab); wctx.work_logprobs.reserve(vocab.n_vocab);
wctx.probs_id.reserve(n_vocab); wctx.probs_id.reserve(n_vocab);
} }
{
// this is the total memory required to run the inference
const size_t mem_required =
wctx.model.buf->size() +
wctx.kv_self.buf.size() +
wctx.kv_cross.buf.size() +
wctx.buf_compute.size() +
wctx.buf_compute_layer.size();
fprintf(stderr, "%s: mem required = %7.2f MB\n", __func__, mem_required / 1024.0 / 1024.0);
}
size_t ctx_size = 0; size_t ctx_size = 0;
const ggml_type wtype = wctx.wtype; const ggml_type wtype = wctx.wtype;
@ -1016,27 +1084,6 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
} }
} }
// TODO: move to decoder
if (!init_kv_cache(model.hparams, wctx.kv_self, wctx.wtype, model.hparams.n_text_ctx)) {
fprintf(stderr, "%s: init_kv_cache() failed for self-attention cache\n", __func__);
return false;
}
{
const size_t memory_size = ggml_nbytes(wctx.kv_self.k) + ggml_nbytes(wctx.kv_self.v);
fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
}
if (!init_kv_cache(model.hparams, wctx.kv_cross, wctx.wtype, model.hparams.n_audio_ctx)) {
fprintf(stderr, "%s: init_kv_cache() failed for cross-attention cache\n", __func__);
return false;
}
{
const size_t memory_size = ggml_nbytes(wctx.kv_cross.k) + ggml_nbytes(wctx.kv_cross.v);
fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
}
// load weights // load weights
{ {
size_t total_size = 0; size_t total_size = 0;
@ -1266,7 +1313,7 @@ static bool whisper_encode(
// ------ // ------
#ifdef USE_FLASH_ATTN #ifdef WHISPER_USE_FLASH_ATTN
struct ggml_tensor * Q = struct ggml_tensor * Q =
ggml_permute(ctxL, ggml_permute(ctxL,
ggml_cpy(ctxL, ggml_cpy(ctxL,
@ -1377,7 +1424,7 @@ static bool whisper_encode(
ggml_repeat(ctxL, layer.mlp_ln_b, cur)); ggml_repeat(ctxL, layer.mlp_ln_b, cur));
} }
#ifdef USE_FLASH_FF #ifdef WHISPER_USE_FLASH_FF
cur = ggml_flash_ff(ctxL, cur = ggml_flash_ff(ctxL,
ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, wctx.wtype, n_state, N)), ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, wctx.wtype, n_state, N)),
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
@ -1539,6 +1586,13 @@ static bool whisper_decode(
const auto & model = wctx.model; const auto & model = wctx.model;
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
WHISPER_ASSERT(wctx.selected_decoder_id >= 0);
WHISPER_ASSERT(wctx.selected_decoder_id < WHISPER_MAX_DECODERS);
auto & kv_self = wctx.decoders[wctx.selected_decoder_id].kv_self;
WHISPER_ASSERT(!!kv_self.ctx);
auto & logits_out = wctx.logits; auto & logits_out = wctx.logits;
auto & probs_out = wctx.probs; auto & probs_out = wctx.probs;
@ -1630,8 +1684,8 @@ static bool whisper_decode(
// store key and value to memory // store key and value to memory
{ {
struct ggml_tensor * k = ggml_view_1d(ctxL, wctx.kv_self.k, N*n_state, (ggml_element_size(wctx.kv_self.k)*n_state)*(il*n_ctx + n_past)); struct ggml_tensor * k = ggml_view_1d(ctxL, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past));
struct ggml_tensor * v = ggml_view_1d(ctxL, wctx.kv_self.v, N*n_state, (ggml_element_size(wctx.kv_self.v)*n_state)*(il*n_ctx + n_past)); struct ggml_tensor * v = ggml_view_1d(ctxL, kv_self.v, N*n_state, (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + n_past));
ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Kcur, k)); ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Kcur, k));
ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Vcur, v)); ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Vcur, v));
@ -1649,7 +1703,7 @@ static bool whisper_decode(
struct ggml_tensor * K = struct ggml_tensor * K =
ggml_permute(ctxL, ggml_permute(ctxL,
ggml_reshape_3d(ctxL, ggml_reshape_3d(ctxL,
ggml_view_1d(ctxL, wctx.kv_self.k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(wctx.kv_self.k)*n_state), ggml_view_1d(ctxL, kv_self.k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.k)*n_state),
n_state/n_head, n_head, n_past + N), n_state/n_head, n_head, n_past + N),
0, 2, 1, 3); 0, 2, 1, 3);
@ -1669,7 +1723,7 @@ static bool whisper_decode(
struct ggml_tensor * V_trans = struct ggml_tensor * V_trans =
ggml_permute(ctxL, ggml_permute(ctxL,
ggml_reshape_3d(ctxL, ggml_reshape_3d(ctxL,
ggml_view_1d(ctxL, wctx.kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(wctx.kv_self.v)*n_state), ggml_view_1d(ctxL, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state),
n_state/n_head, n_head, n_past + N), n_state/n_head, n_head, n_past + N),
1, 2, 0, 3); 1, 2, 0, 3);
@ -2370,12 +2424,14 @@ void whisper_free(struct whisper_context * ctx) {
if (ctx->model.buf) { if (ctx->model.buf) {
delete ctx->model.buf; delete ctx->model.buf;
} }
if (ctx->kv_self.ctx) {
ggml_free(ctx->kv_self.ctx);
}
if (ctx->kv_cross.ctx) { if (ctx->kv_cross.ctx) {
ggml_free(ctx->kv_cross.ctx); ggml_free(ctx->kv_cross.ctx);
} }
for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) {
if (ctx->decoders[i].kv_self.ctx) {
ggml_free(ctx->decoders[i].kv_self.ctx);
}
}
delete ctx; delete ctx;
} }
} }
@ -3413,16 +3469,18 @@ int whisper_full_parallel(
ctx_p = *ctx; ctx_p = *ctx;
if (!init_kv_cache(ctx_p.model.hparams, ctx_p.kv_self, ctx_p.wtype, ctx_p.model.hparams.n_text_ctx)) { if (!kv_cache_reinit(ctx_p.kv_cross)) {
fprintf(stderr, "%s: init_kv_cache() failed for self-attention cache\n", __func__); fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention cache\n", __func__);
return false; return false;
} }
if (!init_kv_cache(ctx_p.model.hparams, ctx_p.kv_cross, ctx_p.wtype, ctx_p.model.hparams.n_audio_ctx)) { for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) {
fprintf(stderr, "%s: init_kv_cache() failed for cross-attention cache\n", __func__); if (ctx_p.decoders[j].kv_self.ctx && !kv_cache_reinit(ctx_p.decoders[j].kv_self)) {
fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention cache\n", __func__);
return false; return false;
} }
} }
}
const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000; const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000;
const int n_samples_per_processor = (n_samples - offset_samples)/n_processors; const int n_samples_per_processor = (n_samples - offset_samples)/n_processors;
@ -3485,6 +3543,12 @@ int whisper_full_parallel(
ctx->t_sample_us += ctxs[i].t_sample_us; ctx->t_sample_us += ctxs[i].t_sample_us;
ctx->t_encode_us += ctxs[i].t_encode_us; ctx->t_encode_us += ctxs[i].t_encode_us;
ctx->t_decode_us += ctxs[i].t_decode_us; ctx->t_decode_us += ctxs[i].t_decode_us;
kv_cache_free(ctx->kv_cross);
for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) {
kv_cache_free(ctx->decoders[j].kv_self);
}
} }
// average the timings // average the timings

Loading…
Cancel
Save