diff --git a/examples/whisper/whisper.cpp b/examples/whisper/whisper.cpp index 90b7f7c..312635d 100644 --- a/examples/whisper/whisper.cpp +++ b/examples/whisper/whisper.cpp @@ -252,26 +252,26 @@ static const std::map MEM_REQ_SCRATCH3 = { { MODEL_LARGE, 9ull*MB }, }; -// TODO determine this on runtime -#if 0 // F16 -static const std::map MEM_REQ_MODEL = { - { MODEL_TINY, 74ull*MB }, - { MODEL_BASE, 142ull*MB }, - { MODEL_SMALL, 466ull*MB }, - { MODEL_MEDIUM, 1464ull*MB }, - { MODEL_LARGE, 2952ull*MB }, -}; - -#else // Q4 - -static const std::map MEM_REQ_MODEL = { - { MODEL_TINY, 1ull*MB }, // TODO - { MODEL_BASE, 1ull*MB }, // TODO - { MODEL_SMALL, 154ull*MB }, - { MODEL_MEDIUM, 470ull*MB }, - { MODEL_LARGE, 940ull*MB }, +static const std::map> MEM_REQ_MODEL = { + { GGML_TYPE_F16, + { + { MODEL_TINY, 74ull*MB }, + { MODEL_BASE, 142ull*MB }, + { MODEL_SMALL, 466ull*MB }, + { MODEL_MEDIUM, 1464ull*MB }, + { MODEL_LARGE, 2952ull*MB }, + }, + }, + { GGML_TYPE_Q4_0, + { + { MODEL_TINY, 26ull*MB }, + { MODEL_BASE, 50ull*MB }, + { MODEL_SMALL, 154ull*MB }, + { MODEL_MEDIUM, 470ull*MB }, + { MODEL_LARGE, 940ull*MB }, + }, + }, }; -#endif static const std::map MEM_REQ_KV_SELF = { { MODEL_TINY, 3ull*MB }, @@ -829,7 +829,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con MEM_REQ_SCRATCH1.at (model.type) + MEM_REQ_SCRATCH2.at (model.type) + MEM_REQ_SCRATCH3.at (model.type) + - scale*MEM_REQ_MODEL.at (model.type) + + scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type) + scale*MEM_REQ_KV_CROSS.at(model.type) + scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)); @@ -845,7 +845,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con // always have at least one decoder wctx.model.buf = new std::vector(); - wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type)); + wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type)); if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, GGML_TYPE_F16, model.hparams.n_text_ctx)) { fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);