fix potential bug reading model data into a small size optimized string which could lead to memory corruption. In an SSO string, you can't write data to &str[0] and expect it to work well.

Also added a small wrapper function to more safely read model data without having to get the sizeof right. I tested this on tiny, base and large models, there was no change in behaviour.
pull/253/head
bert hubert 2 years ago committed by Georgi Gerganov
parent 603f97ba11
commit d1da35de06

@ -429,6 +429,12 @@ struct whisper_context {
int32_t exp_n_audio_ctx; // 0 - use default int32_t exp_n_audio_ctx; // 0 - use default
}; };
template<typename T>
static void read_safe(std::ifstream& fin, T& dest)
{
fin.read((char*)& dest, sizeof(T));
}
// load the model from a ggml file // load the model from a ggml file
// //
// file format: // file format:
@ -455,7 +461,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
// verify magic // verify magic
{ {
uint32_t magic; uint32_t magic;
fin.read((char *) &magic, sizeof(magic)); read_safe(fin, magic);
if (magic != 0x67676d6c) { if (magic != 0x67676d6c) {
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str()); fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
return false; return false;
@ -466,17 +472,17 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
{ {
auto & hparams = model.hparams; auto & hparams = model.hparams;
fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab)); read_safe(fin, hparams.n_vocab);
fin.read((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx)); read_safe(fin, hparams.n_audio_ctx);
fin.read((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state)); read_safe(fin, hparams.n_audio_state);
fin.read((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head)); read_safe(fin, hparams.n_audio_head);
fin.read((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer)); read_safe(fin, hparams.n_audio_layer);
fin.read((char *) &hparams.n_text_ctx, sizeof(hparams.n_text_ctx)); read_safe(fin, hparams.n_text_ctx);
fin.read((char *) &hparams.n_text_state, sizeof(hparams.n_text_state)); read_safe(fin, hparams.n_text_state);
fin.read((char *) &hparams.n_text_head, sizeof(hparams.n_text_head)); read_safe(fin, hparams.n_text_head);
fin.read((char *) &hparams.n_text_layer, sizeof(hparams.n_text_layer)); read_safe(fin, hparams.n_text_layer);
fin.read((char *) &hparams.n_mels, sizeof(hparams.n_mels)); read_safe(fin, hparams.n_mels);
fin.read((char *) &hparams.f16, sizeof(hparams.f16)); read_safe(fin, hparams.f16);
assert(hparams.n_text_state == hparams.n_audio_state); assert(hparams.n_text_state == hparams.n_audio_state);
@ -524,8 +530,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
{ {
auto & filters = wctx.model.filters; auto & filters = wctx.model.filters;
fin.read((char *) &filters.n_mel, sizeof(filters.n_mel)); read_safe(fin, filters.n_mel);
fin.read((char *) &filters.n_fft, sizeof(filters.n_fft)); read_safe(fin, filters.n_fft);
filters.data.resize(filters.n_mel * filters.n_fft); filters.data.resize(filters.n_mel * filters.n_fft);
fin.read((char *) filters.data.data(), filters.data.size() * sizeof(float)); fin.read((char *) filters.data.data(), filters.data.size() * sizeof(float));
@ -534,7 +540,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
// load vocab // load vocab
{ {
int32_t n_vocab = 0; int32_t n_vocab = 0;
fin.read((char *) &n_vocab, sizeof(n_vocab)); read_safe(fin, n_vocab);
//if (n_vocab != model.hparams.n_vocab) { //if (n_vocab != model.hparams.n_vocab) {
// fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n", // fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
@ -545,10 +551,11 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
std::string word; std::string word;
for (int i = 0; i < n_vocab; i++) { for (int i = 0; i < n_vocab; i++) {
uint32_t len; uint32_t len;
fin.read((char *) &len, sizeof(len)); read_safe(fin, len);
word.resize(len); std::vector<char> tmp(len); // create a buffer
fin.read((char *) word.data(), len); fin.read( &tmp[0], tmp.size() ); // read to buffer
word.assign(&tmp[0], tmp.size());
vocab.token_to_id[word] = i; vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word; vocab.id_to_token[i] = word;
@ -998,9 +1005,9 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
int32_t length; int32_t length;
int32_t ftype; int32_t ftype;
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims)); read_safe(fin, n_dims);
fin.read(reinterpret_cast<char *>(&length), sizeof(length)); read_safe(fin, length);
fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype)); read_safe(fin, ftype);
if (fin.eof()) { if (fin.eof()) {
break; break;
@ -1009,12 +1016,14 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
int32_t nelements = 1; int32_t nelements = 1;
int32_t ne[3] = { 1, 1, 1 }; int32_t ne[3] = { 1, 1, 1 };
for (int i = 0; i < n_dims; ++i) { for (int i = 0; i < n_dims; ++i) {
fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i])); read_safe(fin, ne[i]);
nelements *= ne[i]; nelements *= ne[i];
} }
std::string name(length, 0); std::string name;
fin.read(&name[0], length); std::vector<char> tmp(length); // create a buffer
fin.read( &tmp[0], tmp.size() ); // read to buffer
name.assign(&tmp[0], tmp.size());
if (model.tensors.find(name.data()) == model.tensors.end()) { if (model.tensors.find(name.data()) == model.tensors.end()) {
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data()); fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());

Loading…
Cancel
Save