From 6859c6207ab99e84f8bf2e5e1f99e080ae6c09d1 Mon Sep 17 00:00:00 2001 From: Syahmi Azhar Date: Sun, 1 Jan 2023 20:17:11 +0800 Subject: [PATCH] whisper : add loader to allow loading from other than file --- whisper.cpp | 105 ++++++++++++++++++++++++++++++++-------------------- whisper.h | 9 +++++ 2 files changed, 74 insertions(+), 40 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index e8d9f0c..62c40de 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -437,8 +437,8 @@ struct whisper_context { }; template -static void read_safe(std::ifstream& fin, T& dest) { - fin.read((char*)& dest, sizeof(T)); +static void read_safe(whisper_model_loader * loader, T & dest) { + loader->read(loader->context, &dest, sizeof(T)); } // load the model from a ggml file @@ -452,24 +452,18 @@ static void read_safe(std::ifstream& fin, T& dest) { // // see the convert-pt-to-ggml.py script for details // -static bool whisper_model_load(const std::string & fname, whisper_context & wctx) { - fprintf(stderr, "%s: loading model from '%s'\n", __func__, fname.c_str()); +static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) { + fprintf(stderr, "%s: loading model\n", __func__); auto & model = wctx.model; auto & vocab = wctx.vocab; - auto fin = std::ifstream(fname, std::ios::binary); - if (!fin) { - fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str()); - return false; - } - // verify magic { uint32_t magic; - read_safe(fin, magic); + read_safe(loader, magic); if (magic != 0x67676d6c) { - fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str()); + fprintf(stderr, "%s: invalid model file (bad magic)\n", __func__); return false; } } @@ -478,17 +472,17 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx { auto & hparams = model.hparams; - read_safe(fin, hparams.n_vocab); - read_safe(fin, hparams.n_audio_ctx); - read_safe(fin, hparams.n_audio_state); - read_safe(fin, hparams.n_audio_head); - read_safe(fin, hparams.n_audio_layer); - read_safe(fin, hparams.n_text_ctx); - read_safe(fin, hparams.n_text_state); - read_safe(fin, hparams.n_text_head); - read_safe(fin, hparams.n_text_layer); - read_safe(fin, hparams.n_mels); - read_safe(fin, hparams.f16); + read_safe(loader, hparams.n_vocab); + read_safe(loader, hparams.n_audio_ctx); + read_safe(loader, hparams.n_audio_state); + read_safe(loader, hparams.n_audio_head); + read_safe(loader, hparams.n_audio_layer); + read_safe(loader, hparams.n_text_ctx); + read_safe(loader, hparams.n_text_state); + read_safe(loader, hparams.n_text_head); + read_safe(loader, hparams.n_text_layer); + read_safe(loader, hparams.n_mels); + read_safe(loader, hparams.f16); assert(hparams.n_text_state == hparams.n_audio_state); @@ -536,17 +530,17 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx { auto & filters = wctx.model.filters; - read_safe(fin, filters.n_mel); - read_safe(fin, filters.n_fft); + read_safe(loader, filters.n_mel); + read_safe(loader, filters.n_fft); filters.data.resize(filters.n_mel * filters.n_fft); - fin.read((char *) filters.data.data(), filters.data.size() * sizeof(float)); + loader->read(loader->context, filters.data.data(), filters.data.size() * sizeof(float)); } // load vocab { int32_t n_vocab = 0; - read_safe(fin, n_vocab); + read_safe(loader, n_vocab); //if (n_vocab != model.hparams.n_vocab) { // fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n", @@ -561,11 +555,11 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx for (int i = 0; i < n_vocab; i++) { uint32_t len; - read_safe(fin, len); + read_safe(loader, len); if (len > 0) { tmp.resize(len); - fin.read(&tmp[0], tmp.size()); // read to buffer + loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer word.assign(&tmp[0], tmp.size()); } else { // seems like we have an empty-string token in multi-language models (i = 50256) @@ -1017,24 +1011,24 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx int32_t length; int32_t ftype; - read_safe(fin, n_dims); - read_safe(fin, length); - read_safe(fin, ftype); + read_safe(loader, n_dims); + read_safe(loader, length); + read_safe(loader, ftype); - if (fin.eof()) { + if (loader->eof(loader->context)) { break; } int32_t nelements = 1; int32_t ne[3] = { 1, 1, 1 }; for (int i = 0; i < n_dims; ++i) { - read_safe(fin, ne[i]); + read_safe(loader, ne[i]); nelements *= ne[i]; } std::string name; std::vector tmp(length); // create a buffer - fin.read(&tmp[0], tmp.size()); // read to buffer + loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer name.assign(&tmp[0], tmp.size()); if (model.tensors.find(name) == model.tensors.end()) { @@ -1062,7 +1056,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx return false; } - fin.read(reinterpret_cast(tensor->data), ggml_nbytes(tensor)); + loader->read(loader->context, tensor->data, ggml_nbytes(tensor)); //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0); total_size += ggml_nbytes(tensor); @@ -1079,8 +1073,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx } } - fin.close(); - return true; } @@ -2241,6 +2233,37 @@ static std::vector tokenize(const whisper_vocab & vocab, cons // struct whisper_context * whisper_init(const char * path_model) { + whisper_model_loader loader; + + fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model); + + auto fin = std::ifstream(path_model, std::ios::binary); + if (!fin) { + fprintf(stderr, "%s: failed to open '%s'\n", __func__, path_model); + return nullptr; + } + + loader.context = &fin; + loader.read = [](void * ctx, void * output, size_t read_size) { + std::ifstream * fin = (std::ifstream*)ctx; + fin->read((char *)output, read_size); + return read_size; + }; + + loader.eof = [](void * ctx) { + std::ifstream * fin = (std::ifstream*)ctx; + return fin->eof(); + }; + + loader.close = [](void * ctx) { + std::ifstream * fin = (std::ifstream*)ctx; + fin->close(); + }; + + return whisper_init_loader(&loader); +} + +struct whisper_context * whisper_init_loader(struct whisper_model_loader * loader) { ggml_time_init(); whisper_context * ctx = new whisper_context; @@ -2249,13 +2272,15 @@ struct whisper_context * whisper_init(const char * path_model) { ctx->t_start_us = t_start_us; - if (!whisper_model_load(path_model, *ctx)) { - fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, path_model); + if (!whisper_model_load(loader, *ctx)) { + loader->close(loader->context); + fprintf(stderr, "%s: failed to load model\n", __func__); delete ctx; return nullptr; } ctx->t_load_us = ggml_time_us() - t_start_us; + loader->close(loader->context); return ctx; } diff --git a/whisper.h b/whisper.h index 8cb16ca..0fa936b 100644 --- a/whisper.h +++ b/whisper.h @@ -84,9 +84,18 @@ extern "C" { float vlen; // voice length of the token } whisper_token_data; + typedef struct whisper_model_loader { + void* context; + + size_t (*read)(void * ctx, void * output, size_t read_size); + bool (*eof)(void * ctx); + void (*close)(void * ctx); + } whisper_model_loader; + // Allocates all memory needed for the model and loads the model from the given file. // Returns NULL on failure. WHISPER_API struct whisper_context * whisper_init(const char * path_model); + WHISPER_API struct whisper_context * whisper_init_loader(struct whisper_model_loader * loader); // Frees all memory allocated by the model. WHISPER_API void whisper_free(struct whisper_context * ctx);