ggml : sync latest whisper.cpp

pull/15/head
Georgi Gerganov 2 years ago
parent 6ed4da0b03
commit dee3684fec
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -478,7 +478,7 @@ int main(int argc, char ** argv) {
// whisper init // whisper init
struct whisper_context * ctx = whisper_init(params.model.c_str()); struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
if (ctx == nullptr) { if (ctx == nullptr) {
fprintf(stderr, "error: failed to initialize whisper context\n"); fprintf(stderr, "error: failed to initialize whisper context\n");

@ -437,8 +437,8 @@ struct whisper_context {
}; };
template<typename T> template<typename T>
static void read_safe(std::ifstream& fin, T& dest) { static void read_safe(whisper_model_loader * loader, T & dest) {
fin.read((char*)& dest, sizeof(T)); loader->read(loader->context, &dest, sizeof(T));
} }
// load the model from a ggml file // load the model from a ggml file
@ -452,24 +452,18 @@ static void read_safe(std::ifstream& fin, T& dest) {
// //
// see the convert-pt-to-ggml.py script for details // see the convert-pt-to-ggml.py script for details
// //
static bool whisper_model_load(const std::string & fname, whisper_context & wctx) { static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) {
fprintf(stderr, "%s: loading model from '%s'\n", __func__, fname.c_str()); fprintf(stderr, "%s: loading model\n", __func__);
auto & model = wctx.model; auto & model = wctx.model;
auto & vocab = wctx.vocab; auto & vocab = wctx.vocab;
auto fin = std::ifstream(fname, std::ios::binary);
if (!fin) {
fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
return false;
}
// verify magic // verify magic
{ {
uint32_t magic; uint32_t magic;
read_safe(fin, magic); read_safe(loader, magic);
if (magic != 0x67676d6c) { if (magic != 0x67676d6c) {
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str()); fprintf(stderr, "%s: invalid model data (bad magic)\n", __func__);
return false; return false;
} }
} }
@ -478,17 +472,17 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
{ {
auto & hparams = model.hparams; auto & hparams = model.hparams;
read_safe(fin, hparams.n_vocab); read_safe(loader, hparams.n_vocab);
read_safe(fin, hparams.n_audio_ctx); read_safe(loader, hparams.n_audio_ctx);
read_safe(fin, hparams.n_audio_state); read_safe(loader, hparams.n_audio_state);
read_safe(fin, hparams.n_audio_head); read_safe(loader, hparams.n_audio_head);
read_safe(fin, hparams.n_audio_layer); read_safe(loader, hparams.n_audio_layer);
read_safe(fin, hparams.n_text_ctx); read_safe(loader, hparams.n_text_ctx);
read_safe(fin, hparams.n_text_state); read_safe(loader, hparams.n_text_state);
read_safe(fin, hparams.n_text_head); read_safe(loader, hparams.n_text_head);
read_safe(fin, hparams.n_text_layer); read_safe(loader, hparams.n_text_layer);
read_safe(fin, hparams.n_mels); read_safe(loader, hparams.n_mels);
read_safe(fin, hparams.f16); read_safe(loader, hparams.f16);
assert(hparams.n_text_state == hparams.n_audio_state); assert(hparams.n_text_state == hparams.n_audio_state);
@ -536,17 +530,17 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
{ {
auto & filters = wctx.model.filters; auto & filters = wctx.model.filters;
read_safe(fin, filters.n_mel); read_safe(loader, filters.n_mel);
read_safe(fin, filters.n_fft); read_safe(loader, filters.n_fft);
filters.data.resize(filters.n_mel * filters.n_fft); filters.data.resize(filters.n_mel * filters.n_fft);
fin.read((char *) filters.data.data(), filters.data.size() * sizeof(float)); loader->read(loader->context, filters.data.data(), filters.data.size() * sizeof(float));
} }
// load vocab // load vocab
{ {
int32_t n_vocab = 0; int32_t n_vocab = 0;
read_safe(fin, n_vocab); read_safe(loader, n_vocab);
//if (n_vocab != model.hparams.n_vocab) { //if (n_vocab != model.hparams.n_vocab) {
// fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n", // fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
@ -561,11 +555,11 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
for (int i = 0; i < n_vocab; i++) { for (int i = 0; i < n_vocab; i++) {
uint32_t len; uint32_t len;
read_safe(fin, len); read_safe(loader, len);
if (len > 0) { if (len > 0) {
tmp.resize(len); tmp.resize(len);
fin.read(&tmp[0], tmp.size()); // read to buffer loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer
word.assign(&tmp[0], tmp.size()); word.assign(&tmp[0], tmp.size());
} else { } else {
// seems like we have an empty-string token in multi-language models (i = 50256) // seems like we have an empty-string token in multi-language models (i = 50256)
@ -1017,24 +1011,24 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
int32_t length; int32_t length;
int32_t ftype; int32_t ftype;
read_safe(fin, n_dims); read_safe(loader, n_dims);
read_safe(fin, length); read_safe(loader, length);
read_safe(fin, ftype); read_safe(loader, ftype);
if (fin.eof()) { if (loader->eof(loader->context)) {
break; break;
} }
int32_t nelements = 1; int32_t nelements = 1;
int32_t ne[3] = { 1, 1, 1 }; int32_t ne[3] = { 1, 1, 1 };
for (int i = 0; i < n_dims; ++i) { for (int i = 0; i < n_dims; ++i) {
read_safe(fin, ne[i]); read_safe(loader, ne[i]);
nelements *= ne[i]; nelements *= ne[i];
} }
std::string name; std::string name;
std::vector<char> tmp(length); // create a buffer std::vector<char> tmp(length); // create a buffer
fin.read(&tmp[0], tmp.size()); // read to buffer loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer
name.assign(&tmp[0], tmp.size()); name.assign(&tmp[0], tmp.size());
if (model.tensors.find(name) == model.tensors.end()) { if (model.tensors.find(name) == model.tensors.end()) {
@ -1062,7 +1056,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
return false; return false;
} }
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor)); loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
//printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0); //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
total_size += ggml_nbytes(tensor); total_size += ggml_nbytes(tensor);
@ -1079,8 +1073,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
} }
} }
fin.close();
return true; return true;
} }
@ -1479,6 +1471,7 @@ static bool whisper_encode(
} }
ggml_graph_compute(ctx0, &gf); ggml_graph_compute(ctx0, &gf);
//ggml_graph_print(&gf);
} }
//////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////
@ -2240,7 +2233,74 @@ static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, cons
// interface implementation // interface implementation
// //
struct whisper_context * whisper_init(const char * path_model) { struct whisper_context * whisper_init_from_file(const char * path_model) {
whisper_model_loader loader = {};
fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model);
auto fin = std::ifstream(path_model, std::ios::binary);
if (!fin) {
fprintf(stderr, "%s: failed to open '%s'\n", __func__, path_model);
return nullptr;
}
loader.context = &fin;
loader.read = [](void * ctx, void * output, size_t read_size) {
std::ifstream * fin = (std::ifstream*)ctx;
fin->read((char *)output, read_size);
return read_size;
};
loader.eof = [](void * ctx) {
std::ifstream * fin = (std::ifstream*)ctx;
return fin->eof();
};
loader.close = [](void * ctx) {
std::ifstream * fin = (std::ifstream*)ctx;
fin->close();
};
return whisper_init(&loader);
}
struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) {
struct buf_context {
uint8_t* buffer;
size_t size;
size_t current_offset;
};
buf_context ctx = { reinterpret_cast<uint8_t*>(buffer), buffer_size, 0 };
whisper_model_loader loader = {};
fprintf(stderr, "%s: loading model from buffer\n", __func__);
loader.context = &ctx;
loader.read = [](void * ctx, void * output, size_t read_size) {
buf_context * buf = reinterpret_cast<buf_context *>(ctx);
size_t size_to_copy = buf->current_offset + read_size < buf->size ? read_size : buf->size - buf->current_offset;
memcpy(output, buf->buffer + buf->current_offset, size_to_copy);
buf->current_offset += size_to_copy;
return size_to_copy;
};
loader.eof = [](void * ctx) {
buf_context * buf = reinterpret_cast<buf_context *>(ctx);
return buf->current_offset >= buf->size;
};
loader.close = [](void * /*ctx*/) { };
return whisper_init(&loader);
}
struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
ggml_time_init(); ggml_time_init();
whisper_context * ctx = new whisper_context; whisper_context * ctx = new whisper_context;
@ -2249,14 +2309,17 @@ struct whisper_context * whisper_init(const char * path_model) {
ctx->t_start_us = t_start_us; ctx->t_start_us = t_start_us;
if (!whisper_model_load(path_model, *ctx)) { if (!whisper_model_load(loader, *ctx)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, path_model); loader->close(loader->context);
fprintf(stderr, "%s: failed to load model\n", __func__);
delete ctx; delete ctx;
return nullptr; return nullptr;
} }
ctx->t_load_us = ggml_time_us() - t_start_us; ctx->t_load_us = ggml_time_us() - t_start_us;
loader->close(loader->context);
return ctx; return ctx;
} }
@ -3326,7 +3389,7 @@ static int timestamp_to_sample(int64_t t, int n_samples) {
} }
static int64_t sample_to_timestamp(int i_sample) { static int64_t sample_to_timestamp(int i_sample) {
return (100*i_sample)/WHISPER_SAMPLE_RATE; return (100ll*i_sample)/WHISPER_SAMPLE_RATE;
} }
// a cost-function / heuristic that is high for text that takes longer to pronounce // a cost-function / heuristic that is high for text that takes longer to pronounce

@ -1,6 +1,7 @@
#ifndef WHISPER_H #ifndef WHISPER_H
#define WHISPER_H #define WHISPER_H
#include <stddef.h>
#include <stdint.h> #include <stdint.h>
#include <stdbool.h> #include <stdbool.h>
@ -40,7 +41,7 @@ extern "C" {
// //
// ... // ...
// //
// struct whisper_context * ctx = whisper_init("/path/to/ggml-base.en.bin"); // struct whisper_context * ctx = whisper_init_from_file("/path/to/ggml-base.en.bin");
// //
// if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { // if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
// fprintf(stderr, "failed to process audio\n"); // fprintf(stderr, "failed to process audio\n");
@ -84,9 +85,20 @@ extern "C" {
float vlen; // voice length of the token float vlen; // voice length of the token
} whisper_token_data; } whisper_token_data;
// Allocates all memory needed for the model and loads the model from the given file. typedef struct whisper_model_loader {
// Returns NULL on failure. void * context;
WHISPER_API struct whisper_context * whisper_init(const char * path_model);
size_t (*read)(void * ctx, void * output, size_t read_size);
bool (*eof)(void * ctx);
void (*close)(void * ctx);
} whisper_model_loader;
// Various functions for loading a ggml whisper model.
// Allocate (almost) all memory needed for the model.
// Return NULL on failure
WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model);
WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size);
WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader);
// Frees all memory allocated by the model. // Frees all memory allocated by the model.
WHISPER_API void whisper_free(struct whisper_context * ctx); WHISPER_API void whisper_free(struct whisper_context * ctx);

@ -84,7 +84,7 @@ typedef void* thread_ret_t;
#define GGML_GELU_FP16 #define GGML_GELU_FP16
#define GGML_SOFT_MAX_UNROLL 4 #define GGML_SOFT_MAX_UNROLL 4
#define GGML_VEC_DOT_UNROLL 4 #define GGML_VEC_DOT_UNROLL 2
#ifdef GGML_USE_ACCELERATE #ifdef GGML_USE_ACCELERATE
// uncomment to use vDSP for soft max computation // uncomment to use vDSP for soft max computation
@ -923,9 +923,9 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) { inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 }; ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
const ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL] = { xv }; ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
for (int i = 1; i < GGML_VEC_DOT_UNROLL; ++i) { for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
x[i] = (ggml_fp16_t *) ((char *) xv + i*xs); x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
} }
@ -1109,8 +1109,8 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
ggml_float sum = 0.0; ggml_float sum = 0.0;
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
sum += x[i]; sum += x[i];
*s += sum;
} }
*s = sum;
#else #else
vDSP_sve(x, 1, s, n); vDSP_sve(x, 1, s, n);
#endif #endif
@ -3724,8 +3724,6 @@ static void ggml_compute_forward_sum_f32(
assert(ggml_is_scalar(dst)); assert(ggml_is_scalar(dst));
assert(src0->nb[0] == sizeof(float)); assert(src0->nb[0] == sizeof(float));
*(float *) (dst->data) = 0.0f;
const int ne00 = src0->ne[0]; const int ne00 = src0->ne[0];
const int ne01 = src0->ne[1]; const int ne01 = src0->ne[1];
const int ne02 = src0->ne[2]; const int ne02 = src0->ne[2];
@ -3811,8 +3809,6 @@ static void ggml_compute_forward_mean_f32(
for (int i03 = 0; i03 < ne03; i03++) { for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) { for (int i02 = 0; i02 < ne02; i02++) {
for (int i01 = 0; i01 < ne01; i01++) { for (int i01 = 0; i01 < ne01; i01++) {
*(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) = 0.0f;
ggml_vec_sum_f32(ne00, ggml_vec_sum_f32(ne00,
(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
(float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03)); (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
@ -4791,7 +4787,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
} }
} }
} else { } else {
// parallelize by src1 columns using ggml_vec_mad_f32 // parallelize by src1 columns using ggml_vec_mad_f16
// each thread has its own work data // each thread has its own work data
// during FINALIZE we accumulate all work data into dst // during FINALIZE we accumulate all work data into dst
@ -6158,8 +6154,7 @@ static void ggml_compute_forward_flash_attn_f16(
S[i] = -INFINITY; S[i] = -INFINITY;
} }
// looks like unrolling here does not help if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) {
#if 1
for (int ic = 0; ic < nek1; ++ic) { for (int ic = 0; ic < nek1; ++ic) {
// k indices // k indices
const int ik3 = iq3; const int ik3 = iq3;
@ -6174,9 +6169,7 @@ static void ggml_compute_forward_flash_attn_f16(
(ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
} }
#else } else {
GGML_ASSERT(nek1 % GGML_VEC_DOT_UNROLL == 0);
for (int ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { for (int ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
// k indices // k indices
const int ik3 = iq3; const int ik3 = iq3;
@ -6191,7 +6184,7 @@ static void ggml_compute_forward_flash_attn_f16(
((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
} }
#endif }
// scale // scale
ggml_vec_scale_f32(nek1, S, scale); ggml_vec_scale_f32(nek1, S, scale);
@ -6261,8 +6254,19 @@ static void ggml_compute_forward_flash_attn_f16(
S16[i] = GGML_FP32_TO_FP16(S[i]); S16[i] = GGML_FP32_TO_FP16(S[i]);
} }
GGML_ASSERT(nev1 % GGML_VEC_DOT_UNROLL == 0); if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
for (int ic = 0; ic < nev1; ++ic) {
// dst indices
const int i1 = iq1;
const int i2 = iq2;
const int i3 = iq3;
ggml_vec_dot_f16(nek1,
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
(ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
S16);
}
} else {
for (int ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) { for (int ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
// dst indices // dst indices
const int i1 = iq1; const int i1 = iq1;
@ -6275,6 +6279,7 @@ static void ggml_compute_forward_flash_attn_f16(
S16); S16);
} }
} }
}
} }
static void ggml_compute_forward_flash_attn( static void ggml_compute_forward_flash_attn(

Loading…
Cancel
Save