whisper : fix extra memory usage after recent processor changes

Had increased the memory buffer to the size of the model and forgot to
bring it down.
pull/95/head^2
Georgi Gerganov 2 years ago
parent c63ce24834
commit 02dfd5b8c3
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -133,11 +133,19 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
static const size_t MB = 1024*1024;
static const std::map<e_model, size_t> MEM_REQ_MODEL = {
{ MODEL_TINY, 86ull*MB },
{ MODEL_BASE, 165ull*MB },
{ MODEL_SMALL, 540ull*MB },
{ MODEL_MEDIUM, 1650ull*MB },
{ MODEL_LARGE, 3260ull*MB },
{ MODEL_TINY, 74ull*MB },
{ MODEL_BASE, 142ull*MB },
{ MODEL_SMALL, 466ull*MB },
{ MODEL_MEDIUM, 1464ull*MB },
{ MODEL_LARGE, 2952ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_MEMORY = {
{ MODEL_TINY, 12ull*MB },
{ MODEL_BASE, 24ull*MB },
{ MODEL_SMALL, 70ull*MB },
{ MODEL_MEDIUM, 184ull*MB },
{ MODEL_LARGE, 306ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
@ -498,7 +506,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
wctx.buf_model = new std::vector<uint8_t>();
wctx.buf_model->resize(MEM_REQ_MODEL.at(model.type));
wctx.buf_memory.resize(std::max(MEM_REQ_MODEL.at(model.type), MEM_REQ_MODEL.at(model.type))); // TODO: TMP !!!
wctx.buf_memory.resize(MEM_REQ_MEMORY.at(model.type));
wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
@ -722,20 +730,6 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
}
}
// create the ggml memory context
{
struct ggml_init_params params = {
.mem_size = wctx.buf_memory.size(),
.mem_buffer = wctx.buf_memory.data(),
};
model.ctx_mem = ggml_init(params);
if (!model.ctx_mem) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}
}
// prepare memory for the weights
{
auto & ctx = model.ctx;
@ -932,6 +926,20 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
}
}
// create the ggml memory context
{
struct ggml_init_params params = {
.mem_size = wctx.buf_memory.size(),
.mem_buffer = wctx.buf_memory.data(),
};
model.ctx_mem = ggml_init(params);
if (!model.ctx_mem) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}
}
// key + value memory
{
auto & ctx = model.ctx_mem;

Loading…
Cancel
Save