whisper : separate self and cross attention memory

Initial step needed for supporting parallel decoders
pull/291/head
Georgi Gerganov 3 years ago
parent 21559537e2
commit 523e0494a6
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -150,11 +150,19 @@ static const std::map<e_model, size_t> MEM_REQ_MODEL = {
}; };
static const std::map<e_model, size_t> MEM_REQ_MEMORY = { static const std::map<e_model, size_t> MEM_REQ_MEMORY = {
{ MODEL_TINY, 12ull*MB }, { MODEL_TINY, 3ull*MB },
{ MODEL_BASE, 24ull*MB }, { MODEL_BASE, 6ull*MB },
{ MODEL_SMALL, 70ull*MB }, { MODEL_SMALL, 16ull*MB },
{ MODEL_MEDIUM, 184ull*MB }, { MODEL_MEDIUM, 43ull*MB },
{ MODEL_LARGE, 306ull*MB }, { MODEL_LARGE, 71ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_MEMORY_CROSS = {
{ MODEL_TINY, 9ull*MB },
{ MODEL_BASE, 18ull*MB },
{ MODEL_SMALL, 53ull*MB },
{ MODEL_MEDIUM, 141ull*MB },
{ MODEL_LARGE, 235ull*MB },
}; };
static const std::map<e_model, size_t> MEM_REQ_ENCODE = { static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
@ -391,22 +399,27 @@ struct whisper_model {
std::vector<whisper_layer_encoder> layers_encoder; std::vector<whisper_layer_encoder> layers_encoder;
std::vector<whisper_layer_decoder> layers_decoder; std::vector<whisper_layer_decoder> layers_decoder;
// key + value memory // key + value memory for self attention
struct ggml_tensor * memory_k; struct ggml_tensor * memory_k;
struct ggml_tensor * memory_v; struct ggml_tensor * memory_v;
// key + value memory for cross attention
struct ggml_tensor * memory_cross_k; struct ggml_tensor * memory_cross_k;
struct ggml_tensor * memory_cross_v; struct ggml_tensor * memory_cross_v;
// context // context
struct ggml_context * ctx; struct ggml_context * ctx;
struct ggml_context * ctx_mem; struct ggml_context * ctx_mem;
struct ggml_context * ctx_mem_cross;
// tensors // tensors
int n_loaded; int n_loaded;
std::map<std::string, struct ggml_tensor *> tensors; std::map<std::string, struct ggml_tensor *> tensors;
}; };
struct whisper_decoder_data {
};
struct whisper_context { struct whisper_context {
int64_t t_load_us = 0; int64_t t_load_us = 0;
int64_t t_mel_us = 0; int64_t t_mel_us = 0;
@ -417,6 +430,7 @@ struct whisper_context {
std::vector<uint8_t> * buf_model; // the model buffer is read-only and can be shared between processors std::vector<uint8_t> * buf_model; // the model buffer is read-only and can be shared between processors
std::vector<uint8_t> buf_memory; std::vector<uint8_t> buf_memory;
std::vector<uint8_t> buf_memory_cross;
std::vector<uint8_t> buf_compute; std::vector<uint8_t> buf_compute;
std::vector<uint8_t> buf_compute_layer; std::vector<uint8_t> buf_compute_layer;
@ -533,6 +547,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
wctx.buf_model = new std::vector<uint8_t>(); wctx.buf_model = new std::vector<uint8_t>();
wctx.buf_model->resize(MEM_REQ_MODEL.at(model.type)); wctx.buf_model->resize(MEM_REQ_MODEL.at(model.type));
wctx.buf_memory.resize(MEM_REQ_MEMORY.at(model.type)); wctx.buf_memory.resize(MEM_REQ_MEMORY.at(model.type));
wctx.buf_memory_cross.resize(MEM_REQ_MEMORY_CROSS.at(model.type));
wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type))); wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
} }
@ -631,6 +646,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
const size_t mem_required = const size_t mem_required =
wctx.buf_model->size() + wctx.buf_model->size() +
wctx.buf_memory.size() + wctx.buf_memory.size() +
wctx.buf_memory_cross.size() +
wctx.buf_compute.size() + wctx.buf_compute.size() +
wctx.buf_compute_layer.size(); wctx.buf_compute_layer.size();
@ -964,31 +980,27 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
} }
} }
// create the ggml memory context // create the ggml context for the key/value memory (self-attention)
{ {
struct ggml_init_params params; struct ggml_init_params params;
params.mem_size = wctx.buf_memory.size(); params.mem_size = wctx.buf_memory.size();
params.mem_buffer = wctx.buf_memory.data(); params.mem_buffer = wctx.buf_memory.data();
model.ctx_mem = ggml_init(params); auto & ctx = model.ctx_mem;
if (!model.ctx_mem) {
ctx = ggml_init(params);
if (!ctx) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__); fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false; return false;
} }
}
// key + value memory
{ {
auto & ctx = model.ctx_mem;
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
const int n_text_state = hparams.n_text_state; const int n_text_state = hparams.n_text_state;
const int n_text_layer = hparams.n_text_layer; const int n_text_layer = hparams.n_text_layer;
const int n_text_ctx = hparams.n_text_ctx; const int n_text_ctx = hparams.n_text_ctx;
// key/value memory for the self-attention layer
{
const int n_mem = n_text_layer*n_text_ctx; const int n_mem = n_text_layer*n_text_ctx;
const int n_elements = n_text_state*n_mem; const int n_elements = n_text_state*n_mem;
@ -996,8 +1008,29 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
model.memory_v = ggml_new_tensor_1d(ctx, wtype, n_elements); model.memory_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
} }
// key/value memory for the cross-attention layer const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
fprintf(stderr, "%s: memory size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
}
// create the ggml context for the key/value memory (cross-attention)
{ {
struct ggml_init_params params;
params.mem_size = wctx.buf_memory_cross.size();
params.mem_buffer = wctx.buf_memory_cross.data();
auto & ctx = model.ctx_mem_cross;
ctx = ggml_init(params);
if (!ctx) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}
{
const auto & hparams = model.hparams;
const int n_text_state = hparams.n_text_state;
const int n_text_layer = hparams.n_text_layer;
const int n_audio_ctx = hparams.n_audio_ctx; const int n_audio_ctx = hparams.n_audio_ctx;
const int n_mem = n_text_layer*n_audio_ctx; const int n_mem = n_text_layer*n_audio_ctx;
@ -1007,10 +1040,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
model.memory_cross_v = ggml_new_tensor_1d(ctx, wtype, n_elements); model.memory_cross_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
} }
const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
const size_t memory_cross_size = ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v); const size_t memory_cross_size = ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
fprintf(stderr, "%s: memory size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
fprintf(stderr, "%s: memory x size = %7.2f MB\n", __func__, memory_cross_size/1024.0/1024.0); fprintf(stderr, "%s: memory x size = %7.2f MB\n", __func__, memory_cross_size/1024.0/1024.0);
} }
@ -2345,6 +2376,9 @@ void whisper_free(struct whisper_context * ctx) {
if (ctx->model.ctx_mem) { if (ctx->model.ctx_mem) {
ggml_free(ctx->model.ctx_mem); ggml_free(ctx->model.ctx_mem);
} }
if (ctx->model.ctx_mem_cross) {
ggml_free(ctx->model.ctx_mem_cross);
}
if (ctx->buf_model) { if (ctx->buf_model) {
delete ctx->buf_model; delete ctx->buf_model;
} }
@ -3381,40 +3415,50 @@ int whisper_full_parallel(
auto & model = ctxs[i].model; auto & model = ctxs[i].model;
// create the ggml memory context // separate key + value memory for each processor (self-attention)
{ {
struct ggml_init_params params; struct ggml_init_params params;
params.mem_size = ctxs[i].buf_memory.size(); params.mem_size = ctxs[i].buf_memory.size();
params.mem_buffer = ctxs[i].buf_memory.data(); params.mem_buffer = ctxs[i].buf_memory.data();
model.ctx_mem = ggml_init(params); auto & mctx = model.ctx_mem;
if (!model.ctx_mem) { mctx = ggml_init(params);
if (!mctx) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__); fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false; return false;
} }
}
// separate key + value memory for each processor
{ {
auto & mctx = model.ctx_mem;
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
const int n_text_state = hparams.n_text_state; const int n_text_state = hparams.n_text_state;
const int n_text_layer = hparams.n_text_layer; const int n_text_layer = hparams.n_text_layer;
const int n_text_ctx = hparams.n_text_ctx; const int n_text_ctx = hparams.n_text_ctx;
// key/value memory for the self-attention layer
{
const int n_mem = n_text_layer*n_text_ctx; const int n_mem = n_text_layer*n_text_ctx;
const int n_elements = n_text_state*n_mem; const int n_elements = n_text_state*n_mem;
model.memory_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements); model.memory_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
model.memory_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements); model.memory_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
} }
}
// key/value memory for the cross-attention layer // separate key + value memory for each processor (cross-attention)
{ {
struct ggml_init_params params;
params.mem_size = ctxs[i].buf_memory_cross.size();
params.mem_buffer = ctxs[i].buf_memory_cross.data();
auto & mctx = model.ctx_mem_cross;
mctx = ggml_init(params);
if (!mctx) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}
const auto & hparams = model.hparams;
const int n_text_state = hparams.n_text_state;
const int n_text_layer = hparams.n_text_layer;
const int n_audio_ctx = hparams.n_audio_ctx; const int n_audio_ctx = hparams.n_audio_ctx;
const int n_mem = n_text_layer*n_audio_ctx; const int n_mem = n_text_layer*n_audio_ctx;
@ -3424,7 +3468,6 @@ int whisper_full_parallel(
model.memory_cross_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements); model.memory_cross_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
} }
} }
}
const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000; const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000;
const int n_samples_per_processor = (n_samples - offset_samples)/n_processors; const int n_samples_per_processor = (n_samples - offset_samples)/n_processors;

Loading…
Cancel
Save