whisper : separate self and cross attention memory

Initial step needed for supporting parallel decoders
pull/291/head
Georgi Gerganov 3 years ago
parent 21559537e2
commit 523e0494a6
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -150,11 +150,19 @@ static const std::map<e_model, size_t> MEM_REQ_MODEL = {
};
static const std::map<e_model, size_t> MEM_REQ_MEMORY = {
{ MODEL_TINY, 12ull*MB },
{ MODEL_BASE, 24ull*MB },
{ MODEL_SMALL, 70ull*MB },
{ MODEL_MEDIUM, 184ull*MB },
{ MODEL_LARGE, 306ull*MB },
{ MODEL_TINY, 3ull*MB },
{ MODEL_BASE, 6ull*MB },
{ MODEL_SMALL, 16ull*MB },
{ MODEL_MEDIUM, 43ull*MB },
{ MODEL_LARGE, 71ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_MEMORY_CROSS = {
{ MODEL_TINY, 9ull*MB },
{ MODEL_BASE, 18ull*MB },
{ MODEL_SMALL, 53ull*MB },
{ MODEL_MEDIUM, 141ull*MB },
{ MODEL_LARGE, 235ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
@ -391,22 +399,27 @@ struct whisper_model {
std::vector<whisper_layer_encoder> layers_encoder;
std::vector<whisper_layer_decoder> layers_decoder;
// key + value memory
// key + value memory for self attention
struct ggml_tensor * memory_k;
struct ggml_tensor * memory_v;
// key + value memory for cross attention
struct ggml_tensor * memory_cross_k;
struct ggml_tensor * memory_cross_v;
// context
struct ggml_context * ctx;
struct ggml_context * ctx_mem;
struct ggml_context * ctx_mem_cross;
// tensors
int n_loaded;
std::map<std::string, struct ggml_tensor *> tensors;
};
struct whisper_decoder_data {
};
struct whisper_context {
int64_t t_load_us = 0;
int64_t t_mel_us = 0;
@ -417,6 +430,7 @@ struct whisper_context {
std::vector<uint8_t> * buf_model; // the model buffer is read-only and can be shared between processors
std::vector<uint8_t> buf_memory;
std::vector<uint8_t> buf_memory_cross;
std::vector<uint8_t> buf_compute;
std::vector<uint8_t> buf_compute_layer;
@ -533,6 +547,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
wctx.buf_model = new std::vector<uint8_t>();
wctx.buf_model->resize(MEM_REQ_MODEL.at(model.type));
wctx.buf_memory.resize(MEM_REQ_MEMORY.at(model.type));
wctx.buf_memory_cross.resize(MEM_REQ_MEMORY_CROSS.at(model.type));
wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
}
@ -631,6 +646,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
const size_t mem_required =
wctx.buf_model->size() +
wctx.buf_memory.size() +
wctx.buf_memory_cross.size() +
wctx.buf_compute.size() +
wctx.buf_compute_layer.size();
@ -964,31 +980,27 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
}
}
// create the ggml memory context
// create the ggml context for the key/value memory (self-attention)
{
struct ggml_init_params params;
params.mem_size = wctx.buf_memory.size();
params.mem_buffer = wctx.buf_memory.data();
model.ctx_mem = ggml_init(params);
if (!model.ctx_mem) {
auto & ctx = model.ctx_mem;
ctx = ggml_init(params);
if (!ctx) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}
}
// key + value memory
{
auto & ctx = model.ctx_mem;
const auto & hparams = model.hparams;
{
const auto & hparams = model.hparams;
const int n_text_state = hparams.n_text_state;
const int n_text_layer = hparams.n_text_layer;
const int n_text_ctx = hparams.n_text_ctx;
const int n_text_state = hparams.n_text_state;
const int n_text_layer = hparams.n_text_layer;
const int n_text_ctx = hparams.n_text_ctx;
// key/value memory for the self-attention layer
{
const int n_mem = n_text_layer*n_text_ctx;
const int n_elements = n_text_state*n_mem;
@ -996,9 +1008,30 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
model.memory_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
}
// key/value memory for the cross-attention layer
const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
fprintf(stderr, "%s: memory size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
}
// create the ggml context for the key/value memory (cross-attention)
{
struct ggml_init_params params;
params.mem_size = wctx.buf_memory_cross.size();
params.mem_buffer = wctx.buf_memory_cross.data();
auto & ctx = model.ctx_mem_cross;
ctx = ggml_init(params);
if (!ctx) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}
{
const int n_audio_ctx = hparams.n_audio_ctx;
const auto & hparams = model.hparams;
const int n_text_state = hparams.n_text_state;
const int n_text_layer = hparams.n_text_layer;
const int n_audio_ctx = hparams.n_audio_ctx;
const int n_mem = n_text_layer*n_audio_ctx;
const int n_elements = n_text_state*n_mem;
@ -1007,10 +1040,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
model.memory_cross_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
}
const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
const size_t memory_cross_size = ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
fprintf(stderr, "%s: memory size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
fprintf(stderr, "%s: memory x size = %7.2f MB\n", __func__, memory_cross_size/1024.0/1024.0);
}
@ -2345,6 +2376,9 @@ void whisper_free(struct whisper_context * ctx) {
if (ctx->model.ctx_mem) {
ggml_free(ctx->model.ctx_mem);
}
if (ctx->model.ctx_mem_cross) {
ggml_free(ctx->model.ctx_mem_cross);
}
if (ctx->buf_model) {
delete ctx->buf_model;
}
@ -3381,48 +3415,57 @@ int whisper_full_parallel(
auto & model = ctxs[i].model;
// create the ggml memory context
// separate key + value memory for each processor (self-attention)
{
struct ggml_init_params params;
params.mem_size = ctxs[i].buf_memory.size();
params.mem_buffer = ctxs[i].buf_memory.data();
model.ctx_mem = ggml_init(params);
if (!model.ctx_mem) {
auto & mctx = model.ctx_mem;
mctx = ggml_init(params);
if (!mctx) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}
}
// separate key + value memory for each processor
{
auto & mctx = model.ctx_mem;
const auto & hparams = model.hparams;
{
const auto & hparams = model.hparams;
const int n_text_state = hparams.n_text_state;
const int n_text_layer = hparams.n_text_layer;
const int n_text_ctx = hparams.n_text_ctx;
const int n_text_state = hparams.n_text_state;
const int n_text_layer = hparams.n_text_layer;
const int n_text_ctx = hparams.n_text_ctx;
// key/value memory for the self-attention layer
{
const int n_mem = n_text_layer*n_text_ctx;
const int n_elements = n_text_state*n_mem;
model.memory_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
model.memory_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
}
}
// key/value memory for the cross-attention layer
{
const int n_audio_ctx = hparams.n_audio_ctx;
const int n_mem = n_text_layer*n_audio_ctx;
const int n_elements = n_text_state*n_mem;
// separate key + value memory for each processor (cross-attention)
{
struct ggml_init_params params;
params.mem_size = ctxs[i].buf_memory_cross.size();
params.mem_buffer = ctxs[i].buf_memory_cross.data();
model.memory_cross_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
model.memory_cross_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
auto & mctx = model.ctx_mem_cross;
mctx = ggml_init(params);
if (!mctx) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}
const auto & hparams = model.hparams;
const int n_text_state = hparams.n_text_state;
const int n_text_layer = hparams.n_text_layer;
const int n_audio_ctx = hparams.n_audio_ctx;
const int n_mem = n_text_layer*n_audio_ctx;
const int n_elements = n_text_state*n_mem;
model.memory_cross_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
model.memory_cross_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
}
}

Loading…
Cancel
Save