diff --git a/Makefile b/Makefile index 363c263..35cfe78 100644 --- a/Makefile +++ b/Makefile @@ -20,10 +20,13 @@ samples: @wget --quiet --show-progress -O samples/gb0.ogg https://upload.wikimedia.org/wikipedia/commons/2/22/George_W._Bush%27s_weekly_radio_address_%28November_1%2C_2008%29.oga @wget --quiet --show-progress -O samples/gb1.ogg https://upload.wikimedia.org/wikipedia/commons/1/1f/George_W_Bush_Columbia_FINAL.ogg @wget --quiet --show-progress -O samples/hp0.ogg https://upload.wikimedia.org/wikipedia/en/d/d4/En.henryfphillips.ogg + @wget --quiet --show-progress -O samples/mm1.wav https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav @echo "Converting to 16-bit WAV ..." @ffmpeg -loglevel -0 -y -i samples/gb0.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/gb0.wav @ffmpeg -loglevel -0 -y -i samples/gb1.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/gb1.wav @ffmpeg -loglevel -0 -y -i samples/hp0.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/hp0.wav + @ffmpeg -loglevel -0 -y -i samples/mm1.wav -ar 16000 -ac 1 -c:a pcm_s16le samples/mm0.wav + @rm samples/mm1.wav # if not already downloaded, the following targets download the specified model and diff --git a/README.md b/README.md index 5b06c12..068636a 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,13 @@ # whisper.cpp -C/C++ port of [OpenAI's Whisper](https://github.com/openai/whisper) speech-to-text model +High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model: - Plain C/C++ implementation without dependencies - ARM_NEON and AVX intrinsics support - Mixed F16 / F32 support - Low memory usage (Flash Attention + Flash Forward) - Zero memory allocations at runtime +- Runs on the CPU (Mac and Linux support) ## Usage @@ -50,7 +51,12 @@ options: bash ./download-ggml-model.sh base.en Downloading ggml model base.en ... -Model base.en already exists. Skipping download. +models/ggml-base.en.bin 100%[=====================================>] 141.11M 8.58MB/s in 22s +Done! Model 'base.en' saved in 'models/ggml-base.en.bin' +You can now use it like this: + + $ ./main -m models/ggml-base.en.bin -f samples/jfk.wav + =============================================== Running base.en on all samples in ./samples ... @@ -73,7 +79,7 @@ whisper_model_load: n_text_layer = 6 whisper_model_load: n_mels = 80 whisper_model_load: f16 = 1 whisper_model_load: type = 2 -whisper_model_load: mem_required = 611.00 MB +whisper_model_load: mem_required = 377.00 MB whisper_model_load: adding 1607 extra tokens whisper_model_load: ggml ctx size = 163.43 MB whisper_model_load: memory size = 22.83 MB @@ -86,12 +92,12 @@ main: processing 176000 samples (11.0 sec), 4 threads, lang = english, task = tr [00:00.000 --> 00:11.000] And so my fellow Americans ask not what your country can do for you. Ask what you can do for your country. -main: load time = 61.78 ms -main: mel time = 41.74 ms -main: sample time = 2.10 ms -main: encode time = 718.60 ms / 119.77 ms per layer -main: decode time = 83.55 ms -main: total time = 908.15 ms +main: load time = 82.05 ms +main: mel time = 44.15 ms +main: sample time = 1.98 ms +main: encode time = 674.77 ms / 112.46 ms per layer +main: decode time = 82.91 ms +main: total time = 886.29 ms ``` The command downloads the `base.en` model converted to custom `ggml` format and runs the inference on all `.wav` samples in the folder `samples`. @@ -131,10 +137,12 @@ make large ## Another example -Here is another example of transcribing a [3:24 min speech](https://upload.wikimedia.org/wikipedia/commons/1/1f/George_W_Bush_Columbia_FINAL.ogg) in less than a minute, using `medium.en` model: +Here is another example of transcribing a [3:24 min speech](https://upload.wikimedia.org/wikipedia/commons/1/1f/George_W_Bush_Columbia_FINAL.ogg) +in less than a minute on a MacBook M1 Pro, using `medium.en` model: ```java $ ./main -m models/ggml-medium.en.bin -f samples/gb1.wav -t 8 + whisper_model_load: loading model from 'models/ggml-medium.en.bin' whisper_model_load: n_vocab = 51864 whisper_model_load: n_audio_ctx = 1500 @@ -148,7 +156,7 @@ whisper_model_load: n_text_layer = 24 whisper_model_load: n_mels = 80 whisper_model_load: f16 = 1 whisper_model_load: type = 4 -whisper_model_load: mem_required = 2786.00 MB +whisper_model_load: mem_required = 2502.00 MB whisper_model_load: adding 1607 extra tokens whisper_model_load: ggml ctx size = 1644.97 MB whisper_model_load: memory size = 182.62 MB @@ -187,30 +195,30 @@ main: processing 3179750 samples (198.7 sec), 8 threads, lang = english, task = [03:14.000 --> 03:24.000] [Music] -main: load time = 438.55 ms -main: mel time = 440.22 ms -main: sample time = 32.23 ms -main: encode time = 42329.63 ms / 1763.73 ms per layer -main: decode time = 15190.00 ms -main: total time = 58444.63 ms +main: load time = 522.18 ms +main: mel time = 423.43 ms +main: sample time = 31.42 ms +main: encode time = 41518.51 ms / 1729.94 ms per layer +main: decode time = 14907.22 ms +main: total time = 57416.63 ms ``` ## Limitations - Very basic greedy sampling scheme - always pick up the top token +- Only 16-bit WAV at 16 kHz is supported - Inference only -- Runs on the CPU -- Only mono-channel 16-bit WAV is supported +- No GPU support ## Memory usage -| Model | Disk | Mem | -| --- | --- | --- | -| tiny | 75 MB | ~460 MB | -| base | 142 MB | ~620 MB | -| small | 466 MB | ~1.3 GB | -| medium | 1.5 GB | ~2.8 GB | -| large | 2.9 GB | ~4.9 GB | +| Model | Disk | Mem | +| --- | --- | --- | +| tiny | 75 MB | ~240 MB | +| base | 142 MB | ~380 MB | +| small | 466 MB | ~970 MB | +| medium | 1.5 GB | ~2.5 GB | +| large | 2.9 GB | ~4.6 GB | ## ggml format diff --git a/main.cpp b/main.cpp index ac20531..4822aa9 100644 --- a/main.cpp +++ b/main.cpp @@ -158,11 +158,11 @@ const std::map MEM_REQ_ENCODE_LAYER = { }; const std::map MEM_REQ_DECODE = { - { MODEL_TINY, 190ull*MB }, - { MODEL_BASE, 190ull*MB }, - { MODEL_SMALL, 190ull*MB }, - { MODEL_MEDIUM, 200ull*MB }, - { MODEL_LARGE, 200ull*MB }, + { MODEL_TINY, 94ull*MB }, + { MODEL_BASE, 96ull*MB }, + { MODEL_SMALL, 98ull*MB }, + { MODEL_MEDIUM, 100ull*MB }, + { MODEL_LARGE, 102ull*MB }, }; const std::map MEM_REQ_DECODE_LAYER = { @@ -173,6 +173,11 @@ const std::map MEM_REQ_DECODE_LAYER = { { MODEL_LARGE, 110ull*MB }, }; +// the memory buffers used to store the model in memory and perform the inference computations +std::vector g_buf_model; +std::vector g_buf_compute; +std::vector g_buf_compute_layer; + const int SAMPLE_RATE = 16000; const int N_FFT = 400; const int N_MEL = 80; @@ -542,13 +547,15 @@ bool whisper_model_load(const std::string & fname, whisper_model & model, whispe printf("%s: f16 = %d\n", __func__, hparams.f16); printf("%s: type = %d\n", __func__, model.type); + g_buf_model.resize(MEM_REQ_MODEL.at(model.type)); + g_buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); + g_buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type))); + // this is the total memory required to run the inference const size_t mem_required = - MEM_REQ_MODEL.at(model.type) + - MEM_REQ_ENCODE.at(model.type) + - MEM_REQ_ENCODE_LAYER.at(model.type) + - MEM_REQ_DECODE.at(model.type) + - MEM_REQ_DECODE_LAYER.at(model.type); + g_buf_model.size() + + g_buf_compute.size() + + g_buf_compute_layer.size(); printf("%s: mem_required = %.2f MB\n", __func__, mem_required / 1024.0 / 1024.0); } @@ -752,8 +759,8 @@ bool whisper_model_load(const std::string & fname, whisper_model & model, whispe // create the ggml context { struct ggml_init_params params = { - .mem_size = ctx_size, - .mem_buffer = NULL, + .mem_size = g_buf_model.size(), + .mem_buffer = g_buf_model.data(), }; model.ctx = ggml_init(params); @@ -1089,17 +1096,10 @@ bool whisper_encode( const int n_mels = hparams.n_mels; assert(mel_inp.n_mel == n_mels); - struct ggml_init_params params; - - { - static size_t buf_size = MEM_REQ_ENCODE.at(model.type); - static void * buf = malloc(buf_size); - - params = { - .mem_size = buf_size, - .mem_buffer = buf, - }; - } + struct ggml_init_params params = { + .mem_size = g_buf_compute.size(), + .mem_buffer = g_buf_compute.data(), + }; struct ggml_context * ctx0 = ggml_init(params); @@ -1151,16 +1151,10 @@ bool whisper_encode( // create separate context for each layer to reduce memory usage - struct ggml_init_params paramsL; - { - static size_t buf_size = MEM_REQ_ENCODE_LAYER.at(model.type); - static void * buf = malloc(buf_size); - - paramsL = { - .mem_size = buf_size, - .mem_buffer = buf, - }; - } + struct ggml_init_params paramsL = { + .mem_size = g_buf_compute_layer.size(), + .mem_buffer = g_buf_compute_layer.data(), + }; struct ggml_context * ctxL = ggml_init(paramsL); @@ -1492,17 +1486,10 @@ bool whisper_decode( const int N = prompt.size(); const int M = hparams.n_audio_ctx; - struct ggml_init_params params; - - { - static size_t buf_size = MEM_REQ_DECODE.at(model.type); - static void * buf = malloc(buf_size); - - params = { - .mem_size = buf_size, - .mem_buffer = buf, + struct ggml_init_params params = { + .mem_size = g_buf_compute.size(), + .mem_buffer = g_buf_compute.data(), }; - } struct ggml_context * ctx0 = ggml_init(params); @@ -1525,17 +1512,10 @@ bool whisper_decode( for (int il = 0; il < n_layer; ++il) { const auto & layer = model.layers_decoder[il]; - struct ggml_init_params paramsL; - - { - static size_t buf_size = MEM_REQ_DECODE_LAYER.at(model.type); - static void * buf = malloc(buf_size); - - paramsL = { - .mem_size = buf_size, - .mem_buffer = buf, - }; - } + struct ggml_init_params paramsL = { + .mem_size = g_buf_compute_layer.size(), + .mem_buffer = g_buf_compute_layer.data(), + }; struct ggml_context * ctxL = ggml_init(paramsL); struct ggml_cgraph gf = { .n_threads = n_threads }; @@ -1849,7 +1829,7 @@ bool whisper_decode( // TODO: temperature whisper_vocab::id whisper_sample_best( const whisper_vocab & vocab, - const float * probs) { + const float * probs, bool need_timestamp) { int n_logits = vocab.id_to_token.size(); std::vector> probs_id; @@ -1859,7 +1839,7 @@ whisper_vocab::id whisper_sample_best( probs_id.push_back(std::make_pair(probs[i], i)); } - const int top_k = 10; + const int top_k = 4; // find the top K tokens std::partial_sort( @@ -1876,6 +1856,15 @@ whisper_vocab::id whisper_sample_best( // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second); //} + if (need_timestamp) { + // at the end of the 30-second audio segment, we start giving preference to time tokens + for (int i = 0; i < top_k; i++) { + if (probs_id[i].second > vocab.token_beg + 1300 && probs_id[i].first > probs_id[0].first*0.1) { + return probs_id[i].second; + } + } + } + int res = 0; while ((probs_id[res].second == vocab.token_sot || probs_id[res].second == vocab.token_solm || @@ -2136,8 +2125,8 @@ int main(int argc, char ** argv) { return 2; } - if (wav.channels != 1) { - fprintf(stderr, "%s: WAV file '%s' must be mono\n", argv[0], params.fname_inp.c_str()); + if (wav.channels != 1 && wav.channels != 2) { + fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], params.fname_inp.c_str()); return 3; } @@ -2158,8 +2147,14 @@ int main(int argc, char ** argv) { // convert to float pcmf32.resize(pcm16.size()); - for (size_t i = 0; i < pcm16.size(); i++) { - pcmf32[i] = float(pcm16[i])/32768.0f; + if (wav.channels == 1) { + for (size_t i = 0; i < pcm16.size(); i++) { + pcmf32[i] = float(pcm16[i])/32768.0f; + } + } else { + for (size_t i = 0; i < pcm16.size(); i++) { + pcmf32[i] = float(pcm16[i*2 + 0] + pcm16[i*2 + 1])/32768.0f/2.0f; + } } } @@ -2252,6 +2247,7 @@ int main(int argc, char ** argv) { int seek_delta = 100*CHUNK_SIZE; whisper_vocab::id last_id = 0; + // print the prompt //printf("\n\n"); //for (int i = 0; i < prompt.size(); i++) { // printf("%s: prompt[%d] = %s\n", __func__, i, vocab.id_to_token[prompt[i]].c_str()); @@ -2294,7 +2290,7 @@ int main(int argc, char ** argv) { { const int64_t t_start_sample_us = ggml_time_us(); - id = whisper_sample_best(vocab, probs.data() + (probs.size() - n_vocab)); + id = whisper_sample_best(vocab, probs.data() + (probs.size() - n_vocab), result_len == 0); if (i > 0) { tid = whisper_sample_timestamp(vocab, probs.data() + (probs.size() - n_vocab)); } @@ -2313,6 +2309,8 @@ int main(int argc, char ** argv) { prompt.push_back(id); result_cur.push_back({ id, seek + 2*(tid - vocab.token_beg) }); + //printf("%s: %s\n", __func__, vocab.id_to_token[id].c_str()); + // end of text token if (id == vocab.token_eot) { break;