From 6feeca262f3cbb995d5f3777643531be614caef0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 1 Nov 2022 22:13:15 +0200 Subject: [PATCH] sync : latest changes from whisper.cpp --- README.md | 4 +- examples/whisper/main.cpp | 703 ++++++++++++++++++++++++++++++----- examples/whisper/whisper.cpp | 427 +++++++++++++++++---- examples/whisper/whisper.h | 47 ++- include/ggml/ggml.h | 13 +- src/ggml.c | 386 ++++++++++++++++--- src/msvc_thread_atomic.h | 31 -- 7 files changed, 1361 insertions(+), 250 deletions(-) delete mode 100644 src/msvc_thread_atomic.h diff --git a/README.md b/README.md index fc065c0..8c3f862 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,9 @@ Tensor library for machine learning - No third-party dependencies - Zero memory allocations during runtime -*Note that this project is under development and not ready for production use* +***Note that this project is under development and not ready for production use. +More active development is happening in the ***[whisper.cpp](https://github.com/ggerganov/whisper.cpp) ***repo +so if you are interested in this project, make sure to follow what is happening there*** ## Whisper inference (example) diff --git a/examples/whisper/main.cpp b/examples/whisper/main.cpp index 995eefc..1be0032 100644 --- a/examples/whisper/main.cpp +++ b/examples/whisper/main.cpp @@ -5,15 +5,23 @@ #define DR_WAV_IMPLEMENTATION #include "dr_wav.h" +#include #include #include #include #include #include +// Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9] +// Lowest is red, middle is yellow, highest is green. +const std::vector k_colors = { + "\033[38;5;196m", "\033[38;5;202m", "\033[38;5;208m", "\033[38;5;214m", "\033[38;5;220m", + "\033[38;5;226m", "\033[38;5;190m", "\033[38;5;154m", "\033[38;5;118m", "\033[38;5;82m", +}; + // 500 -> 00:05.000 // 6000 -> 01:00.000 -std::string to_timestamp(int64_t t) { +std::string to_timestamp(int64_t t, bool comma = false) { int64_t msec = t * 10; int64_t hr = msec / (1000 * 60 * 60); msec = msec - hr * (1000 * 60 * 60); @@ -23,23 +31,64 @@ std::string to_timestamp(int64_t t) { msec = msec - sec * 1000; char buf[32]; - snprintf(buf, sizeof(buf), "%02d:%02d:%02d.%03d", (int) hr, (int) min, (int) sec, (int) msec); + snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec); return std::string(buf); } +void replace_all(std::string & s, const std::string & search, const std::string & replace) { + for (size_t pos = 0; ; pos += replace.length()) { + pos = s.find(search, pos); + if (pos == std::string::npos) break; + s.erase(pos, search.length()); + s.insert(pos, replace); + } +} + +// a cost-function that is high for text that takes longer to pronounce +float voice_length(const std::string & text) { + float res = 0.0f; + + for (size_t i = 0; i < text.size(); ++i) { + if (text[i] == ' ') { + res += 0.01f; + } else if (text[i] == ',') { + res += 2.00f; + } else if (text[i] == '.') { + res += 3.00f; + } else if (text[i] == '!') { + res += 3.00f; + } else if (text[i] == '?') { + res += 3.00f; + } else if (text[i] >= '0' && text[i] <= '9') { + res += 3.00f; + } else { + res += 1.00f; + } + } + + return res; +} + // command-line parameters struct whisper_params { - int32_t seed = -1; // RNG seed, not used currently - int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); - int32_t offset_ms = 0; + int32_t seed = -1; // RNG seed, not used currently + int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + int32_t n_processors = 1; + int32_t offset_t_ms = 0; + int32_t offset_n = 0; + int32_t max_context = -1; + + float word_thold = 0.01f; bool verbose = false; bool translate = false; bool output_txt = false; bool output_vtt = false; bool output_srt = false; + bool output_wts = false; bool print_special_tokens = false; + bool print_colors = false; bool no_timestamps = false; std::string language = "en"; @@ -63,8 +112,16 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { params.seed = std::stoi(argv[++i]); } else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); - } else if (arg == "-o" || arg == "--offset") { - params.offset_ms = std::stoi(argv[++i]); + } else if (arg == "-p" || arg == "--processors") { + params.n_processors = std::stoi(argv[++i]); + } else if (arg == "-ot" || arg == "--offset-t") { + params.offset_t_ms = std::stoi(argv[++i]); + } else if (arg == "-on" || arg == "--offset-n") { + params.offset_n = std::stoi(argv[++i]); + } else if (arg == "-mc" || arg == "--max-context") { + params.max_context = std::stoi(argv[++i]); + } else if (arg == "-wt" || arg == "--word-thold") { + params.word_thold = std::stof(argv[++i]); } else if (arg == "-v" || arg == "--verbose") { params.verbose = true; } else if (arg == "--translate") { @@ -82,8 +139,12 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { params.output_vtt = true; } else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; + } else if (arg == "-owts" || arg == "--output-words") { + params.output_wts = true; } else if (arg == "-ps" || arg == "--print_special") { params.print_special_tokens = true; + } else if (arg == "-pc" || arg == "--print_colors") { + params.print_colors = true; } else if (arg == "-nt" || arg == "--no_timestamps") { params.no_timestamps = true; } else if (arg == "-m" || arg == "--model") { @@ -111,13 +172,19 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params) fprintf(stderr, " -h, --help show this help message and exit\n"); fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n"); fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); - fprintf(stderr, " -o N, --offset N offset in milliseconds (default: %d)\n", params.offset_ms); + fprintf(stderr, " -p N, --processors N number of processors to use during computation (default: %d)\n", params.n_processors); + fprintf(stderr, " -ot N, --offset-t N time offset in milliseconds (default: %d)\n", params.offset_t_ms); + fprintf(stderr, " -on N, --offset-n N segment index offset (default: %d)\n", params.offset_n); + fprintf(stderr, " -mc N, --max-context N maximum number of text context tokens to store (default: max)\n"); + fprintf(stderr, " -wt N, --word-thold N word timestamp probability threshold (default: %f)\n", params.word_thold); fprintf(stderr, " -v, --verbose verbose output\n"); fprintf(stderr, " --translate translate from source language to english\n"); fprintf(stderr, " -otxt, --output-txt output result in a text file\n"); fprintf(stderr, " -ovtt, --output-vtt output result in a vtt file\n"); fprintf(stderr, " -osrt, --output-srt output result in a srt file\n"); + fprintf(stderr, " -owts, --output-words output word-level timestamps to a text file\n"); fprintf(stderr, " -ps, --print_special print special tokens\n"); + fprintf(stderr, " -pc, --print_colors print colors\n"); fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n"); fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str()); fprintf(stderr, " -m FNAME, --model FNAME model path (default: %s)\n", params.model.c_str()); @@ -125,6 +192,505 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params) fprintf(stderr, "\n"); } +void whisper_print_segment_callback(struct whisper_context * ctx, void * user_data) { + const whisper_params & params = *(whisper_params *) user_data; + + const int n_segments = whisper_full_n_segments(ctx); + + // print the last segment + const int i = n_segments - 1; + if (i == 0) { + printf("\n"); + } + + if (params.no_timestamps) { + if (params.print_colors) { + for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { + if (params.print_special_tokens == false) { + const whisper_token id = whisper_full_get_token_id(ctx, i, j); + if (id >= whisper_token_eot(ctx)) { + continue; + } + } + + const char * text = whisper_full_get_token_text(ctx, i, j); + const float p = whisper_full_get_token_p (ctx, i, j); + + const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size())))); + + printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m"); + } + } else { + const char * text = whisper_full_get_segment_text(ctx, i); + printf("%s", text); + } + fflush(stdout); + } else { + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + + if (params.print_colors) { + printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str()); + for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { + if (params.print_special_tokens == false) { + const whisper_token id = whisper_full_get_token_id(ctx, i, j); + if (id >= whisper_token_eot(ctx)) { + continue; + } + } + + const char * text = whisper_full_get_token_text(ctx, i, j); + const float p = whisper_full_get_token_p (ctx, i, j); + + const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size())))); + + printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m"); + } + printf("\n"); + } else { + const char * text = whisper_full_get_segment_text(ctx, i); + + printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text); + } + } +} + +bool output_txt(struct whisper_context * ctx, const char * fname) { + std::ofstream fout(fname); + if (!fout.is_open()) { + fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); + return false; + } + + fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); + + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) { + const char * text = whisper_full_get_segment_text(ctx, i); + fout << text; + } + + return true; +} + +bool output_vtt(struct whisper_context * ctx, const char * fname) { + std::ofstream fout(fname); + if (!fout.is_open()) { + fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); + return 9; + } + + fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); + + fout << "WEBVTT\n\n"; + + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) { + const char * text = whisper_full_get_segment_text(ctx, i); + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + + fout << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n"; + fout << text << "\n\n"; + } + + return true; +} + +bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_params & params) { + std::ofstream fout(fname); + if (!fout.is_open()) { + fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); + return false; + } + + fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); + + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) { + const char * text = whisper_full_get_segment_text(ctx, i); + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + + fout << i + 1 + params.offset_n << "\n"; + fout << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n"; + fout << text << "\n\n"; + } + + return true; +} + +// word-level timestamps (experimental) +// TODO: probably still has bugs, needs refactoring, etc.. +// TODO: auto threshold +// TODO: extra pass to detect unused speech and assign to tokens +// TODO: font parameter adjustments +bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, const std::vector & pcmf32) { + if (params.output_wts) { + std::vector pcm_avg(pcmf32.size(), 0); + + // average the fabs of the signal + { + const int hw = 32; + + for (int i = 0; i < pcmf32.size(); i++) { + float sum = 0; + for (int j = -hw; j <= hw; j++) { + if (i + j >= 0 && i + j < pcmf32.size()) { + sum += fabs(pcmf32[i + j]); + } + } + pcm_avg[i] = sum/(2*hw + 1); + } + } + + struct token_info { + int64_t t0 = -1; + int64_t t1 = -1; + + int64_t tt0 = -1; + int64_t tt1 = -1; + + whisper_token id; + whisper_token tid; + + float p = 0.0f; + float pt = 0.0f; + float ptsum = 0.0f; + + std::string text; + float vlen = 0.0f; // voice length of this token + }; + + int64_t t_beg = 0; + int64_t t_last = 0; + + whisper_token tid_last = 0; + + std::ofstream fout(fname); + + fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); + + fout << "!/bin/bash" << "\n"; + fout << "\n"; + + fout << "ffmpeg -i " << fname_inp << " -f lavfi -i color=size=1200x120:duration=" << float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE << ":rate=25:color=black -vf \""; + + bool is_first = true; + + for (int i = 0; i < whisper_full_n_segments(ctx); i++) { + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + + const char *text = whisper_full_get_segment_text(ctx, i); + + const int s0 = std::max(0, (int) (t0*WHISPER_SAMPLE_RATE/100)); + const int s1 = std::min((int) pcmf32.size(), (int) (t1*WHISPER_SAMPLE_RATE/100)); + + const int n = whisper_full_n_tokens(ctx, i); + + std::vector tokens(n); + + if (n <= 1) { + continue; + } + + for (int j = 0; j < n; ++j) { + struct whisper_token_data token = whisper_full_get_token_data(ctx, i, j); + + if (j == 0) { + if (token.id == whisper_token_beg(ctx)) { + tokens[j ].t0 = t0; + tokens[j ].t1 = t0; + tokens[j + 1].t0 = t0; + + t_beg = t0; + t_last = t0; + tid_last = whisper_token_beg(ctx); + } else { + tokens[j ].t0 = t_last; + } + } + + const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx)); + + tokens[j].id = token.id; + tokens[j].tid = token.tid; + tokens[j].p = token.p; + tokens[j].pt = token.pt; + tokens[j].ptsum = token.ptsum; + + tokens[j].text = whisper_token_to_str(ctx, token.id); + //tokens[j].vlen = tokens[j].pt; + tokens[j].vlen = voice_length(tokens[j].text); + + if (token.pt > params.word_thold && token.ptsum > 0.01 && token.tid > tid_last && tt <= t1) { + if (j > 0) { + tokens[j - 1].t1 = tt; + } + tokens[j].t0 = tt; + tid_last = token.tid; + } + } + + tokens[n - 2].t1 = t1; + tokens[n - 1].t0 = t1; + tokens[n - 1].t1 = t1; + + t_last = t1; + + int p0 = 0; + int p1 = 0; + while (true) { + while (p1 < n && tokens[p1].t1 < 0) { + p1++; + } + + if (p1 >= n) { + p1--; + } + + if (p1 > p0) { + double psum = 0.0; + for (int j = p0; j <= p1; j++) { + psum += tokens[j].vlen; + } + + //printf("analyzing %d - %d, psum = %f\n", p0, p1, psum); + + const double dt = tokens[p1].t1 - tokens[p0].t0; + + for (int j = p0 + 1; j <= p1; j++) { + const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum; + //const double ct = tokens[j - 1].t0 + (dt*(j - p0))/(p1 - p0 + 1); + //const double ct = tokens[p0].t0 + (dt*(j - p0))/(p1 - p0 + 1); + + tokens[j - 1].t1 = ct; + tokens[j ].t0 = ct; + } + } + + p1++; + p0 = p1; + if (p1 >= n) { + break; + } + } + + for (int j = 0; j < n - 1; j++) { + if (tokens[j].t1 < 0) { + tokens[j + 1].t0 = tokens[j].t1; + } + + if (j > 0) { + if (tokens[j - 1].t1 > tokens[j].t0) { + tokens[j].t0 = tokens[j - 1].t1; + tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1); + } + } + + tokens[j].tt0 = tokens[j].t0; + tokens[j].tt1 = tokens[j].t1; + } + + // VAD + { + const int hw = WHISPER_SAMPLE_RATE/8; + + for (int j = 0; j < n; j++) { + if (tokens[j].id >= whisper_token_eot(ctx)) { + continue; + } + + const int64_t t0 = tokens[j].t0; + const int64_t t1 = tokens[j].t1; + + int s0 = std::max(0, (int) (t0*WHISPER_SAMPLE_RATE/100)); + int s1 = std::min((int) pcmf32.size() - 1, (int) (t1*WHISPER_SAMPLE_RATE/100)); + + const int ss0 = std::max(0, (int) (t0*WHISPER_SAMPLE_RATE/100) - hw); + const int ss1 = std::min((int) pcmf32.size() - 1, (int) (t1*WHISPER_SAMPLE_RATE/100) + hw); + + const int n = ss1 - ss0; + + float sum = 0.0f; + + for (int k = ss0; k < ss1; k++) { + sum += pcm_avg[k]; + } + + const float thold = 0.5*sum/n; + + { + int k = s0; + if (pcm_avg[k] > thold && j > 0) { + while (k > 0 && pcm_avg[k] > thold) { + k--; + } + tokens[j].t0 = (int64_t) (100*k/WHISPER_SAMPLE_RATE); + if (tokens[j].t0 < tokens[j - 1].t1) { + tokens[j].t0 = tokens[j - 1].t1; + } else { + s0 = k; + } + } else { + while (pcm_avg[k] < thold && k < s1) { + k++; + } + s0 = k; + tokens[j].t0 = 100*k/WHISPER_SAMPLE_RATE; + } + } + + { + int k = s1; + if (pcm_avg[k] > thold) { + while (k < (int) pcmf32.size() - 1 && pcm_avg[k] > thold) { + k++; + } + tokens[j].t1 = 100*k/WHISPER_SAMPLE_RATE; + if (j < n - 1 && tokens[j].t1 > tokens[j + 1].t0) { + tokens[j].t1 = tokens[j + 1].t0; + } else { + s1 = k; + } + } else { + while (pcm_avg[k] < thold && k > s0) { + k--; + } + s1 = k; + tokens[j].t1 = 100*k/WHISPER_SAMPLE_RATE; + } + } + } + } + + const int t_expand = 0; + + for (int j = 0; j < n; j++) { + if (j > 0) { + tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand)); + } + if (j < n - 1) { + tokens[j].t1 = tokens[j].t1 + t_expand; + } + } + + for (int j = 0; j < n; ++j) { + const auto & token = tokens[j]; + const auto tt = token.pt > params.word_thold && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]"; + printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__, + tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, token.text.c_str()); + + if (tokens[j].id >= whisper_token_eot(ctx)) { + continue; + } + + //printf("[%s --> %s] %s\n", to_timestamp(token.t0).c_str(), to_timestamp(token.t1).c_str(), whisper_token_to_str(ctx, token.id)); + + //fout << "# " << to_timestamp(token.t0) << " --> " << to_timestamp(token.t1) << " " << whisper_token_to_str(ctx, token.id) << "\n"; + } + + static const int line_wrap = 60; + static const char * font = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf"; + + if (!is_first) { + fout << ","; + } + + // background text + fout << "drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << "," << t0/100.0 << ")'"; + + is_first = false; + + for (int j = 0; j < n; ++j) { + const auto & token = tokens[j]; + + if (tokens[j].id >= whisper_token_eot(ctx)) { + continue; + } + + std::string txt_bg; + std::string txt_fg; // highlight token + std::string txt_ul; // underline + + txt_bg = "> "; + txt_fg = "> "; + txt_ul = "\\ \\ "; + + { + int ncnt = 0; + for (int k = 0; k < n; ++k) { + const auto & token2 = tokens[k]; + + if (tokens[k].id >= whisper_token_eot(ctx)) { + continue; + } + + const std::string txt = whisper_token_to_str(ctx, token2.id); + + txt_bg += txt; + + if (k == j) { + for (int l = 0; l < (int) txt.size(); ++l) { + txt_fg += txt[l]; + txt_ul += "_"; + } + txt_fg += "|"; + } else { + for (int l = 0; l < (int) txt.size(); ++l) { + txt_fg += "\\ "; + txt_ul += "\\ "; + } + } + + ncnt += txt.size(); + + if (ncnt > line_wrap) { + if (k < j) { + txt_bg = "> "; + txt_fg = "> "; + txt_ul = "\\ \\ "; + ncnt = 0; + } else { + break; + } + } + } + + ::replace_all(txt_bg, "'", "’"); + ::replace_all(txt_bg, "\"", "\\\""); + ::replace_all(txt_fg, "'", "’"); + ::replace_all(txt_fg, "\"", "\\\""); + } + + // background text + fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='" << txt_bg << "':enable='between(t," << token.tt0/100.0 << "," << token.tt1/100.0 << ")'"; + + // foreground text + fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2:text='" << txt_fg << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'"; + + // underline + fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2+16:text='" << txt_ul << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'"; + } + } + + fout << "\" -c:v libx264 -pix_fmt yuv420p -y " << fname_inp << ".mp4" << "\n"; + + fout << "\n\n"; + fout << "echo \"Your video has been saved to " << fname_inp << ".mp4\"" << "\n"; + fout << "\n"; + fout << "echo \" ffplay " << fname_inp << ".mp4\"\n"; + fout << "\n"; + + fout.close(); + + fprintf(stderr, "%s: run 'source %s' to generate karaoke video\n", __func__, fname); + } + + return true; +} + int main(int argc, char ** argv) { whisper_params params; @@ -146,6 +712,11 @@ int main(int argc, char ** argv) { struct whisper_context * ctx = whisper_init(params.model.c_str()); + if (ctx == nullptr) { + fprintf(stderr, "error: failed to initialize whisper context\n"); + return 3; + } + for (int f = 0; f < (int) params.fname_inp.size(); ++f) { const auto fname_inp = params.fname_inp[f]; @@ -156,22 +727,22 @@ int main(int argc, char ** argv) { if (!drwav_init_file(&wav, fname_inp.c_str(), NULL)) { fprintf(stderr, "%s: failed to open WAV file '%s' - check your input\n", argv[0], fname_inp.c_str()); whisper_print_usage(argc, argv, {}); - return 3; + return 4; } if (wav.channels != 1 && wav.channels != 2) { fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str()); - return 4; + return 5; } if (wav.sampleRate != WHISPER_SAMPLE_RATE) { fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str()); - return 5; + return 6; } if (wav.bitsPerSample != 16) { fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str()); - return 6; + return 7; } int n = wav.totalPCMFrameCount; @@ -194,6 +765,13 @@ int main(int argc, char ** argv) { } } + // print system information + { + fprintf(stderr, "\n"); + fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", + params.n_threads*params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info()); + } + // print some info about the processing { fprintf(stderr, "\n"); @@ -204,8 +782,9 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__); } } - fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n", - __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, params.n_threads, + fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n", + __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, + params.n_threads, params.n_processors, params.language.c_str(), params.translate ? "translate" : "transcribe", params.no_timestamps ? 0 : 1); @@ -218,108 +797,54 @@ int main(int argc, char ** argv) { { whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); - wparams.print_realtime = true; + wparams.print_realtime = false; wparams.print_progress = false; wparams.print_timestamps = !params.no_timestamps; wparams.print_special_tokens = params.print_special_tokens; wparams.translate = params.translate; wparams.language = params.language.c_str(); wparams.n_threads = params.n_threads; - wparams.offset_ms = params.offset_ms; - - if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { - fprintf(stderr, "%s: failed to process audio\n", argv[0]); - return 7; - } + wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx; + wparams.offset_ms = params.offset_t_ms; - // print result + // this callback is called on each new segment if (!wparams.print_realtime) { - printf("\n"); - - const int n_segments = whisper_full_n_segments(ctx); - for (int i = 0; i < n_segments; ++i) { - const char * text = whisper_full_get_segment_text(ctx, i); - - if (params.no_timestamps) { - printf("%s", text); - fflush(stdout); - } else { - const int64_t t0 = whisper_full_get_segment_t0(ctx, i); - const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + wparams.new_segment_callback = whisper_print_segment_callback; + wparams.new_segment_callback_user_data = ¶ms; + } - printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text); - } - } + if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) { + fprintf(stderr, "%s: failed to process audio\n", argv[0]); + return 8; } + } + // output stuff + { printf("\n"); // output to text file if (params.output_txt) { - const auto fname_txt = fname_inp + ".txt"; - std::ofstream fout_txt(fname_txt); - if (!fout_txt.is_open()) { - fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_txt.c_str()); - return 8; - } - - fprintf(stderr, "%s: saving output to '%s.txt'\n", __func__, fname_inp.c_str()); - - const int n_segments = whisper_full_n_segments(ctx); - for (int i = 0; i < n_segments; ++i) { - const char * text = whisper_full_get_segment_text(ctx, i); - fout_txt << text; - } + output_txt(ctx, fname_txt.c_str()); } // output to VTT file if (params.output_vtt) { - const auto fname_vtt = fname_inp + ".vtt"; - std::ofstream fout_vtt(fname_vtt); - if (!fout_vtt.is_open()) { - fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_vtt.c_str()); - return 9; - } - - fprintf(stderr, "%s: saving output to '%s.vtt'\n", __func__, fname_inp.c_str()); - - fout_vtt << "WEBVTT\n\n"; - - const int n_segments = whisper_full_n_segments(ctx); - for (int i = 0; i < n_segments; ++i) { - const char * text = whisper_full_get_segment_text(ctx, i); - const int64_t t0 = whisper_full_get_segment_t0(ctx, i); - const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - - fout_vtt << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n"; - fout_vtt << text << "\n\n"; - } + output_vtt(ctx, fname_vtt.c_str()); } // output to SRT file if (params.output_srt) { - const auto fname_srt = fname_inp + ".srt"; - std::ofstream fout_srt(fname_srt); - if (!fout_srt.is_open()) { - fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_srt.c_str()); - return 10; - } - - fprintf(stderr, "%s: saving output to '%s.srt'\n", __func__, fname_inp.c_str()); - - const int n_segments = whisper_full_n_segments(ctx); - for (int i = 0; i < n_segments; ++i) { - const char * text = whisper_full_get_segment_text(ctx, i); - const int64_t t0 = whisper_full_get_segment_t0(ctx, i); - const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + output_srt(ctx, fname_srt.c_str(), params); + } - fout_srt << i + 1 << "\n"; - fout_srt << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n"; - fout_srt << text << "\n\n"; - } + // output to WTS file + if (params.output_wts) { + const auto fname_wts = fname_inp + ".wts"; + output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, pcmf32); } } } diff --git a/examples/whisper/whisper.cpp b/examples/whisper/whisper.cpp index 2d2b8ce..7f2b49b 100644 --- a/examples/whisper/whisper.cpp +++ b/examples/whisper/whisper.cpp @@ -1,3 +1,4 @@ +#define WHISPER_BUILD #include "whisper.h" #include "ggml.h" @@ -210,16 +211,13 @@ struct whisper_vocab { } }; -struct whisper_result { - int64_t t; - whisper_token id; -}; - struct whisper_segment { int64_t t0; int64_t t1; std::string text; + + std::vector tokens; }; // medium @@ -379,8 +377,12 @@ struct whisper_model { struct ggml_tensor * memory_cross_k; struct ggml_tensor * memory_cross_v; - // + // context struct ggml_context * ctx; + struct ggml_context * ctx_mem; + + // tensors + int n_loaded; std::map tensors; }; @@ -392,9 +394,10 @@ struct whisper_context { int64_t t_decode_us = 0; int64_t t_start_us = 0; - std::vector buf_model; - std::vector buf_compute; - std::vector buf_compute_layer; + std::vector * buf_model; // the model buffer is read-only and can be shared between processors + std::vector buf_memory; + std::vector buf_compute; + std::vector buf_compute_layer; whisper_model model; whisper_vocab vocab; @@ -404,7 +407,6 @@ struct whisper_context { std::vector probs; std::vector logits; - std::vector result_cur; std::vector result_all; std::vector prompt_past; @@ -494,13 +496,16 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) { fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16); fprintf(stderr, "%s: type = %d\n", __func__, model.type); - wctx.buf_model.resize(MEM_REQ_MODEL.at(model.type)); + wctx.buf_model = new std::vector(); + wctx.buf_model->resize(MEM_REQ_MODEL.at(model.type)); + wctx.buf_memory.resize(std::max(MEM_REQ_MODEL.at(model.type), MEM_REQ_MODEL.at(model.type))); // TODO: TMP !!! wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type))); // this is the total memory required to run the inference const size_t mem_required = - wctx.buf_model.size() + + wctx.buf_model->size() + + wctx.buf_memory.size() + wctx.buf_compute.size() + wctx.buf_compute_layer.size(); @@ -583,6 +588,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) { size_t ctx_size = 0; + size_t ctx_mem_size = 0; { const auto & hparams = model.hparams; @@ -691,11 +697,11 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) { ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b } - ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_k - ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_v + ctx_mem_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_k + ctx_mem_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_v - ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_k - ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_v + ctx_mem_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_k + ctx_mem_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_v ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead @@ -705,8 +711,8 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) { // create the ggml context { struct ggml_init_params params = { - .mem_size = wctx.buf_model.size(), - .mem_buffer = wctx.buf_model.data(), + .mem_size = wctx.buf_model->size(), + .mem_buffer = wctx.buf_model->data(), }; model.ctx = ggml_init(params); @@ -716,6 +722,20 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) { } } + // create the ggml memory context + { + struct ggml_init_params params = { + .mem_size = wctx.buf_memory.size(), + .mem_buffer = wctx.buf_memory.data(), + }; + + model.ctx_mem = ggml_init(params); + if (!model.ctx_mem) { + fprintf(stderr, "%s: ggml_init() failed\n", __func__); + return false; + } + } + // prepare memory for the weights { auto & ctx = model.ctx; @@ -914,7 +934,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) { // key + value memory { - auto & ctx = model.ctx; + auto & ctx = model.ctx_mem; const auto & hparams = model.hparams; @@ -946,14 +966,15 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) { ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) + ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v); - fprintf(stderr, "%s: memory size = %8.2f MB \n", __func__, memory_size/1024.0/1024.0); + fprintf(stderr, "%s: memory size = %8.2f MB\n", __func__, memory_size/1024.0/1024.0); } // load weights { - int n_loaded = 0; size_t total_size = 0; + model.n_loaded = 0; + while (true) { int32_t n_dims; int32_t length; @@ -1006,15 +1027,15 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) { //printf("%24s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0); total_size += ggml_nbytes(tensor); - n_loaded++; + model.n_loaded++; } fprintf(stderr, "%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0); - if (n_loaded == 0) { + if (model.n_loaded == 0) { fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__); - } else if (n_loaded != (int) model.tensors.size()) { - fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), n_loaded); + } else if (model.n_loaded != (int) model.tensors.size()) { + fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded); return false; } } @@ -1782,9 +1803,11 @@ bool whisper_decode( } // the most basic sampling scheme - select the top token -whisper_vocab::id whisper_sample_best( +whisper_token_data whisper_sample_best( const whisper_vocab & vocab, const float * probs) { + whisper_token_data result; + int n_logits = vocab.id_to_token.size(); std::vector> probs_id; @@ -1794,24 +1817,34 @@ whisper_vocab::id whisper_sample_best( probs_id.push_back(std::make_pair(probs[i], i)); } - double sum_ts = 0.0; - double max_tx = 0.0; + { + double sum_ts = 0.0; + double max_ts = -1.0; + double max_tx = -1.0; - for (int i = 0; i < vocab.token_beg; i++) { - max_tx = std::max(max_tx, probs_id[i].first); - } + for (int i = 0; i < vocab.token_beg; i++) { + max_tx = std::max(max_tx, probs_id[i].first); + } - for (int i = vocab.token_beg; i < n_logits; i++) { - sum_ts += probs_id[i].first; - } + for (int i = vocab.token_beg; i < n_logits; i++) { + sum_ts += probs_id[i].first; + if (probs_id[i].first > max_ts) { + max_ts = probs_id[i].first; + result.tid = probs_id[i].second; + } + } - // if the probability sum of all timestamp tokesn is higher than the max probability of the text tokens - sample a - // timestamp token - if (sum_ts > max_tx) { - // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438 - for (int i = 0; i < vocab.token_beg; i++) { - probs_id[i].first = -INFINITY; + // if the probability sum of all timestamp tokens is higher than the max probability of the text tokens - sample a + // timestamp token + if (sum_ts > max_tx) { + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438 + for (int i = 0; i < vocab.token_beg; i++) { + probs_id[i].first = -INFINITY; + } } + + result.pt = max_ts/(sum_ts + 1e-10); + result.ptsum = sum_ts; } // find the top K tokens @@ -1839,7 +1872,10 @@ whisper_vocab::id whisper_sample_best( res++; } - return probs_id[res].second; + result.id = probs_id[res].second; + result.p = probs_id[res].first; + + return result; } // samples only from the timestamps tokens @@ -1875,14 +1911,19 @@ whisper_vocab::id whisper_sample_timestamp( return probs_id[0].second; } -static std::string to_timestamp(int64_t t) { - int64_t sec = t/100; - int64_t msec = t - sec*100; - int64_t min = sec/60; - sec = sec - min*60; +// 500 -> 00:05.000 +// 6000 -> 01:00.000 +static std::string to_timestamp(int64_t t, bool comma = false) { + int64_t msec = t * 10; + int64_t hr = msec / (1000 * 60 * 60); + msec = msec - hr * (1000 * 60 * 60); + int64_t min = msec / (1000 * 60); + msec = msec - min * (1000 * 60); + int64_t sec = msec / 1000; + msec = msec - sec * 1000; char buf[32]; - snprintf(buf, sizeof(buf), "%02d:%02d.%03d", (int) min, (int) sec, (int) msec); + snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec); return std::string(buf); } @@ -2104,6 +2145,9 @@ struct whisper_context * whisper_init(const char * path_model) { void whisper_free(struct whisper_context * ctx) { if (ctx) { + if (ctx->buf_model) { + delete ctx->buf_model; + } delete ctx; } } @@ -2166,7 +2210,7 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i return 0; } -whisper_token whisper_sample_best(struct whisper_context * ctx) { +struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) { const int64_t t_start_sample_us = ggml_time_us(); // TODO: simplify @@ -2277,6 +2321,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.strategy =*/ WHISPER_SAMPLING_GREEDY, /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), + /*.n_max_text_ctx =*/ 16384, /*.offset_ms =*/ 0, /*.translate =*/ false, @@ -2297,6 +2342,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.beam_width =*/ -1, /*.n_best =*/ -1, }, + + /*.new_segment_callback =*/ nullptr, + /*.new_segment_callback_user_data =*/ nullptr, }; } break; case WHISPER_SAMPLING_BEAM_SEARCH: @@ -2305,6 +2353,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.strategy =*/ WHISPER_SAMPLING_BEAM_SEARCH, /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), + /*.n_max_text_ctx =*/ 16384, /*.offset_ms =*/ 0, /*.translate =*/ false, @@ -2325,6 +2374,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.beam_width =*/ 10, /*.n_best =*/ 5, }, + + /*.new_segment_callback =*/ nullptr, + /*.new_segment_callback_user_data =*/ nullptr, }; } break; } @@ -2339,7 +2391,6 @@ int whisper_full( int n_samples) { // clear old results auto & result_all = ctx->result_all; - auto & result_cur = ctx->result_cur; result_all.clear(); @@ -2349,10 +2400,12 @@ int whisper_full( return -1; } + const int seek_start = params.offset_ms/10; + // if length of spectrogram is less than 1s (100 samples), then return // basically don't process anything that is less than 1s // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39 - if (whisper_n_len(ctx) < 100) { + if (whisper_n_len(ctx) < 100 + seek_start) { return 0; } @@ -2376,8 +2429,14 @@ int whisper_full( int progress_prev = 0; int progress_step = 5; + std::vector tokens_cur; + tokens_cur.reserve(whisper_n_text_ctx(ctx)); + + std::vector prompt; + prompt.reserve(whisper_n_text_ctx(ctx)); + // main loop - int seek = params.offset_ms/10; + int seek = seek_start; while (true) { int progress_cur = (100*seek)/whisper_n_len(ctx); while (progress_cur >= progress_prev + progress_step) { @@ -2397,13 +2456,12 @@ int whisper_full( return 7; } - std::vector prompt; - int n_past = 0; + prompt.clear(); // if we have already generated some text, use it as a prompt to condition the next generation if (prompt_past.size() > 0) { - int n_take = std::min(whisper_n_text_ctx(ctx)/2, int(prompt_past.size())); + int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size())); prompt = { whisper_token_prev(ctx) }; prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end()); @@ -2426,7 +2484,7 @@ int whisper_full( // the accumulated transcription in the current interation int result_len = 0; - result_cur.clear(); + tokens_cur.clear(); for (int i = 0; i < whisper_n_text_ctx(ctx)/2 - 4; ++i) { if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) { @@ -2445,28 +2503,29 @@ int whisper_full( // feel free to experiment! // { - whisper_token id = 0; - whisper_token tid = whisper_token_beg(ctx); + auto token = whisper_sample_best(ctx); - id = whisper_sample_best(ctx); - if (i > 0) { - tid = whisper_sample_timestamp(ctx); + if (i == 0) { + token.tid = whisper_token_beg(ctx); } - // update sliding window - if (id > whisper_token_beg(ctx)) { - seek_delta = 2*(id - whisper_token_beg(ctx)); + // timestamp token - update sliding window + if (token.id > whisper_token_beg(ctx)) { + seek_delta = 2*(token.id - whisper_token_beg(ctx)); result_len = i + 1; } // add it to the context - prompt.push_back(id); - result_cur.push_back({ seek + 2*(tid - whisper_token_beg(ctx)), id }); + prompt.push_back(token.id); + tokens_cur.push_back(token); - //printf("%s: %s\n", __func__, ctx->vocab.id_to_token[id].c_str()); + //{ + // const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]"; + // printf("%s: %10s %6.3f '%s'\n", __func__, tt.c_str(), token.pt, ctx->vocab.id_to_token[token.id].c_str()); + //} // end of text token - if (id == whisper_token_eot(ctx)) { + if (token.id == whisper_token_eot(ctx)) { if (result_len == 0) { if (seek + seek_delta + 100 >= whisper_n_len(ctx)) { result_len = i + 1; @@ -2477,6 +2536,12 @@ int whisper_full( } break; } + + // TESTS: if no tensors are loaded, it means we are running tests + if (ctx->model.n_loaded == 0) { + seek_delta = 100*WHISPER_CHUNK_SIZE; + break; + } } if (done) { @@ -2484,25 +2549,30 @@ int whisper_full( } } - result_cur.resize(result_len); + tokens_cur.resize(result_len); - for (const auto & r : result_cur) { + for (const auto & r : tokens_cur) { prompt_past.push_back(r.id); } // store the text from this iteration - if (result_cur.size() > 0) { - auto t0 = result_cur.front().t; + if (tokens_cur.size() > 0) { + int i0 = 0; + auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx)); std::string text = ""; - for (int i = 0; i < (int) result_cur.size(); i++) { - if (params.print_special_tokens == false && result_cur[i].id >= whisper_token_eot(ctx)) { + for (int i = 0; i < (int) tokens_cur.size(); i++) { + //printf("%s: %18s %6.3f %18s %6.3f\n", __func__, + // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p, + // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt); + + if (params.print_special_tokens == false && tokens_cur[i].id >= whisper_token_eot(ctx)) { } else { - text += whisper_token_to_str(ctx, result_cur[i].id); + text += whisper_token_to_str(ctx, tokens_cur[i].id); } - if (result_cur[i].id > whisper_token_beg(ctx)) { - const auto t1 = result_cur[i].t; + if (tokens_cur[i].id > whisper_token_beg(ctx)) { + const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx)); if (!text.empty()) { if (params.print_realtime) { if (params.print_timestamps) { @@ -2513,14 +2583,21 @@ int whisper_full( } } - result_all.push_back({ t0, t1, text }); + result_all.push_back({ t0, t1, text, {} }); + for (int j = i0; j <= i; j++) { + result_all.back().tokens.push_back(tokens_cur[j]); + } + if (params.new_segment_callback) { + params.new_segment_callback(ctx, params.new_segment_callback_user_data); + } } text = ""; - while (i < (int) result_cur.size() && result_cur[i].id > whisper_token_beg(ctx)) { + while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) { i++; } i--; - t0 = result_cur[i].t; + t0 = t1; + i0 = i + 1; } } @@ -2536,7 +2613,13 @@ int whisper_full( } } - result_all.push_back({ t0, t1, text }); + result_all.push_back({ t0, t1, text, {} }); + for (int j = i0; j < (int) tokens_cur.size(); j++) { + result_all.back().tokens.push_back(tokens_cur[j]); + } + if (params.new_segment_callback) { + params.new_segment_callback(ctx, params.new_segment_callback_user_data); + } } } @@ -2546,6 +2629,156 @@ int whisper_full( return 0; } +int whisper_full_parallel( + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples, + const int n_processors) { + if (n_processors == 1) { + return whisper_full(ctx, params, samples, n_samples); + } + + int ret = 0; + + // prepare separate contexts for each thread + std::vector ctxs(n_processors - 1); + + for (int i = 0; i < n_processors - 1; ++i) { + ctxs[i] = *ctx; + + auto & model = ctxs[i].model; + + // create the ggml memory context + { + struct ggml_init_params params = { + .mem_size = ctxs[i].buf_memory.size(), + .mem_buffer = ctxs[i].buf_memory.data(), + }; + + model.ctx_mem = ggml_init(params); + if (!model.ctx_mem) { + fprintf(stderr, "%s: ggml_init() failed\n", __func__); + return false; + } + } + + // separate key + value memory for each processor + { + auto & ctx = model.ctx_mem; + + const auto & hparams = model.hparams; + + const int n_text_state = hparams.n_text_state; + const int n_text_layer = hparams.n_text_layer; + const int n_text_ctx = hparams.n_text_ctx; + + // key/value memory for the self-attention layer + { + const int n_mem = n_text_layer*n_text_ctx; + const int n_elements = n_text_state*n_mem; + + model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); + model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); + } + + // key/value memory for the cross-attention layer + { + const int n_audio_ctx = hparams.n_audio_ctx; + + const int n_mem = n_text_layer*n_audio_ctx; + const int n_elements = n_text_state*n_mem; + + model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); + model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); + } + + const size_t memory_size = + ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) + + ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v); + } + } + + const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000; + const int n_samples_per_processor = (n_samples - offset_samples)/n_processors; + + // the calling thread will process the first chunk + // while the other threads will process the remaining chunks + + std::vector workers(n_processors - 1); + for (int i = 0; i < n_processors - 1; ++i) { + const int start_samples = offset_samples + (i + 1)*n_samples_per_processor; + const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor; + + auto params_cur = params; + + params_cur.offset_ms = 0; + params_cur.print_progress = false; + params_cur.print_realtime = false; + + params_cur.new_segment_callback = nullptr; + params_cur.new_segment_callback_user_data = nullptr; + + workers[i] = std::thread(whisper_full, &ctxs[i], std::move(params_cur), samples + start_samples, n_samples_cur); + } + + { + auto params_cur = params; + + ret = whisper_full(ctx, std::move(params_cur), samples, offset_samples + n_samples_per_processor); + } + + for (int i = 0; i < n_processors - 1; ++i) { + workers[i].join(); + } + + const int64_t offset_t = (int64_t) params.offset_ms/10.0; + + // combine results into ctx->result_all + for (int i = 0; i < n_processors - 1; ++i) { + auto & results_i = ctxs[i].result_all; + + for (int j = 0; j < (int) results_i.size(); ++j) { + // correct the segment timestamp taking into account the offset + results_i[j].t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t; + results_i[j].t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t; + + // make sure that segments are not overlapping + if (ctx->result_all.size() > 0) { + results_i[j].t0 = std::max(results_i[j].t0, ctx->result_all.back().t1); + } + + ctx->result_all.push_back(std::move(results_i[j])); + + // call the new_segment_callback for each segment + if (params.new_segment_callback) { + params.new_segment_callback(ctx, params.new_segment_callback_user_data); + } + } + + ctx->t_mel_us += ctxs[i].t_mel_us; + ctx->t_sample_us += ctxs[i].t_sample_us; + ctx->t_encode_us += ctxs[i].t_encode_us; + ctx->t_decode_us += ctxs[i].t_decode_us; + } + + // average the timings + ctx->t_mel_us /= n_processors; + ctx->t_sample_us /= n_processors; + ctx->t_encode_us /= n_processors; + ctx->t_decode_us /= n_processors; + + // print information about the audio boundaries + fprintf(stderr, "\n"); + fprintf(stderr, "%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors); + for (int i = 0; i < n_processors - 1; ++i) { + fprintf(stderr, "%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str()); + } + fprintf(stderr, "%s: the transcription quality may be degraded near these boundaries\n", __func__); + + return ret; +} + int whisper_full_n_segments(struct whisper_context * ctx) { return ctx->result_all.size(); } @@ -2561,3 +2794,37 @@ int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) { return ctx->result_all[i_segment].text.c_str(); } + +int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment) { + return ctx->result_all[i_segment].tokens.size(); +} + +const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) { + return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str(); +} + +whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) { + return ctx->result_all[i_segment].tokens[i_token].id; +} + +struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token) { + return ctx->result_all[i_segment].tokens[i_token]; +} + +float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) { + return ctx->result_all[i_segment].tokens[i_token].p; +} + +const char * whisper_print_system_info() { + static std::string s; + + s = ""; + s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | "; + s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | "; + s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | "; + s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | "; + s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | "; + s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | "; + + return s.c_str(); +} diff --git a/examples/whisper/whisper.h b/examples/whisper/whisper.h index 4423674..5d7c40d 100644 --- a/examples/whisper/whisper.h +++ b/examples/whisper/whisper.h @@ -68,6 +68,15 @@ extern "C" { typedef int whisper_token; + struct whisper_token_data { + whisper_token id; // token id + whisper_token tid; // forced timestamp token id + + float p; // probability of the token + float pt; // probability of the timestamp token + float ptsum; // sum of probabilities of all timestamp tokens + }; + // Allocates all memory needed for the model and loads the model from the given file. // Returns NULL on failure. WHISPER_API struct whisper_context * whisper_init(const char * path_model); @@ -120,7 +129,7 @@ extern "C" { // You can also implement your own sampling method using the whisper_get_probs() function. // whisper_sample_best() returns the token with the highest probability // whisper_sample_timestamp() returns the most probable timestamp token - WHISPER_API whisper_token whisper_sample_best(struct whisper_context * ctx); + WHISPER_API struct whisper_token_data whisper_sample_best(struct whisper_context * ctx); WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx); // Return the id of the specified language, returns -1 if not found @@ -160,10 +169,16 @@ extern "C" { WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet! }; + // Text segment callback + // Called on every newly generated text segment + // Use the whisper_full_...() functions to obtain the text segments + typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, void * user_data); + struct whisper_full_params { enum whisper_sampling_strategy strategy; int n_threads; + int n_max_text_ctx; int offset_ms; bool translate; @@ -184,6 +199,9 @@ extern "C" { int beam_width; int n_best; } beam_search; + + whisper_new_segment_callback new_segment_callback; + void * new_segment_callback_user_data; }; WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy); @@ -196,6 +214,16 @@ extern "C" { const float * samples, int n_samples); + // Split the input audio in chunks and process each chunk separately using whisper_full() + // It seems this approach can offer some speedup in some cases. + // However, the transcription accuracy can be worse at the beginning and end of each chunk. + WHISPER_API int whisper_full_parallel( + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples, + const int n_processors); + // Number of generated text segments. // A segment can be a few words, a sentence, or even a paragraph. WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx); @@ -207,6 +235,23 @@ extern "C" { // Get the text of the specified segment. WHISPER_API const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment); + // Get number of tokens in the specified segment. + WHISPER_API int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment); + + // Get the token text of the specified token in the specified segment. + WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token); + + // Get token data for the specified token in the specified segment. + // This contains probabilities, timestamps, etc. + WHISPER_API struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token); + + // Get the probability of the specified token in the specified segment. + WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token); + + // Print system information + WHISPER_API const char * whisper_print_system_info(); + #ifdef __cplusplus } #endif diff --git a/include/ggml/ggml.h b/include/ggml/ggml.h index 34f104b..f92ae73 100644 --- a/include/ggml/ggml.h +++ b/include/ggml/ggml.h @@ -11,7 +11,7 @@ extern "C" { #define GGML_MAX_DIMS 4 #define GGML_MAX_NODES 4096 #define GGML_MAX_PARAMS 16 -#define GGML_MAX_CONTEXTS 16 +#define GGML_MAX_CONTEXTS 64 #define GGML_MAX_OPT 4 #ifdef __ARM_NEON @@ -548,6 +548,17 @@ enum ggml_opt_result ggml_opt( struct ggml_opt_params params, struct ggml_tensor * f); +// +// system info +// + +int ggml_cpu_has_avx2(void); +int ggml_cpu_has_avx512(void); +int ggml_cpu_has_neon(void); +int ggml_cpu_has_fp16_va(void); +int ggml_cpu_has_wasm_simd(void); +int ggml_cpu_has_blas(void); + #ifdef __cplusplus } #endif diff --git a/src/ggml.c b/src/ggml.c index 115e619..1000a5b 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -15,10 +15,44 @@ #include #if defined _MSC_VER -#include "msvc_thread_atomic.h" +#include + +typedef volatile LONG atomic_int; +typedef atomic_int atomic_bool; + +static void atomic_store(atomic_int* ptr, LONG val) { + InterlockedExchange(ptr, val); +} +static LONG atomic_load(atomic_int* ptr) { + return InterlockedCompareExchange(ptr, 0, 0); +} +static LONG atomic_fetch_add(atomic_int* ptr, LONG inc) { + return InterlockedExchangeAdd(ptr, inc); +} +static LONG atomic_fetch_sub(atomic_int* ptr, LONG dec) { + return atomic_fetch_add(ptr, -(dec)); +} + +typedef HANDLE pthread_t; + +typedef DWORD thread_ret_t; +static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) { + out = CreateThread(NULL, 0, func, arg, 0, NULL); + return out != NULL; +} + +static int pthread_join(pthread_t thread, void* unused) { + return (int) WaitForSingleObject(thread, INFINITE); +} + +static int sched_yield (void) { + Sleep (0); + return 0; +} #else #include #include + typedef void* thread_ret_t; #endif @@ -47,6 +81,8 @@ typedef void* thread_ret_t; #ifdef GGML_USE_ACCELERATE #include +#elif GGML_USE_OPENBLAS +#include #endif // floating point type used to accumulate sums @@ -73,7 +109,11 @@ ggml_fp16_t ggml_fp32_to_fp16(float x) { #else +#ifdef __wasm_simd128__ +#include +#else #include +#endif // FP16 <-> FP32 // ref: https://github.com/Maratyszcza/FP16 @@ -288,7 +328,7 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float sumf += x[i]*y[i]; } #elif defined(__AVX2__) - // AVX 256-bit (unroll 4) + // AVX 256-bit const int n32 = (n & ~31); __m256 sum0 = _mm256_setzero_ps(); @@ -330,6 +370,45 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float for (int i = n32; i < n; ++i) { sumf += x[i]*y[i]; } +#elif defined(__wasm_simd128__) + // WASM 128-bit + const int n16 = (n & ~15); + + v128_t sum0 = wasm_f32x4_splat(0); + v128_t sum1 = wasm_f32x4_splat(0); + v128_t sum2 = wasm_f32x4_splat(0); + v128_t sum3 = wasm_f32x4_splat(0); + + v128_t x0, x1, x2, x3; + v128_t y0, y1, y2, y3; + + for (int i = 0; i < n16; i += 16) { + x0 = wasm_v128_load(x + i + 0); + x1 = wasm_v128_load(x + i + 4); + x2 = wasm_v128_load(x + i + 8); + x3 = wasm_v128_load(x + i + 12); + + y0 = wasm_v128_load(y + i + 0); + y1 = wasm_v128_load(y + i + 4); + y2 = wasm_v128_load(y + i + 8); + y3 = wasm_v128_load(y + i + 12); + + sum0 = wasm_f32x4_add(sum0, wasm_f32x4_mul(x0, y0)); + sum1 = wasm_f32x4_add(sum1, wasm_f32x4_mul(x1, y1)); + sum2 = wasm_f32x4_add(sum2, wasm_f32x4_mul(x2, y2)); + sum3 = wasm_f32x4_add(sum3, wasm_f32x4_mul(x3, y3)); + } + + sum0 = wasm_f32x4_add(sum0, sum1); + sum2 = wasm_f32x4_add(sum2, sum3); + sum0 = wasm_f32x4_add(sum0, sum2); + + sumf = wasm_f32x4_extract_lane(sum0, 0) + wasm_f32x4_extract_lane(sum0, 1) + wasm_f32x4_extract_lane(sum0, 2) + wasm_f32x4_extract_lane(sum0, 3); + + // leftovers + for (int i = n16; i < n; ++i) { + sumf += x[i]*y[i]; + } #else // scalar for (int i = 0; i < n; ++i) { @@ -446,7 +525,7 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]); } #elif defined(__AVX2__) - // AVX 256-bit (unroll 4) + // AVX 256-bit const int n32 = (n & ~31); __m256 sum0 = _mm256_setzero_ps(); @@ -489,6 +568,54 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t //GGML_ASSERT(false); sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]); } +#elif defined(__wasm_simd128__) + // WASM 128-bit + const int n16 = (n & ~15); + + v128_t sum0 = wasm_f32x4_splat(0.0f); + v128_t sum1 = wasm_f32x4_splat(0.0f); + v128_t sum2 = wasm_f32x4_splat(0.0f); + v128_t sum3 = wasm_f32x4_splat(0.0f); + + v128_t x0, x1, x2, x3; + v128_t y0, y1, y2, y3; + + float tx[16]; + float ty[16]; + + for (int i = 0; i < n16; i += 16) { + for (int k = 0; k < 16; ++k) { + tx[k] = ggml_fp16_to_fp32(x[i + k]); + ty[k] = ggml_fp16_to_fp32(y[i + k]); + } + + x0 = wasm_v128_load(tx + 0); + x1 = wasm_v128_load(tx + 4); + x2 = wasm_v128_load(tx + 8); + x3 = wasm_v128_load(tx + 12); + + y0 = wasm_v128_load(ty + 0); + y1 = wasm_v128_load(ty + 4); + y2 = wasm_v128_load(ty + 8); + y3 = wasm_v128_load(ty + 12); + + sum0 = wasm_f32x4_add(sum0, wasm_f32x4_mul(x0, y0)); + sum1 = wasm_f32x4_add(sum1, wasm_f32x4_mul(x1, y1)); + sum2 = wasm_f32x4_add(sum2, wasm_f32x4_mul(x2, y2)); + sum3 = wasm_f32x4_add(sum3, wasm_f32x4_mul(x3, y3)); + } + + sum0 = wasm_f32x4_add(sum0, sum1); + sum2 = wasm_f32x4_add(sum2, sum3); + sum0 = wasm_f32x4_add(sum0, sum2); + + sumf = wasm_f32x4_extract_lane(sum0, 0) + wasm_f32x4_extract_lane(sum0, 1) + wasm_f32x4_extract_lane(sum0, 2) + wasm_f32x4_extract_lane(sum0, 3); + + // leftovers + for (int i = n16; i < n; ++i) { + //GGML_ASSERT(false); + sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]); + } #else for (int i = 0; i < n; ++i) { sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]); @@ -535,7 +662,7 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float y[i] += x[i]*v; } #elif defined(__AVX2__) - // AVX 256-bit (unroll 4) + // AVX 256-bit const int n32 = (n & ~31); const __m256 v4 = _mm256_set1_ps(v); @@ -569,6 +696,41 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float for (int i = n32; i < n; ++i) { y[i] += x[i]*v; } +#elif defined(__wasm_simd128__) + // WASM SIMD 128-bit + const int n16 = (n & ~15); + + const v128_t v4 = wasm_f32x4_splat(v); + + v128_t x0, x1, x2, x3; + v128_t y0, y1, y2, y3; + + for (int i = 0; i < n16; i += 16) { + x0 = wasm_v128_load(x + i + 0); + x1 = wasm_v128_load(x + i + 4); + x2 = wasm_v128_load(x + i + 8); + x3 = wasm_v128_load(x + i + 12); + + y0 = wasm_v128_load(y + i + 0); + y1 = wasm_v128_load(y + i + 4); + y2 = wasm_v128_load(y + i + 8); + y3 = wasm_v128_load(y + i + 12); + + y0 = wasm_f32x4_add(y0, wasm_f32x4_mul(x0, v4)); + y1 = wasm_f32x4_add(y1, wasm_f32x4_mul(x1, v4)); + y2 = wasm_f32x4_add(y2, wasm_f32x4_mul(x2, v4)); + y3 = wasm_f32x4_add(y3, wasm_f32x4_mul(x3, v4)); + + wasm_v128_store(y + i + 0, y0); + wasm_v128_store(y + i + 4, y1); + wasm_v128_store(y + i + 8, y2); + wasm_v128_store(y + i + 12, y3); + } + + // leftovers + for (int i = n16; i < n; ++i) { + y[i] += x[i]*v; + } #else // scalar for (int i = 0; i < n; ++i) { @@ -696,6 +858,54 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_ GGML_ASSERT(false); y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v); } +#elif defined(__wasm_simd128__) + // WASM SIMD 128-bit + const int n16 = (n & ~15); + + const v128_t v4 = wasm_f32x4_splat(v); + + v128_t x0, x1, x2, x3; + v128_t y0, y1, y2, y3; + + float tx[16]; + float ty[16]; + + for (int i = 0; i < n16; i += 16) { + for (int k = 0; k < 16; ++k) { + tx[k] = ggml_fp16_to_fp32(x[i + k]); + ty[k] = ggml_fp16_to_fp32(y[i + k]); + } + + x0 = wasm_v128_load(tx + 0); + x1 = wasm_v128_load(tx + 4); + x2 = wasm_v128_load(tx + 8); + x3 = wasm_v128_load(tx + 12); + + y0 = wasm_v128_load(ty + 0); + y1 = wasm_v128_load(ty + 4); + y2 = wasm_v128_load(ty + 8); + y3 = wasm_v128_load(ty + 12); + + y0 = wasm_f32x4_add(y0, wasm_f32x4_mul(x0, v4)); + y1 = wasm_f32x4_add(y1, wasm_f32x4_mul(x1, v4)); + y2 = wasm_f32x4_add(y2, wasm_f32x4_mul(x2, v4)); + y3 = wasm_f32x4_add(y3, wasm_f32x4_mul(x3, v4)); + + wasm_v128_store(ty + 0, y0); + wasm_v128_store(ty + 4, y1); + wasm_v128_store(ty + 8, y2); + wasm_v128_store(ty + 12, y3); + + for (int k = 0; k < 16; ++k) { + y[i + k] = ggml_fp32_to_fp16(ty[k]); + } + } + + // leftovers + for (int i = n16; i < n; ++i) { + GGML_ASSERT(false); + y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v); + } #else for (int i = 0; i < n; ++i) { y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v); @@ -931,6 +1141,7 @@ struct ggml_state { // global state struct ggml_state g_state; +atomic_int g_state_barrier = 0; //////////////////////////////////////////////////////////////////////////////// @@ -1060,6 +1271,17 @@ int ggml_up64(int n) { //////////////////////////////////////////////////////////////////////////////// struct ggml_context * ggml_init(struct ggml_init_params params) { + // make this function thread safe + { + int processing = atomic_fetch_add(&g_state_barrier, 1); + while (processing > 0) { + // wait for other threads to finish + atomic_fetch_sub(&g_state_barrier, 1); + sched_yield(); + processing = atomic_fetch_add(&g_state_barrier, 1); + } + } + static bool is_first_call = true; if (is_first_call) { const uint64_t t_start = ggml_time_us(); UNUSED(t_start); @@ -1103,6 +1325,9 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { if (ctx == NULL) { GGML_PRINT_DEBUG("%s: no unused context found\n", __func__); + + atomic_fetch_sub(&g_state_barrier, 1); + return NULL; } @@ -1117,10 +1342,25 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { ggml_assert_aligned(ctx->mem_buffer); + GGML_PRINT_DEBUG("%s: context initialized\n", __func__); + + atomic_fetch_sub(&g_state_barrier, 1); + return ctx; } void ggml_free(struct ggml_context * ctx) { + // make this function thread safe + { + int processing = atomic_fetch_add(&g_state_barrier, 1); + while (processing > 0) { + // wait for other threads to finish + atomic_fetch_sub(&g_state_barrier, 1); + sched_yield(); + processing = atomic_fetch_add(&g_state_barrier, 1); + } + } + for (int i = 0; i < GGML_MAX_CONTEXTS; i++) { if (&g_state.contexts[i].context == ctx) { g_state.contexts[i].used = false; @@ -1132,11 +1372,15 @@ void ggml_free(struct ggml_context * ctx) { free(ctx->mem_buffer); } + atomic_fetch_sub(&g_state_barrier, 1); + return; } } GGML_PRINT_DEBUG("%s: context not found\n", __func__); + + atomic_fetch_sub(&g_state_barrier, 1); } size_t ggml_used_mem(const struct ggml_context * ctx) { @@ -3852,46 +4096,44 @@ void ggml_compute_forward_mul_mat_f32( // nb00 < nb01 - src0 is transposed // compute by src0 columns -//#ifdef GGML_USE_ACCELERATE -// if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { -// GGML_ASSERT(ggml_is_contiguous(src0)); -// GGML_ASSERT(nb10 == sizeof(float)); -// -// if (params->ith != 0) return; -// -// if (params->type == GGML_TASK_INIT) { -// return; -// } -// -// if (params->type == GGML_TASK_FINALIZE) { -// return; -// } -// -// float * const wdata = params->wdata; -// -// for (int i03 = 0; i03 < ne03; i03++) { -// for (int i02 = 0; i02 < ne02; i02++) { -// const float * x = (float *) (src0->data); -// const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); -// -// float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); -// -// // zT = y * xT -// { -// cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, -// ne11, ne01, ne10, -// 1.0f, y, ne10, -// x, ne10, -// 0.0f, d, ne01); -// } -// } -// } -// -// //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); -// -// return; -// } -//#endif +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->ith != 0) return; + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + const float * x = (float *) (src0->data); + const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); + + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + + // zT = y * xT + { + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + ne11, ne01, ne10, + 1.0f, y, ne10, + x, ne10, + 0.0f, d, ne01); + } + } + } + + //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); + + return; + } +#endif if (params->type == GGML_TASK_INIT) { if (nb01 >= nb00) { @@ -4098,7 +4340,7 @@ void ggml_compute_forward_mul_mat_f16_f32( // nb00 < nb01 - src0 is transposed // compute by src0 columns -#ifdef GGML_USE_ACCELERATE +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { GGML_ASSERT(nb10 == sizeof(float)); @@ -6654,7 +6896,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) } else { if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) { -#ifdef GGML_USE_ACCELERATE +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { cur = sizeof(float)*(node->src0->ne[0]*node->src0->ne[1]); } else { @@ -7358,7 +7600,7 @@ enum ggml_opt_result ggml_opt_adam( { const int64_t t_end_cpu = ggml_cycles(); - GGML_PRINT_DEBUG("time iter: %5.3f s\n", (t_end_cpu - t_start_cpu)/CLOCKS_PER_SEC); + GGML_PRINT_DEBUG("time iter: %5.3f s\n", ((float)(t_end_cpu - t_start_cpu))/CLOCKS_PER_SEC); UNUSED(t_end_cpu); const int64_t t_end_wall = ggml_time_us(); @@ -7829,3 +8071,53 @@ enum ggml_opt_result ggml_opt( } //////////////////////////////////////////////////////////////////////////////// + +int ggml_cpu_has_avx2(void) { +#if defined(__AVX2__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_avx512(void) { +#if defined(__AVX512F__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_neon(void) { +#if defined(__ARM_NEON__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_fp16_va(void) { +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_wasm_simd(void) { +#if defined(__wasm_simd128__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_blas(void) { +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + return 1; +#else + return 0; +#endif +} + +//////////////////////////////////////////////////////////////////////////////// diff --git a/src/msvc_thread_atomic.h b/src/msvc_thread_atomic.h deleted file mode 100644 index 52cd419..0000000 --- a/src/msvc_thread_atomic.h +++ /dev/null @@ -1,31 +0,0 @@ -#pragma once -#include - -typedef volatile LONG atomic_int; -typedef atomic_int atomic_bool; - -static void atomic_store(atomic_int* ptr, LONG val) { - InterlockedExchange(ptr, val); -} -static LONG atomic_load(atomic_int* ptr) { - return InterlockedCompareExchange(ptr, 0, 0); -} -static LONG atomic_fetch_add(atomic_int* ptr, LONG inc) { - return InterlockedExchangeAdd(ptr, inc); -} -static LONG atomic_fetch_sub(atomic_int* ptr, LONG dec) { - return atomic_fetch_add(ptr, -(dec)); -} - -typedef HANDLE pthread_t; - -typedef DWORD thread_ret_t; -static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) { - out = CreateThread(NULL, 0, func, arg, 0, NULL); - return out != NULL; -} - -static int pthread_join(pthread_t thread, void* unused) { - return (int) WaitForSingleObject(thread, INFINITE); -} -