From 0116c03fb7d697e92072d368646d850aa7250cdd Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 30 Sep 2022 19:16:07 +0300 Subject: [PATCH] whisper : various updates and improvements --- examples/whisper/README.md | 14 +- examples/whisper/main.cpp | 302 +++++++++++++++++++++++++------------ 2 files changed, 214 insertions(+), 102 deletions(-) diff --git a/examples/whisper/README.md b/examples/whisper/README.md index 68bc4bc..1834cdd 100644 --- a/examples/whisper/README.md +++ b/examples/whisper/README.md @@ -9,13 +9,13 @@ Checkout https://github.com/ggerganov/whisper.cpp ## Memory usage -| Model | Mem | -| --- | --- | -| tiny.en | ~460 MB | -| base.en | ~620 MB | -| small.en | ~1.3 GB | -| medium.en | ~2.8 GB | -| large | ~4.9 GB | +| Model | Disk | Mem | +| --- | --- | --- | +| tiny | 75 MB | ~240 MB | +| base | 142 MB | ~380 MB | +| small | 466 MB | ~970 MB | +| medium | 1.5 GB | ~2.5 GB | +| large | 2.9 GB | ~4.6 GB | ## ggml format diff --git a/examples/whisper/main.cpp b/examples/whisper/main.cpp index 326a8a7..79935c9 100644 --- a/examples/whisper/main.cpp +++ b/examples/whisper/main.cpp @@ -158,11 +158,11 @@ const std::map MEM_REQ_ENCODE_LAYER = { }; const std::map MEM_REQ_DECODE = { - { MODEL_TINY, 190ull*MB }, - { MODEL_BASE, 190ull*MB }, - { MODEL_SMALL, 190ull*MB }, - { MODEL_MEDIUM, 200ull*MB }, - { MODEL_LARGE, 200ull*MB }, + { MODEL_TINY, 94ull*MB }, + { MODEL_BASE, 96ull*MB }, + { MODEL_SMALL, 98ull*MB }, + { MODEL_MEDIUM, 100ull*MB }, + { MODEL_LARGE, 102ull*MB }, }; const std::map MEM_REQ_DECODE_LAYER = { @@ -173,6 +173,11 @@ const std::map MEM_REQ_DECODE_LAYER = { { MODEL_LARGE, 110ull*MB }, }; +// the memory buffers used to store the model in memory and perform the inference computations +std::vector g_buf_model; +std::vector g_buf_compute; +std::vector g_buf_compute_layer; + const int SAMPLE_RATE = 16000; const int N_FFT = 400; const int N_MEL = 80; @@ -206,6 +211,7 @@ struct whisper_vocab { id token_sot = 50257; id token_prev = 50360; id token_solm = 50361; // ?? + id token_not = 50362; // no timestamps id token_beg = 50363; // available tasks @@ -217,17 +223,20 @@ struct whisper_vocab { } }; +struct whisper_result { + whisper_vocab::id id; + int64_t t; +}; + // command-line parameters struct whisper_params { int32_t seed = -1; // RNG seed, not used currently int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); - // sampling parameter - used for the greedy strategy - int32_t max_tokens_per_iter = 64; - bool verbose = false; bool translate = false; bool print_special_tokens = false; + bool no_timestamps = false; std::string language = "en"; std::string model = "models/ggml-base.en.bin"; @@ -244,8 +253,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { params.seed = std::stoi(argv[++i]); } else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); - } else if (arg == "-T" || arg == "--tokens") { - params.max_tokens_per_iter = std::stoi(argv[++i]); } else if (arg == "-v" || arg == "--verbose") { params.verbose = true; } else if (arg == "--translate") { @@ -259,6 +266,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { } } else if (arg == "-ps" || arg == "--print_special") { params.print_special_tokens = true; + } else if (arg == "-nt" || arg == "--no_timestamps") { + params.no_timestamps = true; } else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } else if (arg == "-f" || arg == "--file") { @@ -284,10 +293,10 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params) fprintf(stderr, " -h, --help show this help message and exit\n"); fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n"); fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); - fprintf(stderr, " -T N, --tokens N maximum number of tokens to generate per iteration (default: %d)\n", params.max_tokens_per_iter); fprintf(stderr, " -v, --verbose verbose output\n"); fprintf(stderr, " --translate translate from source language to english\n"); fprintf(stderr, " -ps, --print_special print special tokens\n"); + fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n"); fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str()); fprintf(stderr, " -m FNAME, --model FNAME model path (default: %s)\n", params.model.c_str()); fprintf(stderr, " -f FNAME, --file FNAME input WAV file path (default: %s)\n", params.fname_inp.c_str()); @@ -538,13 +547,15 @@ bool whisper_model_load(const std::string & fname, whisper_model & model, whispe printf("%s: f16 = %d\n", __func__, hparams.f16); printf("%s: type = %d\n", __func__, model.type); + g_buf_model.resize(MEM_REQ_MODEL.at(model.type)); + g_buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); + g_buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type))); + // this is the total memory required to run the inference const size_t mem_required = - MEM_REQ_MODEL.at(model.type) + - MEM_REQ_ENCODE.at(model.type) + - MEM_REQ_ENCODE_LAYER.at(model.type) + - MEM_REQ_DECODE.at(model.type) + - MEM_REQ_DECODE_LAYER.at(model.type); + g_buf_model.size() + + g_buf_compute.size() + + g_buf_compute_layer.size(); printf("%s: mem_required = %.2f MB\n", __func__, mem_required / 1024.0 / 1024.0); } @@ -591,6 +602,7 @@ bool whisper_model_load(const std::string & fname, whisper_model & model, whispe vocab.token_sot++; vocab.token_prev++; vocab.token_solm++; + vocab.token_not++; vocab.token_beg++; } @@ -605,6 +617,8 @@ bool whisper_model_load(const std::string & fname, whisper_model & model, whispe word = "[_SOT_]"; } else if (i == vocab.token_prev) { word = "[_PREV_]"; + } else if (i == vocab.token_not) { + word = "[_NOT_]"; } else if (i == vocab.token_beg) { word = "[_BEG_]"; } else { @@ -745,8 +759,8 @@ bool whisper_model_load(const std::string & fname, whisper_model & model, whispe // create the ggml context { struct ggml_init_params params = { - .mem_size = ctx_size, - .mem_buffer = NULL, + .mem_size = g_buf_model.size(), + .mem_buffer = g_buf_model.data(), }; model.ctx = ggml_init(params); @@ -1082,17 +1096,10 @@ bool whisper_encode( const int n_mels = hparams.n_mels; assert(mel_inp.n_mel == n_mels); - struct ggml_init_params params; - - { - static size_t buf_size = MEM_REQ_ENCODE.at(model.type); - static void * buf = malloc(buf_size); - - params = { - .mem_size = buf_size, - .mem_buffer = buf, - }; - } + struct ggml_init_params params = { + .mem_size = g_buf_compute.size(), + .mem_buffer = g_buf_compute.data(), + }; struct ggml_context * ctx0 = ggml_init(params); @@ -1144,16 +1151,10 @@ bool whisper_encode( // create separate context for each layer to reduce memory usage - struct ggml_init_params paramsL; - { - static size_t buf_size = MEM_REQ_ENCODE_LAYER.at(model.type); - static void * buf = malloc(buf_size); - - paramsL = { - .mem_size = buf_size, - .mem_buffer = buf, - }; - } + struct ggml_init_params paramsL = { + .mem_size = g_buf_compute_layer.size(), + .mem_buffer = g_buf_compute_layer.data(), + }; struct ggml_context * ctxL = ggml_init(paramsL); @@ -1485,17 +1486,10 @@ bool whisper_decode( const int N = prompt.size(); const int M = hparams.n_audio_ctx; - struct ggml_init_params params; - - { - static size_t buf_size = MEM_REQ_DECODE.at(model.type); - static void * buf = malloc(buf_size); - - params = { - .mem_size = buf_size, - .mem_buffer = buf, + struct ggml_init_params params = { + .mem_size = g_buf_compute.size(), + .mem_buffer = g_buf_compute.data(), }; - } struct ggml_context * ctx0 = ggml_init(params); @@ -1518,17 +1512,10 @@ bool whisper_decode( for (int il = 0; il < n_layer; ++il) { const auto & layer = model.layers_decoder[il]; - struct ggml_init_params paramsL; - - { - static size_t buf_size = MEM_REQ_DECODE_LAYER.at(model.type); - static void * buf = malloc(buf_size); - - paramsL = { - .mem_size = buf_size, - .mem_buffer = buf, - }; - } + struct ggml_init_params paramsL = { + .mem_size = g_buf_compute_layer.size(), + .mem_buffer = g_buf_compute_layer.data(), + }; struct ggml_context * ctxL = ggml_init(paramsL); struct ggml_cgraph gf = { .n_threads = n_threads }; @@ -1842,19 +1829,17 @@ bool whisper_decode( // TODO: temperature whisper_vocab::id whisper_sample_best( const whisper_vocab & vocab, - const float * probs, - double temp, - int offset = 0) { + const float * probs, bool need_timestamp) { int n_logits = vocab.id_to_token.size(); std::vector> probs_id; probs_id.reserve(n_logits); - for (int i = offset; i < n_logits; i++) { + for (int i = 0; i < n_logits; i++) { probs_id.push_back(std::make_pair(probs[i], i)); } - const int top_k = 10; + const int top_k = 4; // find the top K tokens std::partial_sort( @@ -1871,14 +1856,59 @@ whisper_vocab::id whisper_sample_best( // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second); //} + if (need_timestamp) { + // at the end of the 30-second audio segment, we start giving preference to time tokens + for (int i = 0; i < top_k; i++) { + if (probs_id[i].second > vocab.token_beg + 1300 && probs_id[i].first > probs_id[0].first*0.1) { + return probs_id[i].second; + } + } + } + int res = 0; - while (probs_id[res].second == vocab.token_solm && res < (int) probs_id.size() - 1) { + while ((probs_id[res].second == vocab.token_sot || + probs_id[res].second == vocab.token_solm || + probs_id[res].second == vocab.token_not) && + res < (int) probs_id.size() - 1) { res++; } return probs_id[res].second; } +// samples only from the timestamps tokens +whisper_vocab::id whisper_sample_timestamp( + const whisper_vocab & vocab, + const float * probs) { + int n_logits = vocab.id_to_token.size(); + + std::vector> probs_id; + probs_id.reserve(n_logits); + + for (int i = vocab.token_beg + 1; i < n_logits; i++) { + probs_id.push_back(std::make_pair(probs[i], i)); + } + + const int top_k = 10; + + // find the top K tokens + std::partial_sort( + probs_id.begin(), + probs_id.begin() + top_k, probs_id.end(), + [](const std::pair & a, const std::pair & b) { + return a.first > b.first; + }); + + probs_id.resize(top_k); + + //printf("\n"); + //for (int i = 0; i < (int) probs_id.size(); i++) { + // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second); + //} + + return probs_id[0].second; +} + // Cooley-Tukey FFT // poor man's implmentation - use something better // input is real-valued @@ -2032,6 +2062,20 @@ bool log_mel_spectrogram( return true; } +// 500 -> 00:05.000 +// 6000 -> 01:00.000 +std::string to_timestamp(int64_t t) { + int64_t sec = t/100; + int64_t msec = t - sec*100; + int64_t min = sec/60; + sec = sec - min*60; + + char buf[32]; + snprintf(buf, sizeof(buf), "%02d:%02d.%03d", (int) min, (int) sec, (int) msec); + + return std::string(buf); +} + int main(int argc, char ** argv) { const int64_t t_main_start_us = ggml_time_us(); @@ -2051,7 +2095,7 @@ int main(int argc, char ** argv) { int64_t t_load_us = 0; int64_t t_mel_us = 0; - int64_t t_sample_us = 0; + int64_t t_sample_us = 0; int64_t t_encode_us = 0; int64_t t_decode_us = 0; @@ -2128,10 +2172,12 @@ int main(int argc, char ** argv) { printf("%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__); } } - printf("%s: processing %d samples (%.1f sec), %d threads, lang = %s, task = %s ...\n", + printf("%s: processing %d samples (%.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n", __func__, int(pcmf32.size()), float(pcmf32.size())/SAMPLE_RATE, params.n_threads, g_lang.at(params.language).second.c_str(), - params.translate ? "translate" : "transcribe"); + params.translate ? "translate" : "transcribe", + params.no_timestamps ? 0 : 1); + printf("\n"); } // the accumulated text context so far @@ -2148,6 +2194,9 @@ int main(int argc, char ** argv) { } } + // the generated text including timestamps + std::vector result_all; + // main loop int seek = 0; while (true) { @@ -2165,7 +2214,7 @@ int main(int argc, char ** argv) { return 1; } - t_encode_us = ggml_time_us() - t_start_us; + t_encode_us += ggml_time_us() - t_start_us; } std::vector probs; @@ -2192,11 +2241,17 @@ int main(int argc, char ** argv) { int seek_delta = 100*CHUNK_SIZE; whisper_vocab::id last_id = 0; + // print the prompt + //printf("\n\n"); //for (int i = 0; i < prompt.size(); i++) { // printf("%s: prompt[%d] = %s\n", __func__, i, vocab.id_to_token[prompt[i]].c_str()); //} + //printf("\n\n"); + + // the accumulated transcription in the current interation + int result_len = 0; + std::vector result_cur; - printf("\n"); for (int i = 0; i < model.hparams.n_text_ctx/2; ++i) { // decode if (prompt.size() > 0) { @@ -2216,63 +2271,120 @@ int main(int argc, char ** argv) { // very basic greedy sampling strategy: // // - always take the most probable token - // - if we have accumulated more than 'params.max_tokens_per_iter' -> pick most probable timestamp token - // and advance the sliding window by that amount - // - in the meantime, if we encounter 2 consecutive timestamp tokens, we advance the sliding window too // // more sophisticated sampling strategies could be implemented here, but we keep it simple // feel free to experiment! // { - // sample next token - const float temp = 1.0; // TODO - const int n_vocab = model.hparams.n_vocab; - whisper_vocab::id id = 0; + whisper_vocab::id id = 0; + whisper_vocab::id tid = vocab.token_beg; { const int64_t t_start_sample_us = ggml_time_us(); - id = whisper_sample_best(vocab, probs.data() + (probs.size() - n_vocab), temp, i > params.max_tokens_per_iter ? vocab.token_beg : 0); + id = whisper_sample_best(vocab, probs.data() + (probs.size() - n_vocab), result_len == 0); + if (i > 0) { + tid = whisper_sample_timestamp(vocab, probs.data() + (probs.size() - n_vocab)); + } t_sample_us += ggml_time_us() - t_start_sample_us; } - // end of text token - if (id == vocab.token_eot) { - break; - } - - // 2 consecutive time tokens - if (id > vocab.token_beg && last_id > vocab.token_beg) { + // update sliding window + if (id > vocab.token_beg) { seek_delta = 2*(id - vocab.token_beg); - done = true; + result_len = i + 1; } last_id = id; // add it to the context prompt.push_back(id); - prompt_past.push_back(id); - } + result_cur.push_back({ id, seek + 2*(tid - vocab.token_beg) }); + + //printf("%s: %s\n", __func__, vocab.id_to_token[id].c_str()); - // display text - for (auto id : prompt) { - if (params.print_special_tokens == false && id >= vocab.token_eot) { - continue; + // end of text token + if (id == vocab.token_eot) { + break; } - printf("%s", vocab.id_to_token[id].c_str()); } - fflush(stdout); if (done) { break; } } + result_cur.resize(result_len); + result_all.insert(result_all.end(), result_cur.begin(), result_cur.end()); + + for (const auto & r : result_cur) { + prompt_past.push_back(r.id); + } + + // print the text from this iteration + if (result_cur.size() > 0) { + auto t0 = result_cur.front().t; + + std::string text = ""; + for (int i = 0; i < result_cur.size(); i++) { + if (params.print_special_tokens == false && result_cur[i].id >= vocab.token_eot) { + } else { + text += vocab.id_to_token[result_cur[i].id]; + } + if (result_cur[i].id > vocab.token_beg) { + const auto t1 = result_cur[i].t; + if (!text.empty()) { + if (params.no_timestamps) { + printf ("%s", text.c_str()); + fflush(stdout); + } else { + printf ("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text.c_str()); + } + } + text = ""; + while (result_cur[i].id > vocab.token_beg && i < result_cur.size()) { + i++; + } + i--; + t0 = result_cur[i].t; + } + } + + if (!text.empty()) { + printf ("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(seek + seek_delta).c_str(), text.c_str()); + } + } + seek += seek_delta; } + // WIP: attempt for per-token timestamps + //if (!params.no_timestamps && result_all.size() > 0) { + // const int64_t dt = 500; // 5 second intervals + + // int i0 = 0; + + // int64_t t0 = result_all[0].t; + // int64_t t1 = t0; + + // printf("\n\n"); + // for (int i = 0; i < result_all.size(); ++i) { + // printf("'%s' -> %lld\n", vocab.id_to_token[result_all[i].id].c_str(), result_all[i].t); + // if (result_all[i].t - t0 > dt) { + // t1 = result_all[i - 1].t; + // printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str()); + // for (int j = i0; j < i; ++j) { + // printf("%s", vocab.id_to_token.at(result_all[j].id).c_str()); + // } + // printf("\n"); + // i0 = i; + // t0 = result_all[i].t; + // } + // } + //} + // report timing { const int64_t t_main_end_us = ggml_time_us();