whisper : add loader to allow loading from other than file

pull/353/head
Syahmi Azhar 3 years ago committed by Georgi Gerganov
parent 52a3e0c92a
commit 6859c6207a
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -437,8 +437,8 @@ struct whisper_context {
};
template<typename T>
static void read_safe(std::ifstream& fin, T& dest) {
fin.read((char*)& dest, sizeof(T));
static void read_safe(whisper_model_loader * loader, T & dest) {
loader->read(loader->context, &dest, sizeof(T));
}
// load the model from a ggml file
@ -452,24 +452,18 @@ static void read_safe(std::ifstream& fin, T& dest) {
//
// see the convert-pt-to-ggml.py script for details
//
static bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
fprintf(stderr, "%s: loading model from '%s'\n", __func__, fname.c_str());
static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) {
fprintf(stderr, "%s: loading model\n", __func__);
auto & model = wctx.model;
auto & vocab = wctx.vocab;
auto fin = std::ifstream(fname, std::ios::binary);
if (!fin) {
fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
return false;
}
// verify magic
{
uint32_t magic;
read_safe(fin, magic);
read_safe(loader, magic);
if (magic != 0x67676d6c) {
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
fprintf(stderr, "%s: invalid model file (bad magic)\n", __func__);
return false;
}
}
@ -478,17 +472,17 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
{
auto & hparams = model.hparams;
read_safe(fin, hparams.n_vocab);
read_safe(fin, hparams.n_audio_ctx);
read_safe(fin, hparams.n_audio_state);
read_safe(fin, hparams.n_audio_head);
read_safe(fin, hparams.n_audio_layer);
read_safe(fin, hparams.n_text_ctx);
read_safe(fin, hparams.n_text_state);
read_safe(fin, hparams.n_text_head);
read_safe(fin, hparams.n_text_layer);
read_safe(fin, hparams.n_mels);
read_safe(fin, hparams.f16);
read_safe(loader, hparams.n_vocab);
read_safe(loader, hparams.n_audio_ctx);
read_safe(loader, hparams.n_audio_state);
read_safe(loader, hparams.n_audio_head);
read_safe(loader, hparams.n_audio_layer);
read_safe(loader, hparams.n_text_ctx);
read_safe(loader, hparams.n_text_state);
read_safe(loader, hparams.n_text_head);
read_safe(loader, hparams.n_text_layer);
read_safe(loader, hparams.n_mels);
read_safe(loader, hparams.f16);
assert(hparams.n_text_state == hparams.n_audio_state);
@ -536,17 +530,17 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
{
auto & filters = wctx.model.filters;
read_safe(fin, filters.n_mel);
read_safe(fin, filters.n_fft);
read_safe(loader, filters.n_mel);
read_safe(loader, filters.n_fft);
filters.data.resize(filters.n_mel * filters.n_fft);
fin.read((char *) filters.data.data(), filters.data.size() * sizeof(float));
loader->read(loader->context, filters.data.data(), filters.data.size() * sizeof(float));
}
// load vocab
{
int32_t n_vocab = 0;
read_safe(fin, n_vocab);
read_safe(loader, n_vocab);
//if (n_vocab != model.hparams.n_vocab) {
// fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
@ -561,11 +555,11 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
for (int i = 0; i < n_vocab; i++) {
uint32_t len;
read_safe(fin, len);
read_safe(loader, len);
if (len > 0) {
tmp.resize(len);
fin.read(&tmp[0], tmp.size()); // read to buffer
loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer
word.assign(&tmp[0], tmp.size());
} else {
// seems like we have an empty-string token in multi-language models (i = 50256)
@ -1017,24 +1011,24 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
int32_t length;
int32_t ftype;
read_safe(fin, n_dims);
read_safe(fin, length);
read_safe(fin, ftype);
read_safe(loader, n_dims);
read_safe(loader, length);
read_safe(loader, ftype);
if (fin.eof()) {
if (loader->eof(loader->context)) {
break;
}
int32_t nelements = 1;
int32_t ne[3] = { 1, 1, 1 };
for (int i = 0; i < n_dims; ++i) {
read_safe(fin, ne[i]);
read_safe(loader, ne[i]);
nelements *= ne[i];
}
std::string name;
std::vector<char> tmp(length); // create a buffer
fin.read(&tmp[0], tmp.size()); // read to buffer
loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer
name.assign(&tmp[0], tmp.size());
if (model.tensors.find(name) == model.tensors.end()) {
@ -1062,7 +1056,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
return false;
}
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
//printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
total_size += ggml_nbytes(tensor);
@ -1079,8 +1073,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
}
}
fin.close();
return true;
}
@ -2241,6 +2233,37 @@ static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, cons
//
struct whisper_context * whisper_init(const char * path_model) {
whisper_model_loader loader;
fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model);
auto fin = std::ifstream(path_model, std::ios::binary);
if (!fin) {
fprintf(stderr, "%s: failed to open '%s'\n", __func__, path_model);
return nullptr;
}
loader.context = &fin;
loader.read = [](void * ctx, void * output, size_t read_size) {
std::ifstream * fin = (std::ifstream*)ctx;
fin->read((char *)output, read_size);
return read_size;
};
loader.eof = [](void * ctx) {
std::ifstream * fin = (std::ifstream*)ctx;
return fin->eof();
};
loader.close = [](void * ctx) {
std::ifstream * fin = (std::ifstream*)ctx;
fin->close();
};
return whisper_init_loader(&loader);
}
struct whisper_context * whisper_init_loader(struct whisper_model_loader * loader) {
ggml_time_init();
whisper_context * ctx = new whisper_context;
@ -2249,13 +2272,15 @@ struct whisper_context * whisper_init(const char * path_model) {
ctx->t_start_us = t_start_us;
if (!whisper_model_load(path_model, *ctx)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, path_model);
if (!whisper_model_load(loader, *ctx)) {
loader->close(loader->context);
fprintf(stderr, "%s: failed to load model\n", __func__);
delete ctx;
return nullptr;
}
ctx->t_load_us = ggml_time_us() - t_start_us;
loader->close(loader->context);
return ctx;
}

@ -84,9 +84,18 @@ extern "C" {
float vlen; // voice length of the token
} whisper_token_data;
typedef struct whisper_model_loader {
void* context;
size_t (*read)(void * ctx, void * output, size_t read_size);
bool (*eof)(void * ctx);
void (*close)(void * ctx);
} whisper_model_loader;
// Allocates all memory needed for the model and loads the model from the given file.
// Returns NULL on failure.
WHISPER_API struct whisper_context * whisper_init(const char * path_model);
WHISPER_API struct whisper_context * whisper_init_loader(struct whisper_model_loader * loader);
// Frees all memory allocated by the model.
WHISPER_API void whisper_free(struct whisper_context * ctx);

Loading…
Cancel
Save