From 90ee5c6358a3f33a5615256a0b229aa793ff4b49 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 9 Nov 2022 21:43:03 +0200 Subject: [PATCH] sync : latest changes from whisper.cpp - Documentation - whisper : token-level timestamps - ggml : Windows build fixes - etc. --- examples/whisper/main.cpp | 546 ++++++++++------------------------- examples/whisper/whisper.cpp | 494 ++++++++++++++++++++++++++++--- examples/whisper/whisper.h | 26 +- include/ggml/ggml.h | 175 ++++++++++- src/ggml.c | 23 +- 5 files changed, 818 insertions(+), 446 deletions(-) diff --git a/examples/whisper/main.cpp b/examples/whisper/main.cpp index 1be0032..7058031 100644 --- a/examples/whisper/main.cpp +++ b/examples/whisper/main.cpp @@ -36,6 +36,7 @@ std::string to_timestamp(int64_t t, bool comma = false) { return std::string(buf); } +// helper function to replace substrings void replace_all(std::string & s, const std::string & search, const std::string & replace) { for (size_t pos = 0; ; pos += replace.length()) { pos = s.find(search, pos); @@ -45,31 +46,6 @@ void replace_all(std::string & s, const std::string & search, const std::string } } -// a cost-function that is high for text that takes longer to pronounce -float voice_length(const std::string & text) { - float res = 0.0f; - - for (size_t i = 0; i < text.size(); ++i) { - if (text[i] == ' ') { - res += 0.01f; - } else if (text[i] == ',') { - res += 2.00f; - } else if (text[i] == '.') { - res += 3.00f; - } else if (text[i] == '!') { - res += 3.00f; - } else if (text[i] == '?') { - res += 3.00f; - } else if (text[i] >= '0' && text[i] <= '9') { - res += 3.00f; - } else { - res += 1.00f; - } - } - - return res; -} - // command-line parameters struct whisper_params { int32_t seed = -1; // RNG seed, not used currently @@ -77,7 +53,9 @@ struct whisper_params { int32_t n_processors = 1; int32_t offset_t_ms = 0; int32_t offset_n = 0; + int32_t duration_ms = 0; int32_t max_context = -1; + int32_t max_len = 0; float word_thold = 0.01f; @@ -118,8 +96,12 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { params.offset_t_ms = std::stoi(argv[++i]); } else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); + } else if (arg == "-d" || arg == "--duration") { + params.duration_ms = std::stoi(argv[++i]); } else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); + } else if (arg == "-ml" || arg == "--max-len") { + params.max_len = std::stoi(argv[++i]); } else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } else if (arg == "-v" || arg == "--verbose") { @@ -175,14 +157,16 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params) fprintf(stderr, " -p N, --processors N number of processors to use during computation (default: %d)\n", params.n_processors); fprintf(stderr, " -ot N, --offset-t N time offset in milliseconds (default: %d)\n", params.offset_t_ms); fprintf(stderr, " -on N, --offset-n N segment index offset (default: %d)\n", params.offset_n); + fprintf(stderr, " -d N, --duration N duration of audio to process in milliseconds (default: %d)\n", params.duration_ms); fprintf(stderr, " -mc N, --max-context N maximum number of text context tokens to store (default: max)\n"); + fprintf(stderr, " -ml N, --max-len N maximum segment length in characters (default: %d)\n", params.max_len); fprintf(stderr, " -wt N, --word-thold N word timestamp probability threshold (default: %f)\n", params.word_thold); fprintf(stderr, " -v, --verbose verbose output\n"); fprintf(stderr, " --translate translate from source language to english\n"); fprintf(stderr, " -otxt, --output-txt output result in a text file\n"); fprintf(stderr, " -ovtt, --output-vtt output result in a vtt file\n"); fprintf(stderr, " -osrt, --output-srt output result in a srt file\n"); - fprintf(stderr, " -owts, --output-words output word-level timestamps to a text file\n"); + fprintf(stderr, " -owts, --output-words output script for generating karaoke video\n"); fprintf(stderr, " -ps, --print_special print special tokens\n"); fprintf(stderr, " -pc, --print_colors print colors\n"); fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n"); @@ -192,65 +176,67 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params) fprintf(stderr, "\n"); } -void whisper_print_segment_callback(struct whisper_context * ctx, void * user_data) { +void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) { const whisper_params & params = *(whisper_params *) user_data; const int n_segments = whisper_full_n_segments(ctx); - // print the last segment - const int i = n_segments - 1; - if (i == 0) { + // print the last n_new segments + const int s0 = n_segments - n_new; + if (s0 == 0) { printf("\n"); } - if (params.no_timestamps) { - if (params.print_colors) { - for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { - if (params.print_special_tokens == false) { - const whisper_token id = whisper_full_get_token_id(ctx, i, j); - if (id >= whisper_token_eot(ctx)) { - continue; + for (int i = s0; i < n_segments; i++) { + if (params.no_timestamps) { + if (params.print_colors) { + for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { + if (params.print_special_tokens == false) { + const whisper_token id = whisper_full_get_token_id(ctx, i, j); + if (id >= whisper_token_eot(ctx)) { + continue; + } } - } - const char * text = whisper_full_get_token_text(ctx, i, j); - const float p = whisper_full_get_token_p (ctx, i, j); + const char * text = whisper_full_get_token_text(ctx, i, j); + const float p = whisper_full_get_token_p (ctx, i, j); - const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size())))); + const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size())))); - printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m"); + printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m"); + } + } else { + const char * text = whisper_full_get_segment_text(ctx, i); + printf("%s", text); } + fflush(stdout); } else { - const char * text = whisper_full_get_segment_text(ctx, i); - printf("%s", text); - } - fflush(stdout); - } else { - const int64_t t0 = whisper_full_get_segment_t0(ctx, i); - const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - if (params.print_colors) { - printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str()); - for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { - if (params.print_special_tokens == false) { - const whisper_token id = whisper_full_get_token_id(ctx, i, j); - if (id >= whisper_token_eot(ctx)) { - continue; + if (params.print_colors) { + printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str()); + for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { + if (params.print_special_tokens == false) { + const whisper_token id = whisper_full_get_token_id(ctx, i, j); + if (id >= whisper_token_eot(ctx)) { + continue; + } } - } - const char * text = whisper_full_get_token_text(ctx, i, j); - const float p = whisper_full_get_token_p (ctx, i, j); + const char * text = whisper_full_get_token_text(ctx, i, j); + const float p = whisper_full_get_token_p (ctx, i, j); - const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size())))); + const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size())))); - printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m"); - } - printf("\n"); - } else { - const char * text = whisper_full_get_segment_text(ctx, i); + printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m"); + } + printf("\n"); + } else { + const char * text = whisper_full_get_segment_text(ctx, i); - printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text); + printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text); + } } } } @@ -320,373 +306,117 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_ return true; } -// word-level timestamps (experimental) -// TODO: probably still has bugs, needs refactoring, etc.. -// TODO: auto threshold -// TODO: extra pass to detect unused speech and assign to tokens +// karaoke video generation +// outputs a bash script that uses ffmpeg to generate a video with the subtitles // TODO: font parameter adjustments -bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, const std::vector & pcmf32) { - if (params.output_wts) { - std::vector pcm_avg(pcmf32.size(), 0); - - // average the fabs of the signal - { - const int hw = 32; - - for (int i = 0; i < pcmf32.size(); i++) { - float sum = 0; - for (int j = -hw; j <= hw; j++) { - if (i + j >= 0 && i + j < pcmf32.size()) { - sum += fabs(pcmf32[i + j]); - } - } - pcm_avg[i] = sum/(2*hw + 1); - } - } - - struct token_info { - int64_t t0 = -1; - int64_t t1 = -1; - - int64_t tt0 = -1; - int64_t tt1 = -1; +bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec) { + std::ofstream fout(fname); - whisper_token id; - whisper_token tid; + fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); - float p = 0.0f; - float pt = 0.0f; - float ptsum = 0.0f; + // TODO: become parameter + static const char * font = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf"; - std::string text; - float vlen = 0.0f; // voice length of this token - }; + fout << "#!/bin/bash" << "\n"; + fout << "\n"; - int64_t t_beg = 0; - int64_t t_last = 0; + fout << "ffmpeg -i " << fname_inp << " -f lavfi -i color=size=1200x120:duration=" << t_sec << ":rate=25:color=black -vf \""; - whisper_token tid_last = 0; + for (int i = 0; i < whisper_full_n_segments(ctx); i++) { + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - std::ofstream fout(fname); + const int n = whisper_full_n_tokens(ctx, i); - fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); + std::vector tokens(n); + for (int j = 0; j < n; ++j) { + tokens[j] = whisper_full_get_token_data(ctx, i, j); + } - fout << "!/bin/bash" << "\n"; - fout << "\n"; + if (i > 0) { + fout << ","; + } - fout << "ffmpeg -i " << fname_inp << " -f lavfi -i color=size=1200x120:duration=" << float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE << ":rate=25:color=black -vf \""; + // background text + fout << "drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << "," << t0/100.0 << ")'"; bool is_first = true; - for (int i = 0; i < whisper_full_n_segments(ctx); i++) { - const int64_t t0 = whisper_full_get_segment_t0(ctx, i); - const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - - const char *text = whisper_full_get_segment_text(ctx, i); - - const int s0 = std::max(0, (int) (t0*WHISPER_SAMPLE_RATE/100)); - const int s1 = std::min((int) pcmf32.size(), (int) (t1*WHISPER_SAMPLE_RATE/100)); - - const int n = whisper_full_n_tokens(ctx, i); + for (int j = 0; j < n; ++j) { + const auto & token = tokens[j]; - std::vector tokens(n); - - if (n <= 1) { + if (tokens[j].id >= whisper_token_eot(ctx)) { continue; } - for (int j = 0; j < n; ++j) { - struct whisper_token_data token = whisper_full_get_token_data(ctx, i, j); - - if (j == 0) { - if (token.id == whisper_token_beg(ctx)) { - tokens[j ].t0 = t0; - tokens[j ].t1 = t0; - tokens[j + 1].t0 = t0; - - t_beg = t0; - t_last = t0; - tid_last = whisper_token_beg(ctx); - } else { - tokens[j ].t0 = t_last; - } - } - - const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx)); - - tokens[j].id = token.id; - tokens[j].tid = token.tid; - tokens[j].p = token.p; - tokens[j].pt = token.pt; - tokens[j].ptsum = token.ptsum; - - tokens[j].text = whisper_token_to_str(ctx, token.id); - //tokens[j].vlen = tokens[j].pt; - tokens[j].vlen = voice_length(tokens[j].text); - - if (token.pt > params.word_thold && token.ptsum > 0.01 && token.tid > tid_last && tt <= t1) { - if (j > 0) { - tokens[j - 1].t1 = tt; - } - tokens[j].t0 = tt; - tid_last = token.tid; - } - } - - tokens[n - 2].t1 = t1; - tokens[n - 1].t0 = t1; - tokens[n - 1].t1 = t1; - - t_last = t1; - - int p0 = 0; - int p1 = 0; - while (true) { - while (p1 < n && tokens[p1].t1 < 0) { - p1++; - } - - if (p1 >= n) { - p1--; - } - - if (p1 > p0) { - double psum = 0.0; - for (int j = p0; j <= p1; j++) { - psum += tokens[j].vlen; - } - - //printf("analyzing %d - %d, psum = %f\n", p0, p1, psum); - - const double dt = tokens[p1].t1 - tokens[p0].t0; + std::string txt_bg; + std::string txt_fg; // highlight token + std::string txt_ul; // underline - for (int j = p0 + 1; j <= p1; j++) { - const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum; - //const double ct = tokens[j - 1].t0 + (dt*(j - p0))/(p1 - p0 + 1); - //const double ct = tokens[p0].t0 + (dt*(j - p0))/(p1 - p0 + 1); + txt_bg = "> "; + txt_fg = "> "; + txt_ul = "\\ \\ "; - tokens[j - 1].t1 = ct; - tokens[j ].t0 = ct; - } - } - - p1++; - p0 = p1; - if (p1 >= n) { - break; - } - } - - for (int j = 0; j < n - 1; j++) { - if (tokens[j].t1 < 0) { - tokens[j + 1].t0 = tokens[j].t1; - } - - if (j > 0) { - if (tokens[j - 1].t1 > tokens[j].t0) { - tokens[j].t0 = tokens[j - 1].t1; - tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1); - } - } - - tokens[j].tt0 = tokens[j].t0; - tokens[j].tt1 = tokens[j].t1; - } - - // VAD { - const int hw = WHISPER_SAMPLE_RATE/8; + int ncnt = 0; + for (int k = 0; k < n; ++k) { + const auto & token2 = tokens[k]; - for (int j = 0; j < n; j++) { - if (tokens[j].id >= whisper_token_eot(ctx)) { + if (tokens[k].id >= whisper_token_eot(ctx)) { continue; } - const int64_t t0 = tokens[j].t0; - const int64_t t1 = tokens[j].t1; - - int s0 = std::max(0, (int) (t0*WHISPER_SAMPLE_RATE/100)); - int s1 = std::min((int) pcmf32.size() - 1, (int) (t1*WHISPER_SAMPLE_RATE/100)); + const std::string txt = whisper_token_to_str(ctx, token2.id); - const int ss0 = std::max(0, (int) (t0*WHISPER_SAMPLE_RATE/100) - hw); - const int ss1 = std::min((int) pcmf32.size() - 1, (int) (t1*WHISPER_SAMPLE_RATE/100) + hw); - - const int n = ss1 - ss0; - - float sum = 0.0f; - - for (int k = ss0; k < ss1; k++) { - sum += pcm_avg[k]; - } + txt_bg += txt; - const float thold = 0.5*sum/n; - - { - int k = s0; - if (pcm_avg[k] > thold && j > 0) { - while (k > 0 && pcm_avg[k] > thold) { - k--; - } - tokens[j].t0 = (int64_t) (100*k/WHISPER_SAMPLE_RATE); - if (tokens[j].t0 < tokens[j - 1].t1) { - tokens[j].t0 = tokens[j - 1].t1; - } else { - s0 = k; - } - } else { - while (pcm_avg[k] < thold && k < s1) { - k++; - } - s0 = k; - tokens[j].t0 = 100*k/WHISPER_SAMPLE_RATE; + if (k == j) { + for (int l = 0; l < (int) txt.size(); ++l) { + txt_fg += txt[l]; + txt_ul += "_"; } - } - - { - int k = s1; - if (pcm_avg[k] > thold) { - while (k < (int) pcmf32.size() - 1 && pcm_avg[k] > thold) { - k++; - } - tokens[j].t1 = 100*k/WHISPER_SAMPLE_RATE; - if (j < n - 1 && tokens[j].t1 > tokens[j + 1].t0) { - tokens[j].t1 = tokens[j + 1].t0; - } else { - s1 = k; - } - } else { - while (pcm_avg[k] < thold && k > s0) { - k--; - } - s1 = k; - tokens[j].t1 = 100*k/WHISPER_SAMPLE_RATE; + txt_fg += "|"; + } else { + for (int l = 0; l < (int) txt.size(); ++l) { + txt_fg += "\\ "; + txt_ul += "\\ "; } } - } - } - - const int t_expand = 0; - for (int j = 0; j < n; j++) { - if (j > 0) { - tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand)); + ncnt += txt.size(); } - if (j < n - 1) { - tokens[j].t1 = tokens[j].t1 + t_expand; - } - } - - for (int j = 0; j < n; ++j) { - const auto & token = tokens[j]; - const auto tt = token.pt > params.word_thold && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]"; - printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__, - tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, token.text.c_str()); - - if (tokens[j].id >= whisper_token_eot(ctx)) { - continue; - } - - //printf("[%s --> %s] %s\n", to_timestamp(token.t0).c_str(), to_timestamp(token.t1).c_str(), whisper_token_to_str(ctx, token.id)); - //fout << "# " << to_timestamp(token.t0) << " --> " << to_timestamp(token.t1) << " " << whisper_token_to_str(ctx, token.id) << "\n"; + ::replace_all(txt_bg, "'", "’"); + ::replace_all(txt_bg, "\"", "\\\""); + ::replace_all(txt_fg, "'", "’"); + ::replace_all(txt_fg, "\"", "\\\""); } - static const int line_wrap = 60; - static const char * font = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf"; - - if (!is_first) { - fout << ","; - } - - // background text - fout << "drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << "," << t0/100.0 << ")'"; - - is_first = false; - - for (int j = 0; j < n; ++j) { - const auto & token = tokens[j]; - - if (tokens[j].id >= whisper_token_eot(ctx)) { - continue; - } - - std::string txt_bg; - std::string txt_fg; // highlight token - std::string txt_ul; // underline - - txt_bg = "> "; - txt_fg = "> "; - txt_ul = "\\ \\ "; - - { - int ncnt = 0; - for (int k = 0; k < n; ++k) { - const auto & token2 = tokens[k]; - - if (tokens[k].id >= whisper_token_eot(ctx)) { - continue; - } - - const std::string txt = whisper_token_to_str(ctx, token2.id); - - txt_bg += txt; - - if (k == j) { - for (int l = 0; l < (int) txt.size(); ++l) { - txt_fg += txt[l]; - txt_ul += "_"; - } - txt_fg += "|"; - } else { - for (int l = 0; l < (int) txt.size(); ++l) { - txt_fg += "\\ "; - txt_ul += "\\ "; - } - } - - ncnt += txt.size(); - - if (ncnt > line_wrap) { - if (k < j) { - txt_bg = "> "; - txt_fg = "> "; - txt_ul = "\\ \\ "; - ncnt = 0; - } else { - break; - } - } - } - - ::replace_all(txt_bg, "'", "’"); - ::replace_all(txt_bg, "\"", "\\\""); - ::replace_all(txt_fg, "'", "’"); - ::replace_all(txt_fg, "\"", "\\\""); - } - + if (is_first) { // background text - fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='" << txt_bg << "':enable='between(t," << token.tt0/100.0 << "," << token.tt1/100.0 << ")'"; + fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='" << txt_bg << "':enable='between(t," << t0/100.0 << "," << t1/100.0 << ")'"; + is_first = false; + } - // foreground text - fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2:text='" << txt_fg << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'"; + // foreground text + fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2:text='" << txt_fg << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'"; - // underline - fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2+16:text='" << txt_ul << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'"; - } + // underline + fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2+16:text='" << txt_ul << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'"; } + } - fout << "\" -c:v libx264 -pix_fmt yuv420p -y " << fname_inp << ".mp4" << "\n"; + fout << "\" -c:v libx264 -pix_fmt yuv420p -y " << fname_inp << ".mp4" << "\n"; - fout << "\n\n"; - fout << "echo \"Your video has been saved to " << fname_inp << ".mp4\"" << "\n"; - fout << "\n"; - fout << "echo \" ffplay " << fname_inp << ".mp4\"\n"; - fout << "\n"; + fout << "\n\n"; + fout << "echo \"Your video has been saved to " << fname_inp << ".mp4\"" << "\n"; + fout << "\n"; + fout << "echo \" ffplay " << fname_inp << ".mp4\"\n"; + fout << "\n"; - fout.close(); + fout.close(); - fprintf(stderr, "%s: run 'source %s' to generate karaoke video\n", __func__, fname); - } + fprintf(stderr, "%s: run 'source %s' to generate karaoke video\n", __func__, fname); return true; } @@ -724,9 +454,30 @@ int main(int argc, char ** argv) { std::vector pcmf32; { drwav wav; - if (!drwav_init_file(&wav, fname_inp.c_str(), NULL)) { - fprintf(stderr, "%s: failed to open WAV file '%s' - check your input\n", argv[0], fname_inp.c_str()); - whisper_print_usage(argc, argv, {}); + + if (fname_inp == "-") { + std::vector wav_data; + { + uint8_t buf[1024]; + while (true) + { + const size_t n = fread(buf, 1, sizeof(buf), stdin); + if (n == 0) + { + break; + } + wav_data.insert(wav_data.end(), buf, buf + n); + } + } + + if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), NULL) == false) + { + fprintf(stderr, "error: failed to open WAV file from stdin\n"); + return 4; + } + } + else if (drwav_init_file(&wav, fname_inp.c_str(), NULL) == false) { + fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str()); return 4; } @@ -806,6 +557,11 @@ int main(int argc, char ** argv) { wparams.n_threads = params.n_threads; wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx; wparams.offset_ms = params.offset_t_ms; + wparams.duration_ms = params.duration_ms; + + wparams.token_timestamps = params.output_wts || params.max_len > 0; + wparams.thold_pt = params.word_thold; + wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; // this callback is called on each new segment if (!wparams.print_realtime) { @@ -844,7 +600,7 @@ int main(int argc, char ** argv) { // output to WTS file if (params.output_wts) { const auto fname_wts = fname_inp + ".wts"; - output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, pcmf32); + output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE); } } } diff --git a/examples/whisper/whisper.cpp b/examples/whisper/whisper.cpp index 7f2b49b..7078863 100644 --- a/examples/whisper/whisper.cpp +++ b/examples/whisper/whisper.cpp @@ -133,11 +133,19 @@ static const std::map> g_lang = { static const size_t MB = 1024*1024; static const std::map MEM_REQ_MODEL = { - { MODEL_TINY, 86ull*MB }, - { MODEL_BASE, 165ull*MB }, - { MODEL_SMALL, 540ull*MB }, - { MODEL_MEDIUM, 1650ull*MB }, - { MODEL_LARGE, 3260ull*MB }, + { MODEL_TINY, 74ull*MB }, + { MODEL_BASE, 142ull*MB }, + { MODEL_SMALL, 466ull*MB }, + { MODEL_MEDIUM, 1464ull*MB }, + { MODEL_LARGE, 2952ull*MB }, +}; + +static const std::map MEM_REQ_MEMORY = { + { MODEL_TINY, 12ull*MB }, + { MODEL_BASE, 24ull*MB }, + { MODEL_SMALL, 70ull*MB }, + { MODEL_MEDIUM, 184ull*MB }, + { MODEL_LARGE, 306ull*MB }, }; static const std::map MEM_REQ_ENCODE = { @@ -410,6 +418,12 @@ struct whisper_context { std::vector result_all; std::vector prompt_past; + + // [EXPERIMENTAL] token-level timestamps data + int64_t t_beg; + int64_t t_last; + whisper_token tid_last; + std::vector energy; // PCM signal energy }; // load the model from a ggml file @@ -423,7 +437,7 @@ struct whisper_context { // // see the convert-pt-to-ggml.py script for details // -bool whisper_model_load(const std::string & fname, whisper_context & wctx) { +static bool whisper_model_load(const std::string & fname, whisper_context & wctx) { fprintf(stderr, "%s: loading model from '%s'\n", __func__, fname.c_str()); auto & model = wctx.model; @@ -498,7 +512,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) { wctx.buf_model = new std::vector(); wctx.buf_model->resize(MEM_REQ_MODEL.at(model.type)); - wctx.buf_memory.resize(std::max(MEM_REQ_MODEL.at(model.type), MEM_REQ_MODEL.at(model.type))); // TODO: TMP !!! + wctx.buf_memory.resize(MEM_REQ_MEMORY.at(model.type)); wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type))); @@ -722,20 +736,6 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) { } } - // create the ggml memory context - { - struct ggml_init_params params = { - .mem_size = wctx.buf_memory.size(), - .mem_buffer = wctx.buf_memory.data(), - }; - - model.ctx_mem = ggml_init(params); - if (!model.ctx_mem) { - fprintf(stderr, "%s: ggml_init() failed\n", __func__); - return false; - } - } - // prepare memory for the weights { auto & ctx = model.ctx; @@ -932,6 +932,20 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) { } } + // create the ggml memory context + { + struct ggml_init_params params = { + .mem_size = wctx.buf_memory.size(), + .mem_buffer = wctx.buf_memory.data(), + }; + + model.ctx_mem = ggml_init(params); + if (!model.ctx_mem) { + fprintf(stderr, "%s: ggml_init() failed\n", __func__); + return false; + } + } + // key + value memory { auto & ctx = model.ctx_mem; @@ -1054,7 +1068,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) { // - n_threads: number of threads to use // - mel_offset: offset in the mel spectrogram (i.e. audio offset) // -bool whisper_encode( +static bool whisper_encode( whisper_context & wctx, const int n_threads, const int mel_offset) { @@ -1440,7 +1454,7 @@ bool whisper_encode( // - n_tokens: number of tokens in the prompt // - n_past: number of past tokens to prefix the prompt with // -bool whisper_decode( +static bool whisper_decode( whisper_context & wctx, const int n_threads, const whisper_token * tokens, @@ -1803,10 +1817,12 @@ bool whisper_decode( } // the most basic sampling scheme - select the top token -whisper_token_data whisper_sample_best( +static whisper_token_data whisper_sample_best( const whisper_vocab & vocab, const float * probs) { - whisper_token_data result; + whisper_token_data result = { + 0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f, + }; int n_logits = vocab.id_to_token.size(); @@ -1879,7 +1895,7 @@ whisper_token_data whisper_sample_best( } // samples only from the timestamps tokens -whisper_vocab::id whisper_sample_timestamp( +static whisper_vocab::id whisper_sample_timestamp( const whisper_vocab & vocab, const float * probs) { int n_logits = vocab.id_to_token.size(); @@ -1931,7 +1947,7 @@ static std::string to_timestamp(int64_t t, bool comma = false) { // naive Discrete Fourier Transform // input is real-valued // output is complex-valued -void dft(const std::vector & in, std::vector & out) { +static void dft(const std::vector & in, std::vector & out) { int N = in.size(); out.resize(N*2); @@ -1955,7 +1971,7 @@ void dft(const std::vector & in, std::vector & out) { // poor man's implementation - use something better // input is real-valued // output is complex-valued -void fft(const std::vector & in, std::vector & out) { +static void fft(const std::vector & in, std::vector & out) { out.resize(in.size()*2); int N = in.size(); @@ -2006,7 +2022,7 @@ void fft(const std::vector & in, std::vector & out) { } // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124 -bool log_mel_spectrogram( +static bool log_mel_spectrogram( const float * samples, const int n_samples, const int sample_rate, @@ -2323,6 +2339,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), /*.n_max_text_ctx =*/ 16384, /*.offset_ms =*/ 0, + /*.duration_ms =*/ 0, /*.translate =*/ false, /*.no_context =*/ false, @@ -2331,6 +2348,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.print_realtime =*/ false, /*.print_timestamps =*/ true, + /*.token_timestamps =*/ false, + /*.thold_pt =*/ 0.01f, + /*.thold_ptsum =*/ 0.01f, + /*.max_len =*/ 0, + /*.language =*/ "en", /*.greedy =*/ { @@ -2355,6 +2377,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), /*.n_max_text_ctx =*/ 16384, /*.offset_ms =*/ 0, + /*.duration_ms =*/ 0, /*.translate =*/ false, /*.no_context =*/ false, @@ -2363,6 +2386,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.print_realtime =*/ false, /*.print_timestamps =*/ true, + /*.token_timestamps =*/ false, + /*.thold_pt =*/ 0.01f, + /*.thold_ptsum =*/ 0.01f, + /*.max_len =*/ 0, + /*.language =*/ "en", /*.greedy =*/ { @@ -2384,6 +2412,68 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str return result; } +// forward declarations +static std::vector get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window); +static void whisper_exp_compute_token_level_timestamps( + struct whisper_context * ctx, + int i_segment, + float thold_pt, + float thold_ptsum); + +// wrap the last segment to max_len characters +// returns the number of new segments +static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) { + auto segment = ctx->result_all.back(); + + int res = 1; + int acc = 0; + + std::string text; + + for (int i = 0; i < (int) segment.tokens.size(); i++) { + const auto & token = segment.tokens[i]; + if (token.id >= whisper_token_eot(ctx)) { + continue; + } + + const auto txt = whisper_token_to_str(ctx, token.id); + + const int cur = strlen(txt); + + if (acc + cur > max_len && i > 0) { + // split here + ctx->result_all.back().text = std::move(text); + ctx->result_all.back().t1 = token.t0; + ctx->result_all.back().tokens.resize(i); + + ctx->result_all.push_back({}); + ctx->result_all.back().t0 = token.t0; + ctx->result_all.back().t1 = segment.t1; + + // add tokens [i, end] to the new segment + ctx->result_all.back().tokens.insert( + ctx->result_all.back().tokens.end(), + segment.tokens.begin() + i, + segment.tokens.end()); + + acc = 0; + text = ""; + + segment = ctx->result_all.back(); + i = -1; + + res++; + } else { + acc += cur; + text += txt; + } + } + + ctx->result_all.back().text = std::move(text); + + return res; +} + int whisper_full( struct whisper_context * ctx, struct whisper_full_params params, @@ -2400,12 +2490,20 @@ int whisper_full( return -1; } + if (params.token_timestamps) { + ctx->t_beg = 0; + ctx->t_last = 0; + ctx->tid_last = 0; + ctx->energy = get_signal_energy(samples, n_samples, 32); + } + const int seek_start = params.offset_ms/10; + const int seek_end = seek_start + (params.duration_ms == 0 ? whisper_n_len(ctx) : params.duration_ms/10); // if length of spectrogram is less than 1s (100 samples), then return // basically don't process anything that is less than 1s // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39 - if (whisper_n_len(ctx) < 100 + seek_start) { + if (seek_end < 100 + seek_start) { return 0; } @@ -2438,7 +2536,7 @@ int whisper_full( // main loop int seek = seek_start; while (true) { - int progress_cur = (100*seek)/whisper_n_len(ctx); + const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start); while (progress_cur >= progress_prev + progress_step) { progress_prev += progress_step; if (params.print_progress) { @@ -2446,7 +2544,7 @@ int whisper_full( } } - if (seek + 100 >= whisper_n_len(ctx)) { + if (seek + 100 >= seek_end) { break; } @@ -2527,7 +2625,7 @@ int whisper_full( // end of text token if (token.id == whisper_token_eot(ctx)) { if (result_len == 0) { - if (seek + seek_delta + 100 >= whisper_n_len(ctx)) { + if (seek + seek_delta + 100 >= seek_end) { result_len = i + 1; } else { // TODO: figure out how to resolve this @@ -2549,6 +2647,7 @@ int whisper_full( } } + // shrink down to result_len tokens_cur.resize(result_len); for (const auto & r : tokens_cur) { @@ -2587,8 +2686,19 @@ int whisper_full( for (int j = i0; j <= i; j++) { result_all.back().tokens.push_back(tokens_cur[j]); } + + int n_new = 1; + + if (params.token_timestamps) { + whisper_exp_compute_token_level_timestamps( + ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + + if (params.max_len > 0) { + n_new = whisper_wrap_segment(ctx, params.max_len); + } + } if (params.new_segment_callback) { - params.new_segment_callback(ctx, params.new_segment_callback_user_data); + params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data); } } text = ""; @@ -2617,8 +2727,19 @@ int whisper_full( for (int j = i0; j < (int) tokens_cur.size(); j++) { result_all.back().tokens.push_back(tokens_cur[j]); } + + int n_new = 1; + + if (params.token_timestamps) { + whisper_exp_compute_token_level_timestamps( + ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + + if (params.max_len > 0) { + n_new = whisper_wrap_segment(ctx, params.max_len); + } + } if (params.new_segment_callback) { - params.new_segment_callback(ctx, params.new_segment_callback_user_data); + params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data); } } } @@ -2752,7 +2873,7 @@ int whisper_full_parallel( // call the new_segment_callback for each segment if (params.new_segment_callback) { - params.new_segment_callback(ctx, params.new_segment_callback_user_data); + params.new_segment_callback(ctx, 1, params.new_segment_callback_user_data); } } @@ -2828,3 +2949,304 @@ const char * whisper_print_system_info() { return s.c_str(); } + +// ================================================================================================= + +// +// Experimental stuff below +// +// Not sure if these should be part of the library at all, because the quality of the results is not +// guaranteed. Might get removed at some point unless a robust algorithm implementation is found +// + +// ================================================================================================= + +// +// token-level timestamps +// + +static int timestamp_to_sample(int64_t t, int n_samples) { + return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100))); +} + +static int64_t sample_to_timestamp(int i_sample) { + return (100*i_sample)/WHISPER_SAMPLE_RATE; +} + +// a cost-function / heuristic that is high for text that takes longer to pronounce +// obviously, can be improved +static float voice_length(const std::string & text) { + float res = 0.0f; + + for (size_t i = 0; i < text.size(); ++i) { + if (text[i] == ' ') { + res += 0.01f; + } else if (text[i] == ',') { + res += 2.00f; + } else if (text[i] == '.') { + res += 3.00f; + } else if (text[i] == '!') { + res += 3.00f; + } else if (text[i] == '?') { + res += 3.00f; + } else if (text[i] >= '0' && text[i] <= '9') { + res += 3.00f; + } else { + res += 1.00f; + } + } + + return res; +} + +// average the fabs of the signal +static std::vector get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window) { + const int hw = n_samples_per_half_window; + + std::vector result(n_samples); + + for (int i = 0; i < n_samples; i++) { + float sum = 0; + for (int j = -hw; j <= hw; j++) { + if (i + j >= 0 && i + j < n_samples) { + sum += fabs(signal[i + j]); + } + } + result[i] = sum/(2*hw + 1); + } + + return result; +} + +static void whisper_exp_compute_token_level_timestamps( + struct whisper_context * ctx, + int i_segment, + float thold_pt, + float thold_ptsum) { + auto & segment = ctx->result_all[i_segment]; + auto & tokens = segment.tokens; + + const int n_samples = ctx->energy.size(); + + if (n_samples == 0) { + fprintf(stderr, "%s: no signal data available\n", __func__); + return; + } + + const int64_t t0 = segment.t0; + const int64_t t1 = segment.t1; + + const int s0 = timestamp_to_sample(t0, n_samples); + const int s1 = timestamp_to_sample(t1, n_samples); + + const int n = tokens.size(); + + if (n == 0) { + return; + } + + if (n == 1) { + tokens[0].t0 = t0; + tokens[0].t1 = t1; + + return; + } + + auto & t_beg = ctx->t_beg; + auto & t_last = ctx->t_last; + auto & tid_last = ctx->tid_last; + + for (int j = 0; j < n; ++j) { + auto & token = tokens[j]; + + if (j == 0) { + if (token.id == whisper_token_beg(ctx)) { + tokens[j ].t0 = t0; + tokens[j ].t1 = t0; + tokens[j + 1].t0 = t0; + + t_beg = t0; + t_last = t0; + tid_last = whisper_token_beg(ctx); + } else { + tokens[j ].t0 = t_last; + } + } + + const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx)); + + tokens[j].id = token.id; + tokens[j].tid = token.tid; + tokens[j].p = token.p; + tokens[j].pt = token.pt; + tokens[j].ptsum = token.ptsum; + + tokens[j].vlen = voice_length(whisper_token_to_str(ctx, token.id)); + + if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) { + if (j > 0) { + tokens[j - 1].t1 = tt; + } + tokens[j].t0 = tt; + tid_last = token.tid; + } + } + + tokens[n - 2].t1 = t1; + tokens[n - 1].t0 = t1; + tokens[n - 1].t1 = t1; + + t_last = t1; + + // find intervals of tokens with unknown timestamps + // fill the timestamps by proportionally splitting the interval based on the token voice lengths + { + int p0 = 0; + int p1 = 0; + + while (true) { + while (p1 < n && tokens[p1].t1 < 0) { + p1++; + } + + if (p1 >= n) { + p1--; + } + + if (p1 > p0) { + double psum = 0.0; + for (int j = p0; j <= p1; j++) { + psum += tokens[j].vlen; + } + + //printf("analyzing %d - %d, psum = %f\n", p0, p1, psum); + + const double dt = tokens[p1].t1 - tokens[p0].t0; + + // split the time proportionally to the voice length + for (int j = p0 + 1; j <= p1; j++) { + const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum; + + tokens[j - 1].t1 = ct; + tokens[j ].t0 = ct; + } + } + + p1++; + p0 = p1; + if (p1 >= n) { + break; + } + } + } + + // fix up (just in case) + for (int j = 0; j < n - 1; j++) { + if (tokens[j].t1 < 0) { + tokens[j + 1].t0 = tokens[j].t1; + } + + if (j > 0) { + if (tokens[j - 1].t1 > tokens[j].t0) { + tokens[j].t0 = tokens[j - 1].t1; + tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1); + } + } + } + + // VAD + // expand or contract tokens based on voice activity + { + const int hw = WHISPER_SAMPLE_RATE/8; + + for (int j = 0; j < n; j++) { + if (tokens[j].id >= whisper_token_eot(ctx)) { + continue; + } + + int s0 = timestamp_to_sample(tokens[j].t0, n_samples); + int s1 = timestamp_to_sample(tokens[j].t1, n_samples); + + const int ss0 = std::max(s0 - hw, 0); + const int ss1 = std::min(s1 + hw, n_samples); + + const int ns = ss1 - ss0; + + float sum = 0.0f; + + for (int k = ss0; k < ss1; k++) { + sum += ctx->energy[k]; + } + + const float thold = 0.5*sum/ns; + + { + int k = s0; + if (ctx->energy[k] > thold && j > 0) { + while (k > 0 && ctx->energy[k] > thold) { + k--; + } + tokens[j].t0 = sample_to_timestamp(k); + if (tokens[j].t0 < tokens[j - 1].t1) { + tokens[j].t0 = tokens[j - 1].t1; + } else { + s0 = k; + } + } else { + while (ctx->energy[k] < thold && k < s1) { + k++; + } + s0 = k; + tokens[j].t0 = sample_to_timestamp(k); + } + } + + { + int k = s1; + if (ctx->energy[k] > thold) { + while (k < n_samples - 1 && ctx->energy[k] > thold) { + k++; + } + tokens[j].t1 = sample_to_timestamp(k); + if (j < ns - 1 && tokens[j].t1 > tokens[j + 1].t0) { + tokens[j].t1 = tokens[j + 1].t0; + } else { + s1 = k; + } + } else { + while (ctx->energy[k] < thold && k > s0) { + k--; + } + s1 = k; + tokens[j].t1 = sample_to_timestamp(k); + } + } + } + } + + // fixed token expand (optional) + //{ + // const int t_expand = 0; + + // for (int j = 0; j < n; j++) { + // if (j > 0) { + // tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand)); + // } + // if (j < n - 1) { + // tokens[j].t1 = tokens[j].t1 + t_expand; + // } + // } + //} + + // debug info + //for (int j = 0; j < n; ++j) { + // const auto & token = tokens[j]; + // const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]"; + // printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__, + // tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(ctx, token.id)); + + // if (tokens[j].id >= whisper_token_eot(ctx)) { + // continue; + // } + //} +} diff --git a/examples/whisper/whisper.h b/examples/whisper/whisper.h index 5d7c40d..4c112f4 100644 --- a/examples/whisper/whisper.h +++ b/examples/whisper/whisper.h @@ -68,14 +68,21 @@ extern "C" { typedef int whisper_token; - struct whisper_token_data { + typedef struct whisper_token_data { whisper_token id; // token id whisper_token tid; // forced timestamp token id float p; // probability of the token float pt; // probability of the timestamp token float ptsum; // sum of probabilities of all timestamp tokens - }; + + // token-level timestamp data + // do not use if you haven't computed token-level timestamps + int64_t t0; // start time of the token + int64_t t1; // end time of the token + + float vlen; // voice length of the token + } whisper_token_data; // Allocates all memory needed for the model and loads the model from the given file. // Returns NULL on failure. @@ -129,7 +136,7 @@ extern "C" { // You can also implement your own sampling method using the whisper_get_probs() function. // whisper_sample_best() returns the token with the highest probability // whisper_sample_timestamp() returns the most probable timestamp token - WHISPER_API struct whisper_token_data whisper_sample_best(struct whisper_context * ctx); + WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx); WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx); // Return the id of the specified language, returns -1 if not found @@ -172,14 +179,15 @@ extern "C" { // Text segment callback // Called on every newly generated text segment // Use the whisper_full_...() functions to obtain the text segments - typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, void * user_data); + typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data); struct whisper_full_params { enum whisper_sampling_strategy strategy; int n_threads; int n_max_text_ctx; - int offset_ms; + int offset_ms; // start offset in ms + int duration_ms; // audio duration to process in ms bool translate; bool no_context; @@ -188,6 +196,12 @@ extern "C" { bool print_realtime; bool print_timestamps; + // [EXPERIMENTAL] token-level timestamps + bool token_timestamps; // enable token-level timestamps + float thold_pt; // timestamp token probability threshold (~0.01) + float thold_ptsum; // timestamp token sum probability threshold (~0.01) + int max_len; // max segment length in characters + const char * language; struct { @@ -244,7 +258,7 @@ extern "C" { // Get token data for the specified token in the specified segment. // This contains probabilities, timestamps, etc. - WHISPER_API struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token); // Get the probability of the specified token in the specified segment. WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token); diff --git a/include/ggml/ggml.h b/include/ggml/ggml.h index f92ae73..f352e71 100644 --- a/include/ggml/ggml.h +++ b/include/ggml/ggml.h @@ -1,5 +1,174 @@ #pragma once +// +// GGML Tensor Library +// +// This documentation is still a work in progress. +// If you wish some specific topics to be covered, feel free to drop a comment: +// +// https://github.com/ggerganov/whisper.cpp/issues/40 +// +// ## Overview +// +// This library implements: +// +// - a set of tensor operations +// - automatic differentiation +// - basic optimization algorithms +// +// The aim of this library is to provide a minimalistic approach for various machine learning tasks. This includes, +// but is not limited to, the following: +// +// - linear regression +// - support vector machines +// - neural networks +// +// The library allows the user to define a certain function using the available tensor operations. This function +// definition is represented internally via a computation graph. Each tensor operation in the function definition +// corresponds to a node in the graph. Having the computation graph defined, the user can choose to compute the +// function's value and/or its gradient with respect to the input variables. Optionally, the function can be optimized +// using one of the available optimization algorithms. +// +// For example, here we define the function: f(x) = a*x^2 + b +// +// { +// struct ggml_init_params params = { +// .mem_size = 16*1024*1024, +// .mem_buffer = NULL, +// }; +// +// // memory allocation happens here +// struct ggml_context * ctx = ggml_init(params); +// +// struct ggml_tensor * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); +// +// ggml_set_param(ctx, x); // x is an input variable +// +// struct ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); +// struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); +// struct ggml_tensor * x2 = ggml_mul(ctx, x, x); +// struct ggml_tensor * f = ggml_add(ctx, ggml_mul(ctx, a, x2), b); +// +// ... +// } +// +// Notice that the function definition above does not involve any actual computation. The computation is performed only +// when the user explicitly requests it. For example, to compute the function's value at x = 2.0: +// +// { +// ... +// +// struct ggml_cgraph gf = ggml_build_forward(f); +// +// // set the input variable and parameter values +// ggml_set_f32(x, 2.0f); +// ggml_set_f32(a, 3.0f); +// ggml_set_f32(b, 4.0f); +// +// ggml_graph_compute(ctx0, &gf); +// +// printf("f = %f\n", ggml_get_f32_1d(f, 0)); +// +// ... +// } +// +// The actual computation is performed in the ggml_graph_compute() function. +// +// The ggml_new_tensor_...() functions create new tensors. They are allocated in the memory buffer provided to the +// ggml_init() function. You have to be careful not to exceed the memory buffer size. Therefore, you have to know +// in advance how much memory you need for your computation. Alternatively, you can allocate a large enough memory +// and after defining the computation graph, call the ggml_used_mem() function to find out how much memory was +// actually needed. +// +// The ggml_set_param() function marks a tensor as an input variable. This is used by the automatic +// differentiation and optimization algorithms. +// +// The described approach allows to define the function graph once and then compute its forward or backward graphs +// multiple times. All computations will use the same memory buffer allocated in the ggml_init() function. This way +// the user can avoid the memory allocation overhead at runtime. +// +// The library supports multi-dimensional tensors - up to 4 dimensions. The FP16 and FP32 data types are first class +// citizens, but in theory the library can be extended to support FP8 and integer data types. +// +// Each tensor operation produces a new tensor. Initially the library was envisioned to support only the use of unary +// and binary operations. Most of the available operations fall into one of these two categories. With time, it became +// clear that the library needs to support more complex operations. The way to support these operations is not clear +// yet, but a few examples are demonstrated in the following operations: +// +// - ggml_permute() +// - ggml_conv_1d_1s() +// - ggml_conv_1d_2s() +// +// For each tensor operator, the library implements a forward and backward computation function. The forward function +// computes the output tensor value given the input tensor values. The backward function computes the adjoint of the +// input tensors given the adjoint of the output tensor. For a detailed explanation of what this means, take a +// calculus class, or watch the following video: +// +// What is Automatic Differentiation? +// https://www.youtube.com/watch?v=wG_nF1awSSY +// +// +// ## Tensor data (struct ggml_tensor) +// +// The tensors are stored in memory via the ggml_tensor struct. The structure provides information about the size of +// the tensor, the data type, and the memory buffer where the tensor data is stored. Additionally, it contains +// pointers to the "source" tensors - i.e. the tensors that were used to compute the current tensor. For example: +// +// { +// struct ggml_tensor * c = ggml_add(ctx, a, b); +// +// assert(c->src[0] == a); +// assert(c->src[1] == b); +// } +// +// The multi-dimensional tensors are stored in row-major order. The ggml_tensor struct contains fields for the +// number of elements in each dimension ("ne") as well as the number of bytes ("nb", a.k.a. stride). This allows +// to store tensors that are not contiguous in memory, which is useful for operations such as transposition and +// permutation. All tensor operations have to take the stride into account and not assume that the tensor is +// contiguous in memory. +// +// The data of the tensor is accessed via the "data" pointer. For example: +// +// { +// struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2, 3); +// +// // a[1, 2] = 1.0f; +// *(float *) ((char *) a->data + 2*a->nb[1] + 1*a->nb[0]) = 1.0f; +// +// // a[2, 0] = 2.0f; +// *(float *) ((char *) a->data + 0*a->nb[1] + 2*a->nb[0]) = 2.0f; +// +// ... +// } +// +// Alternatively, there are helper functions, such as ggml_get_f32_1d() and ggml_set_f32_1d() that can be used. +// +// ## The matrix multiplication operator (ggml_mul_mat) +// +// TODO +// +// +// ## Multi-threading +// +// TODO +// +// +// ## Overview of ggml.c +// +// TODO +// +// +// ## SIMD optimizations +// +// TODO +// +// +// ## Debugging ggml +// +// TODO +// +// + #ifdef __cplusplus extern "C" { #endif @@ -21,7 +190,8 @@ typedef __fp16 ggml_fp16_t; typedef uint16_t ggml_fp16_t; #endif -float ggml_fp16_to_fp32(ggml_fp16_t x); +// convert FP16 <-> FP32 +float ggml_fp16_to_fp32(ggml_fp16_t x); ggml_fp16_t ggml_fp32_to_fp16(float x); struct ggml_object; @@ -36,6 +206,7 @@ enum ggml_type { GGML_TYPE_COUNT, }; +// available tensor operations: enum ggml_op { GGML_OP_NONE = 0, @@ -136,7 +307,7 @@ struct ggml_init_params { void * mem_buffer; // if NULL, memory will be allocated internally }; -void ggml_time_init(void); +void ggml_time_init(void); // call this once at the beginning of the program int64_t ggml_time_ms(void); int64_t ggml_time_us(void); int64_t ggml_cycles(void); diff --git a/src/ggml.c b/src/ggml.c index 1000a5b..484b6dc 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -14,7 +14,7 @@ #include #include -#if defined _MSC_VER +#if defined _MSC_VER || defined(__MINGW32__) #include typedef volatile LONG atomic_int; @@ -37,8 +37,14 @@ typedef HANDLE pthread_t; typedef DWORD thread_ret_t; static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) { - out = CreateThread(NULL, 0, func, arg, 0, NULL); - return out != NULL; + HANDLE handle = CreateThread(NULL, 0, func, arg, 0, NULL); + if (handle == NULL) + { + return EAGAIN; + } + + *out = handle; + return 0; } static int pthread_join(pthread_t thread, void* unused) { @@ -198,7 +204,7 @@ static ggml_fp16_t table_exp_f16[1 << 16]; // timing // -#if defined(_MSC_VER) +#if defined(_MSC_VER) || defined(__MINGW32__) static int64_t timer_freq; void ggml_time_init(void) { LARGE_INTEGER frequency; @@ -3150,7 +3156,10 @@ void ggml_compute_forward_add_f32( GGML_ASSERT(nb00 == sizeof(float)); if (nb10 == sizeof(float)) { - for (int j = ith; j < n; j += nth) { + const int j0 = (n/nth)*ith; + const int j1 = ith == nth - 1 ? n : (n/nth)*(ith + 1); + + for (int j = j0; j < j1; j++) { ggml_vec_add_f32(nc, (float *) ((char *) dst->data + j*nb1), (float *) ((char *) src0->data + j*nb01), @@ -6857,7 +6866,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) } break; case GGML_OP_ADD: { - node->n_tasks = 1; + node->n_tasks = n_threads; } break; case GGML_OP_SUB: case GGML_OP_MUL: @@ -8089,7 +8098,7 @@ int ggml_cpu_has_avx512(void) { } int ggml_cpu_has_neon(void) { -#if defined(__ARM_NEON__) +#if defined(__ARM_NEON) return 1; #else return 0;