From f56828ed7836ce18c6bfdca0fb7476ac720e5d77 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 4 Dec 2022 11:06:13 +0200 Subject: [PATCH] ggml : sync with latest code from whisper.cpp --- examples/whisper/main.cpp | 293 ++++++++++++++---------- examples/whisper/whisper.cpp | 428 ++++++++++++++++++++++------------- examples/whisper/whisper.h | 95 +++++--- include/ggml/ggml.h | 1 + src/ggml.c | 175 +++++++++++++- 5 files changed, 674 insertions(+), 318 deletions(-) diff --git a/examples/whisper/main.cpp b/examples/whisper/main.cpp index 7058031..465d43f 100644 --- a/examples/whisper/main.cpp +++ b/examples/whisper/main.cpp @@ -36,6 +36,10 @@ std::string to_timestamp(int64_t t, bool comma = false) { return std::string(buf); } +int timestamp_to_sample(int64_t t, int n_samples) { + return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100))); +} + // helper function to replace substrings void replace_all(std::string & s, const std::string & search, const std::string & replace) { for (size_t pos = 0; ; pos += replace.length()) { @@ -48,7 +52,6 @@ void replace_all(std::string & s, const std::string & search, const std::string // command-line parameters struct whisper_params { - int32_t seed = -1; // RNG seed, not used currently int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); int32_t n_processors = 1; int32_t offset_t_ms = 0; @@ -59,15 +62,16 @@ struct whisper_params { float word_thold = 0.01f; - bool verbose = false; - bool translate = false; - bool output_txt = false; - bool output_vtt = false; - bool output_srt = false; - bool output_wts = false; - bool print_special_tokens = false; - bool print_colors = false; - bool no_timestamps = false; + bool speed_up = false; + bool translate = false; + bool diarize = false; + bool output_txt = false; + bool output_vtt = false; + bool output_srt = false; + bool output_wts = false; + bool print_special = false; + bool print_colors = false; + bool no_timestamps = false; std::string language = "en"; std::string model = "models/ggml-base.en.bin"; @@ -86,57 +90,32 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { continue; } - if (arg == "-s" || arg == "--seed") { - params.seed = std::stoi(argv[++i]); - } else if (arg == "-t" || arg == "--threads") { - params.n_threads = std::stoi(argv[++i]); - } else if (arg == "-p" || arg == "--processors") { - params.n_processors = std::stoi(argv[++i]); - } else if (arg == "-ot" || arg == "--offset-t") { - params.offset_t_ms = std::stoi(argv[++i]); - } else if (arg == "-on" || arg == "--offset-n") { - params.offset_n = std::stoi(argv[++i]); - } else if (arg == "-d" || arg == "--duration") { - params.duration_ms = std::stoi(argv[++i]); - } else if (arg == "-mc" || arg == "--max-context") { - params.max_context = std::stoi(argv[++i]); - } else if (arg == "-ml" || arg == "--max-len") { - params.max_len = std::stoi(argv[++i]); - } else if (arg == "-wt" || arg == "--word-thold") { - params.word_thold = std::stof(argv[++i]); - } else if (arg == "-v" || arg == "--verbose") { - params.verbose = true; - } else if (arg == "--translate") { - params.translate = true; - } else if (arg == "-l" || arg == "--language") { - params.language = argv[++i]; - if (whisper_lang_id(params.language.c_str()) == -1) { - fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str()); - whisper_print_usage(argc, argv, params); - exit(0); - } - } else if (arg == "-otxt" || arg == "--output-txt") { - params.output_txt = true; - } else if (arg == "-ovtt" || arg == "--output-vtt") { - params.output_vtt = true; - } else if (arg == "-osrt" || arg == "--output-srt") { - params.output_srt = true; - } else if (arg == "-owts" || arg == "--output-words") { - params.output_wts = true; - } else if (arg == "-ps" || arg == "--print_special") { - params.print_special_tokens = true; - } else if (arg == "-pc" || arg == "--print_colors") { - params.print_colors = true; - } else if (arg == "-nt" || arg == "--no_timestamps") { - params.no_timestamps = true; - } else if (arg == "-m" || arg == "--model") { - params.model = argv[++i]; - } else if (arg == "-f" || arg == "--file") { - params.fname_inp.push_back(argv[++i]); - } else if (arg == "-h" || arg == "--help") { + if (arg == "-h" || arg == "--help") { whisper_print_usage(argc, argv, params); exit(0); - } else { + } + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } + else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); } + else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); } + else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); } + else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); } + else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); } + else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); } + else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } + else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } + else if (arg == "-tr" || arg == "--translate") { params.translate = true; } + else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } + else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } + else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; } + else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; } + else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; } + else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } + else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } + else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } + else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } + else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } + else if (arg == "-f" || arg == "--file") { params.fname_inp.push_back(argv[++i]); } + else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); whisper_print_usage(argc, argv, params); exit(0); @@ -151,33 +130,40 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params) fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]); fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); - fprintf(stderr, " -h, --help show this help message and exit\n"); - fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n"); - fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); - fprintf(stderr, " -p N, --processors N number of processors to use during computation (default: %d)\n", params.n_processors); - fprintf(stderr, " -ot N, --offset-t N time offset in milliseconds (default: %d)\n", params.offset_t_ms); - fprintf(stderr, " -on N, --offset-n N segment index offset (default: %d)\n", params.offset_n); - fprintf(stderr, " -d N, --duration N duration of audio to process in milliseconds (default: %d)\n", params.duration_ms); - fprintf(stderr, " -mc N, --max-context N maximum number of text context tokens to store (default: max)\n"); - fprintf(stderr, " -ml N, --max-len N maximum segment length in characters (default: %d)\n", params.max_len); - fprintf(stderr, " -wt N, --word-thold N word timestamp probability threshold (default: %f)\n", params.word_thold); - fprintf(stderr, " -v, --verbose verbose output\n"); - fprintf(stderr, " --translate translate from source language to english\n"); - fprintf(stderr, " -otxt, --output-txt output result in a text file\n"); - fprintf(stderr, " -ovtt, --output-vtt output result in a vtt file\n"); - fprintf(stderr, " -osrt, --output-srt output result in a srt file\n"); - fprintf(stderr, " -owts, --output-words output script for generating karaoke video\n"); - fprintf(stderr, " -ps, --print_special print special tokens\n"); - fprintf(stderr, " -pc, --print_colors print colors\n"); - fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n"); - fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str()); - fprintf(stderr, " -m FNAME, --model FNAME model path (default: %s)\n", params.model.c_str()); - fprintf(stderr, " -f FNAME, --file FNAME input WAV file path\n"); + fprintf(stderr, " -h, --help [default] show this help message and exit\n"); + fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); + fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors); + fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); + fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n); + fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); + fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); + fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); + fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); + fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); + fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); + fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); + fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); + fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false"); + fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false"); + fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false"); + fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); + fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); + fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true"); + fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); + fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); + fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", ""); fprintf(stderr, "\n"); } +struct whisper_print_user_data { + const whisper_params * params; + + const std::vector> * pcmf32s; +}; + void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) { - const whisper_params & params = *(whisper_params *) user_data; + const auto & params = *((whisper_print_user_data *) user_data)->params; + const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s; const int n_segments = whisper_full_n_segments(ctx); @@ -191,7 +177,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi if (params.no_timestamps) { if (params.print_colors) { for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { - if (params.print_special_tokens == false) { + if (params.print_special == false) { const whisper_token id = whisper_full_get_token_id(ctx, i, j); if (id >= whisper_token_eot(ctx)) { continue; @@ -214,10 +200,37 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + std::string speaker = ""; + + if (params.diarize && pcmf32s.size() == 2) { + const int64_t n_samples = pcmf32s[0].size(); + + const int64_t is0 = timestamp_to_sample(t0, n_samples); + const int64_t is1 = timestamp_to_sample(t1, n_samples); + + double energy0 = 0.0f; + double energy1 = 0.0f; + + for (int64_t j = is0; j < is1; j++) { + energy0 += fabs(pcmf32s[0][j]); + energy1 += fabs(pcmf32s[1][j]); + } + + if (energy0 > 1.1*energy1) { + speaker = "(speaker 0)"; + } else if (energy1 > 1.1*energy0) { + speaker = "(speaker 1)"; + } else { + speaker = "(speaker ?)"; + } + + //printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str()); + } + if (params.print_colors) { printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str()); for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { - if (params.print_special_tokens == false) { + if (params.print_special == false) { const whisper_token id = whisper_full_get_token_id(ctx, i, j); if (id >= whisper_token_eot(ctx)) { continue; @@ -229,13 +242,13 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size())))); - printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m"); + printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m"); } printf("\n"); } else { const char * text = whisper_full_get_segment_text(ctx, i); - printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text); + printf("[%s --> %s] %s%s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), speaker.c_str(), text); } } } @@ -263,7 +276,7 @@ bool output_vtt(struct whisper_context * ctx, const char * fname) { std::ofstream fout(fname); if (!fout.is_open()) { fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); - return 9; + return false; } fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); @@ -386,9 +399,9 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f ncnt += txt.size(); } - ::replace_all(txt_bg, "'", "’"); + ::replace_all(txt_bg, "'", "\u2019"); ::replace_all(txt_bg, "\"", "\\\""); - ::replace_all(txt_fg, "'", "’"); + ::replace_all(txt_fg, "'", "\u2019"); ::replace_all(txt_fg, "\"", "\\\""); } @@ -428,16 +441,18 @@ int main(int argc, char ** argv) { return 1; } - if (params.seed < 0) { - params.seed = time(NULL); - } - if (params.fname_inp.empty()) { fprintf(stderr, "error: no input files specified\n"); whisper_print_usage(argc, argv, params); return 2; } + if (whisper_lang_id(params.language.c_str()) == -1) { + fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str()); + whisper_print_usage(argc, argv, params); + exit(0); + } + // whisper init struct whisper_context * ctx = whisper_init(params.model.c_str()); @@ -450,53 +465,60 @@ int main(int argc, char ** argv) { for (int f = 0; f < (int) params.fname_inp.size(); ++f) { const auto fname_inp = params.fname_inp[f]; + std::vector pcmf32; // mono-channel F32 PCM + std::vector> pcmf32s; // stereo-channel F32 PCM + // WAV input - std::vector pcmf32; { drwav wav; - + std::vector wav_data; // used for pipe input from stdin + if (fname_inp == "-") { - std::vector wav_data; { uint8_t buf[1024]; while (true) { const size_t n = fread(buf, 1, sizeof(buf), stdin); - if (n == 0) - { + if (n == 0) { break; } wav_data.insert(wav_data.end(), buf, buf + n); } } - if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), NULL) == false) - { + if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), NULL) == false) { fprintf(stderr, "error: failed to open WAV file from stdin\n"); return 4; } + + fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size()); } else if (drwav_init_file(&wav, fname_inp.c_str(), NULL) == false) { fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str()); - return 4; + return 5; } if (wav.channels != 1 && wav.channels != 2) { fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str()); - return 5; + return 6; + } + + if (params.diarize && wav.channels != 2 && params.no_timestamps == false) { + fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", argv[0], fname_inp.c_str()); + return 6; } if (wav.sampleRate != WHISPER_SAMPLE_RATE) { fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str()); - return 6; + return 8; } if (wav.bitsPerSample != 16) { fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str()); - return 7; + return 9; } - int n = wav.totalPCMFrameCount; + const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8); std::vector pcm16; pcm16.resize(n*wav.channels); @@ -514,6 +536,18 @@ int main(int argc, char ** argv) { pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f; } } + + if (params.diarize) { + // convert to stereo, float + pcmf32s.resize(2); + + pcmf32s[0].resize(n); + pcmf32s[1].resize(n); + for (int i = 0; i < n; i++) { + pcmf32s[0][i] = float(pcm16[2*i])/32768.0f; + pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f; + } + } } // print system information @@ -548,30 +582,47 @@ int main(int argc, char ** argv) { { whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); - wparams.print_realtime = false; - wparams.print_progress = false; - wparams.print_timestamps = !params.no_timestamps; - wparams.print_special_tokens = params.print_special_tokens; - wparams.translate = params.translate; - wparams.language = params.language.c_str(); - wparams.n_threads = params.n_threads; - wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx; - wparams.offset_ms = params.offset_t_ms; - wparams.duration_ms = params.duration_ms; - - wparams.token_timestamps = params.output_wts || params.max_len > 0; - wparams.thold_pt = params.word_thold; - wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; + wparams.print_realtime = false; + wparams.print_progress = false; + wparams.print_timestamps = !params.no_timestamps; + wparams.print_special = params.print_special; + wparams.translate = params.translate; + wparams.language = params.language.c_str(); + wparams.n_threads = params.n_threads; + wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx; + wparams.offset_ms = params.offset_t_ms; + wparams.duration_ms = params.duration_ms; + + wparams.token_timestamps = params.output_wts || params.max_len > 0; + wparams.thold_pt = params.word_thold; + wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; + + wparams.speed_up = params.speed_up; + + whisper_print_user_data user_data = { ¶ms, &pcmf32s }; // this callback is called on each new segment if (!wparams.print_realtime) { wparams.new_segment_callback = whisper_print_segment_callback; - wparams.new_segment_callback_user_data = ¶ms; + wparams.new_segment_callback_user_data = &user_data; + } + + // example for abort mechanism + // in this example, we do not abort the processing, but we could if the flag is set to true + // the callback is called before every encoder run - if it returns false, the processing is aborted + { + static bool is_aborted = false; // NOTE: this should be atomic to avoid data race + + wparams.encoder_begin_callback = [](struct whisper_context * ctx, void * user_data) { + bool is_aborted = *(bool*)user_data; + return !is_aborted; + }; + wparams.encoder_begin_callback_user_data = &is_aborted; } if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) { fprintf(stderr, "%s: failed to process audio\n", argv[0]); - return 8; + return 10; } } diff --git a/examples/whisper/whisper.cpp b/examples/whisper/whisper.cpp index 7078863..42467ef 100644 --- a/examples/whisper/whisper.cpp +++ b/examples/whisper/whisper.cpp @@ -424,6 +424,9 @@ struct whisper_context { int64_t t_last; whisper_token tid_last; std::vector energy; // PCM signal energy + + // [EXPERIMENTAL] speed-up techniques + int32_t exp_n_audio_ctx; // 0 - use default }; // load the model from a ggml file @@ -515,15 +518,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx wctx.buf_memory.resize(MEM_REQ_MEMORY.at(model.type)); wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type))); - - // this is the total memory required to run the inference - const size_t mem_required = - wctx.buf_model->size() + - wctx.buf_memory.size() + - wctx.buf_compute.size() + - wctx.buf_compute_layer.size(); - - fprintf(stderr, "%s: mem_required = %.2f MB\n", __func__, mem_required / 1024.0 / 1024.0); } // load mel filters @@ -596,11 +590,21 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx } } + { + // this is the total memory required to run the inference + const size_t mem_required = + wctx.buf_model->size() + + wctx.buf_memory.size() + + wctx.buf_compute.size() + + wctx.buf_compute_layer.size(); + + fprintf(stderr, "%s: mem_required = %7.2f MB\n", __func__, mem_required / 1024.0 / 1024.0); + } + // for the big tensors, we have the option to store the data in 16-bit floats // in order to save memory and also to speed up the computation const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32; - size_t ctx_size = 0; size_t ctx_mem_size = 0; @@ -613,7 +617,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx const int n_audio_state = hparams.n_audio_state; const int n_audio_layer = hparams.n_audio_layer; - const int n_text_ctx = hparams.n_text_ctx; + const int n_text_ctx = hparams.n_text_ctx; const int n_text_state = hparams.n_text_state; const int n_text_layer = hparams.n_text_layer; @@ -719,7 +723,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead - fprintf(stderr, "%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0)); + fprintf(stderr, "%s: ggml ctx size = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0)); } // create the ggml context @@ -748,7 +752,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx const int n_audio_state = hparams.n_audio_state; const int n_audio_layer = hparams.n_audio_layer; - const int n_text_ctx = hparams.n_text_ctx; + const int n_text_ctx = hparams.n_text_ctx; const int n_text_state = hparams.n_text_state; const int n_text_layer = hparams.n_text_layer; @@ -967,7 +971,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx // key/value memory for the cross-attention layer { - const int n_audio_ctx = hparams.n_audio_ctx; + const int n_audio_ctx = hparams.n_audio_ctx; const int n_mem = n_text_layer*n_audio_ctx; const int n_elements = n_text_state*n_mem; @@ -980,7 +984,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) + ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v); - fprintf(stderr, "%s: memory size = %8.2f MB\n", __func__, memory_size/1024.0/1024.0); + fprintf(stderr, "%s: memory size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); } // load weights @@ -1039,12 +1043,12 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx fin.read(reinterpret_cast(tensor->data), ggml_nbytes(tensor)); - //printf("%24s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0); + //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0); total_size += ggml_nbytes(tensor); model.n_loaded++; } - fprintf(stderr, "%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0); + fprintf(stderr, "%s: model size = %7.2f MB\n", __func__, total_size/1024.0/1024.0); if (model.n_loaded == 0) { fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__); @@ -1076,13 +1080,11 @@ static bool whisper_encode( const auto & mel_inp = wctx.mel; const auto & hparams = model.hparams; - const int n_ctx = hparams.n_audio_ctx; + const int n_ctx = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx; const int n_state = hparams.n_audio_state; const int n_head = hparams.n_audio_head; const int n_layer = hparams.n_audio_layer; - const int N = n_ctx; - const int n_mels = hparams.n_mels; assert(mel_inp.n_mel == n_mels); @@ -1132,7 +1134,30 @@ static bool whisper_encode( cur = ggml_gelu(ctx0, cur); } - cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur)); + // =================================================================== + // NOTE: experimenting with partial evaluation of the encoder (ignore) + //static int iter = -1; + //const int n_iter = 1500/n_ctx; + + //iter = (iter + 1) % n_iter; + + //if (iter == 0) { + // memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k)); + // memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v)); + //} + + static int iter = 0; + + const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe); + const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter; + + struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset); + + cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur)); + // =================================================================== + + // original: + //cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur)); struct ggml_tensor * inpL = cur; @@ -1198,14 +1223,14 @@ static bool whisper_encode( ggml_permute(ctxL, ggml_cpy(ctxL, Qcur, - ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)), + ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), 0, 2, 1, 3); struct ggml_tensor * K = ggml_permute(ctxL, ggml_cpy(ctxL, Kcur, - ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)), + ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), 0, 2, 1, 3); struct ggml_tensor * V = @@ -1213,9 +1238,9 @@ static bool whisper_encode( ggml_permute(ctxL, ggml_reshape_3d(ctxL, Vcur, - n_state/n_head, n_head, N), + n_state/n_head, n_head, n_ctx), 1, 2, 0, 3), - ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, N, n_state/n_head, n_head) + ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_ctx, n_state/n_head, n_head) ); struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false); @@ -1224,14 +1249,14 @@ static bool whisper_encode( ggml_permute(ctxL, ggml_cpy(ctxL, Qcur, - ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)), + ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)), 0, 2, 1, 3); struct ggml_tensor * K = ggml_permute(ctxL, ggml_cpy(ctxL, Kcur, - ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)), + ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), 0, 2, 1, 3); // K * Q @@ -1249,7 +1274,7 @@ static bool whisper_encode( // ggml_permute(ctxL, // ggml_cpy(ctxL, // Vcur, - // ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)), + // ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), // 1, 2, 0, 3); //struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max); @@ -1259,9 +1284,9 @@ static bool whisper_encode( ggml_permute(ctxL, ggml_reshape_3d(ctxL, Vcur, - n_state/n_head, n_head, N), + n_state/n_head, n_head, n_ctx), 0, 2, 1, 3), - ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, N, n_head) + ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_ctx, n_head) ); struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max); @@ -1271,7 +1296,7 @@ static bool whisper_encode( cur = ggml_cpy(ctxL, KQV_merged, - ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N)); + ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, n_ctx)); } // projection @@ -1425,6 +1450,8 @@ static bool whisper_encode( Vcross), Vcross); + //struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); + //struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*n_ctx)); struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*n_ctx)); @@ -1474,7 +1501,7 @@ static bool whisper_decode( const int n_layer = hparams.n_text_layer; const int N = n_tokens; - const int M = hparams.n_audio_ctx; + const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx; struct ggml_init_params params = { .mem_size = wctx.buf_compute.size(), @@ -1819,7 +1846,9 @@ static bool whisper_decode( // the most basic sampling scheme - select the top token static whisper_token_data whisper_sample_best( const whisper_vocab & vocab, - const float * probs) { + const float * probs, + bool force_timestamp, + bool is_initial) { whisper_token_data result = { 0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f, }; @@ -1842,7 +1871,18 @@ static whisper_token_data whisper_sample_best( max_tx = std::max(max_tx, probs_id[i].first); } - for (int i = vocab.token_beg; i < n_logits; i++) { + const auto i0 = is_initial ? vocab.token_beg + 101 : vocab.token_beg; + const auto i1 = is_initial ? vocab.token_beg + 101 : n_logits; + + // the initial timestamp cannot be larger than 100 + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429 + if (is_initial) { + for (int i = i0; i < n_logits; ++ i) { + probs_id[i].first = -INFINITY; + } + } + + for (int i = vocab.token_beg; i < i1; i++) { sum_ts += probs_id[i].first; if (probs_id[i].first > max_ts) { max_ts = probs_id[i].first; @@ -1852,7 +1892,7 @@ static whisper_token_data whisper_sample_best( // if the probability sum of all timestamp tokens is higher than the max probability of the text tokens - sample a // timestamp token - if (sum_ts > max_tx) { + if (sum_ts > max_tx || force_timestamp) { // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438 for (int i = 0; i < vocab.token_beg; i++) { probs_id[i].first = -INFINITY; @@ -1894,39 +1934,6 @@ static whisper_token_data whisper_sample_best( return result; } -// samples only from the timestamps tokens -static whisper_vocab::id whisper_sample_timestamp( - const whisper_vocab & vocab, - const float * probs) { - int n_logits = vocab.id_to_token.size(); - - std::vector> probs_id; - probs_id.reserve(n_logits); - - for (int i = vocab.token_beg + 1; i < n_logits; i++) { - probs_id.push_back(std::make_pair(probs[i], i)); - } - - const int top_k = 10; - - // find the top K tokens - std::partial_sort( - probs_id.begin(), - probs_id.begin() + top_k, probs_id.end(), - [](const std::pair & a, const std::pair & b) { - return a.first > b.first; - }); - - probs_id.resize(top_k); - - //printf("\n"); - //for (int i = 0; i < (int) probs_id.size(); i++) { - // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second); - //} - - return probs_id[0].second; -} - // 500 -> 00:05.000 // 6000 -> 01:00.000 static std::string to_timestamp(int64_t t, bool comma = false) { @@ -2031,6 +2038,7 @@ static bool log_mel_spectrogram( const int n_mel, const int n_threads, const whisper_filters & filters, + const bool speed_up, whisper_mel & mel) { // Hanning window @@ -2044,7 +2052,7 @@ static bool log_mel_spectrogram( mel.n_len = (n_samples)/fft_step; mel.data.resize(mel.n_mel*mel.n_len); - const int n_fft = 1 + fft_size/2; + const int n_fft = 1 + (speed_up ? fft_size/4 : fft_size/2); //printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len); //printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate); @@ -2091,6 +2099,13 @@ static bool log_mel_spectrogram( //} } + if (speed_up) { + // scale down in the frequency domain results in a speed up in the time domain + for (int j = 0; j < n_fft; j++) { + fft_out[j] = 0.5*(fft_out[2*j] + fft_out[2*j + 1]); + } + } + // mel spectrogram for (int j = 0; j < mel.n_mel; j++) { double sum = 0.0; @@ -2161,6 +2176,12 @@ struct whisper_context * whisper_init(const char * path_model) { void whisper_free(struct whisper_context * ctx) { if (ctx) { + if (ctx->model.ctx) { + ggml_free(ctx->model.ctx); + } + if (ctx->model.ctx_mem) { + ggml_free(ctx->model.ctx_mem); + } if (ctx->buf_model) { delete ctx->buf_model; } @@ -2171,7 +2192,21 @@ void whisper_free(struct whisper_context * ctx) { int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { const int64_t t_start_us = ggml_time_us(); - if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, ctx->mel)) { + if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) { + fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__); + return -1; + } + + ctx->t_mel_us = ggml_time_us() - t_start_us; + + return 0; +} + +// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 +int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { + const int64_t t_start_us = ggml_time_us(); + + if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) { fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__); return -1; } @@ -2229,19 +2264,17 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) { const int64_t t_start_sample_us = ggml_time_us(); - // TODO: simplify - auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab)); + const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), false, false); ctx->t_sample_us += ggml_time_us() - t_start_sample_us; return res; } -whisper_token whisper_sample_timestamp(struct whisper_context * ctx) { +struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial) { const int64_t t_start_sample_us = ggml_time_us(); - // TODO: simplify - auto res = whisper_sample_timestamp(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab)); + const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), true, is_initial); ctx->t_sample_us += ggml_time_us() - t_start_sample_us; @@ -2305,11 +2338,11 @@ whisper_token whisper_token_beg(struct whisper_context * ctx) { return ctx->vocab.token_beg; } -whisper_token whisper_token_translate() { +whisper_token whisper_token_translate(void) { return whisper_vocab::token_translate; } -whisper_token whisper_token_transcribe() { +whisper_token whisper_token_transcribe(void) { return whisper_vocab::token_transcribe; } @@ -2325,6 +2358,27 @@ void whisper_print_timings(struct whisper_context * ctx) { fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); } +void whisper_reset_timings(struct whisper_context * ctx) { + ctx->t_sample_us = 0; + ctx->t_encode_us = 0; + ctx->t_decode_us = 0; +} + +const char * whisper_print_system_info(void) { + static std::string s; + + s = ""; + s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | "; + s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | "; + s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | "; + s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | "; + s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | "; + s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | "; + s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | "; + + return s.c_str(); +} + //////////////////////////////////////////////////////////////////////////// struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) { @@ -2334,77 +2388,99 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str case WHISPER_SAMPLING_GREEDY: { result = { - /*.strategy =*/ WHISPER_SAMPLING_GREEDY, + /*.strategy =*/ WHISPER_SAMPLING_GREEDY, + + /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), + /*.n_max_text_ctx =*/ 16384, + /*.offset_ms =*/ 0, + /*.duration_ms =*/ 0, - /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), - /*.n_max_text_ctx =*/ 16384, - /*.offset_ms =*/ 0, - /*.duration_ms =*/ 0, + /*.translate =*/ false, + /*.no_context =*/ false, + /*.single_segment =*/ false, + /*.print_special =*/ false, + /*.print_progress =*/ true, + /*.print_realtime =*/ false, + /*.print_timestamps =*/ true, - /*.translate =*/ false, - /*.no_context =*/ false, - /*.print_special_tokens =*/ false, - /*.print_progress =*/ true, - /*.print_realtime =*/ false, - /*.print_timestamps =*/ true, + /*.token_timestamps =*/ false, + /*.thold_pt =*/ 0.01f, + /*.thold_ptsum =*/ 0.01f, + /*.max_len =*/ 0, + /*.max_tokens =*/ 0, - /*.token_timestamps =*/ false, - /*.thold_pt =*/ 0.01f, - /*.thold_ptsum =*/ 0.01f, - /*.max_len =*/ 0, + /*.speed_up =*/ false, + /*.audio_ctx =*/ 0, - /*.language =*/ "en", + /*.prompt_tokens =*/ nullptr, + /*.prompt_n_tokens =*/ 0, - /*.greedy =*/ { + /*.language =*/ "en", + + /*.greedy =*/ { /*.n_past =*/ 0, }, - /*.beam_search =*/ { + /*.beam_search =*/ { /*.n_past =*/ -1, /*.beam_width =*/ -1, /*.n_best =*/ -1, }, - /*.new_segment_callback =*/ nullptr, + /*.new_segment_callback =*/ nullptr, /*.new_segment_callback_user_data =*/ nullptr, + + /*.encoder_begin_callback =*/ nullptr, + /*.encoder_begin_callback_user_data =*/ nullptr, }; } break; case WHISPER_SAMPLING_BEAM_SEARCH: { result = { - /*.strategy =*/ WHISPER_SAMPLING_BEAM_SEARCH, + /*.strategy =*/ WHISPER_SAMPLING_BEAM_SEARCH, - /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), - /*.n_max_text_ctx =*/ 16384, - /*.offset_ms =*/ 0, - /*.duration_ms =*/ 0, + /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), + /*.n_max_text_ctx =*/ 16384, + /*.offset_ms =*/ 0, + /*.duration_ms =*/ 0, - /*.translate =*/ false, - /*.no_context =*/ false, - /*.print_special_tokens =*/ false, - /*.print_progress =*/ true, - /*.print_realtime =*/ false, - /*.print_timestamps =*/ true, + /*.translate =*/ false, + /*.no_context =*/ false, + /*.single_segment =*/ false, + /*.print_special =*/ false, + /*.print_progress =*/ true, + /*.print_realtime =*/ false, + /*.print_timestamps =*/ true, - /*.token_timestamps =*/ false, - /*.thold_pt =*/ 0.01f, - /*.thold_ptsum =*/ 0.01f, - /*.max_len =*/ 0, + /*.token_timestamps =*/ false, + /*.thold_pt =*/ 0.01f, + /*.thold_ptsum =*/ 0.01f, + /*.max_len =*/ 0, + /*.max_tokens =*/ 0, - /*.language =*/ "en", + /*.speed_up =*/ false, + /*.audio_ctx =*/ 0, - /*.greedy =*/ { + /*.prompt_tokens =*/ nullptr, + /*.prompt_n_tokens =*/ 0, + + /*.language =*/ "en", + + /*.greedy =*/ { /*.n_past =*/ -1, }, - /*.beam_search =*/ { + /*.beam_search =*/ { /*.n_past =*/ 0, /*.beam_width =*/ 10, /*.n_best =*/ 5, }, - /*.new_segment_callback =*/ nullptr, + /*.new_segment_callback =*/ nullptr, /*.new_segment_callback_user_data =*/ nullptr, + + /*.encoder_begin_callback =*/ nullptr, + /*.encoder_begin_callback_user_data =*/ nullptr, }; } break; } @@ -2485,9 +2561,16 @@ int whisper_full( result_all.clear(); // compute log mel spectrogram - if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) { - fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__); - return -1; + if (params.speed_up) { + if (whisper_pcm_to_mel_phase_vocoder(ctx, samples, n_samples, params.n_threads) != 0) { + fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__); + return -1; + } + } else { + if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) { + fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__); + return -1; + } } if (params.token_timestamps) { @@ -2513,6 +2596,18 @@ int whisper_full( prompt_past.clear(); } + // prepend the prompt tokens to the prompt_past + if (params.prompt_tokens && params.prompt_n_tokens > 0) { + // parse tokens from the pointer + for (int i = 0; i < params.prompt_n_tokens; i++) { + prompt_past.push_back(params.prompt_tokens[i]); + } + std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end()); + } + + // overwrite audio_ctx + ctx->exp_n_audio_ctx = params.audio_ctx; + // these tokens determine the task that will be performed std::vector prompt_init = { whisper_token_sot(ctx) }; if (whisper_is_multilingual(ctx)) { @@ -2548,6 +2643,13 @@ int whisper_full( break; } + if (params.encoder_begin_callback) { + if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) { + fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__); + break; + } + } + // encode audio features starting at offset seek if (whisper_encode(ctx, seek, params.n_threads) != 0) { fprintf(stderr, "%s: failed to encode\n", __func__); @@ -2570,7 +2672,6 @@ int whisper_full( prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); - bool done = false; int seek_delta = 100*WHISPER_CHUNK_SIZE; // print the prompt @@ -2584,7 +2685,9 @@ int whisper_full( int result_len = 0; tokens_cur.clear(); - for (int i = 0; i < whisper_n_text_ctx(ctx)/2 - 4; ++i) { + bool failed = false; + + for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) { fprintf(stderr, "%s: failed to decode\n", __func__); return 8; @@ -2601,15 +2704,19 @@ int whisper_full( // feel free to experiment! // { - auto token = whisper_sample_best(ctx); - - if (i == 0) { - token.tid = whisper_token_beg(ctx); - } + const auto token = (i == 0) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx); // timestamp token - update sliding window if (token.id > whisper_token_beg(ctx)) { - seek_delta = 2*(token.id - whisper_token_beg(ctx)); + const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx)); + + // do not allow to go back in time + if (seek_delta != 100*WHISPER_CHUNK_SIZE && + seek_delta > seek_delta_new && result_len < i) { + break; + } + + seek_delta = seek_delta_new; result_len = i + 1; } @@ -2619,19 +2726,25 @@ int whisper_full( //{ // const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]"; - // printf("%s: %10s %6.3f '%s'\n", __func__, tt.c_str(), token.pt, ctx->vocab.id_to_token[token.id].c_str()); + // printf("%s: %10s %6d %6.3f '%s'\n", __func__, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str()); //} // end of text token - if (token.id == whisper_token_eot(ctx)) { + if (token.id == whisper_token_eot(ctx) || (params.max_tokens > 0 && i > params.max_tokens)) { if (result_len == 0) { if (seek + seek_delta + 100 >= seek_end) { result_len = i + 1; } else { - // TODO: figure out how to resolve this - fprintf(stderr, "\n%s: failed to generate timestamp token - this should not happen\n\n", __func__); + failed = true; + break; } } + + if (params.single_segment) { + result_len = i + 1; + seek_delta = 100*WHISPER_CHUNK_SIZE; + } + break; } @@ -2642,11 +2755,21 @@ int whisper_full( } } - if (done) { + // sometimes, the decoding can get stuck in a repetition loop + // this is a simple strategy to avoid such cases - we simply flag the decoding as failed and advance + // the sliding window by 1 second + if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) { + failed = true; break; } } + if (failed) { + fprintf(stderr, "\n%s: failed to generate timestamp token - using fallback strategy\n\n", __func__); + seek += 100; + continue; + } + // shrink down to result_len tokens_cur.resize(result_len); @@ -2666,23 +2789,26 @@ int whisper_full( // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p, // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt); - if (params.print_special_tokens == false && tokens_cur[i].id >= whisper_token_eot(ctx)) { + if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) { } else { text += whisper_token_to_str(ctx, tokens_cur[i].id); } - if (tokens_cur[i].id > whisper_token_beg(ctx)) { + if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) { const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx)); if (!text.empty()) { + const auto tt0 = params.speed_up ? 2*t0 : t0; + const auto tt1 = params.speed_up ? 2*t1 : t1; + if (params.print_realtime) { if (params.print_timestamps) { - printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text.c_str()); + printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); } else { printf("%s", text.c_str()); fflush(stdout); } } - result_all.push_back({ t0, t1, text, {} }); + result_all.push_back({ tt0, tt1, text, {} }); for (int j = i0; j <= i; j++) { result_all.back().tokens.push_back(tokens_cur[j]); } @@ -2714,16 +2840,19 @@ int whisper_full( if (!text.empty()) { const auto t1 = seek + seek_delta; + const auto tt0 = params.speed_up ? 2*t0 : t0; + const auto tt1 = params.speed_up ? 2*t1 : t1; + if (params.print_realtime) { if (params.print_timestamps) { - printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text.c_str()); + printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); } else { printf("%s", text.c_str()); fflush(stdout); } } - result_all.push_back({ t0, t1, text, {} }); + result_all.push_back({ tt0, tt1, text, {} }); for (int j = i0; j < (int) tokens_cur.size(); j++) { result_all.back().tokens.push_back(tokens_cur[j]); } @@ -2755,7 +2884,7 @@ int whisper_full_parallel( struct whisper_full_params params, const float * samples, int n_samples, - const int n_processors) { + int n_processors) { if (n_processors == 1) { return whisper_full(ctx, params, samples, n_samples); } @@ -2805,7 +2934,7 @@ int whisper_full_parallel( // key/value memory for the cross-attention layer { - const int n_audio_ctx = hparams.n_audio_ctx; + const int n_audio_ctx = hparams.n_audio_ctx; const int n_mem = n_text_layer*n_audio_ctx; const int n_elements = n_text_state*n_mem; @@ -2813,10 +2942,6 @@ int whisper_full_parallel( model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); } - - const size_t memory_size = - ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) + - ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v); } } @@ -2936,20 +3061,6 @@ float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int return ctx->result_all[i_segment].tokens[i_token].p; } -const char * whisper_print_system_info() { - static std::string s; - - s = ""; - s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | "; - s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | "; - s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | "; - s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | "; - s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | "; - s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | "; - - return s.c_str(); -} - // ================================================================================================= // @@ -3036,9 +3147,6 @@ static void whisper_exp_compute_token_level_timestamps( const int64_t t0 = segment.t0; const int64_t t1 = segment.t1; - const int s0 = timestamp_to_sample(t0, n_samples); - const int s1 = timestamp_to_sample(t1, n_samples); - const int n = tokens.size(); if (n == 0) { diff --git a/examples/whisper/whisper.h b/examples/whisper/whisper.h index 4c112f4..def77d4 100644 --- a/examples/whisper/whisper.h +++ b/examples/whisper/whisper.h @@ -72,16 +72,16 @@ extern "C" { whisper_token id; // token id whisper_token tid; // forced timestamp token id - float p; // probability of the token - float pt; // probability of the timestamp token - float ptsum; // sum of probabilities of all timestamp tokens + float p; // probability of the token + float pt; // probability of the timestamp token + float ptsum; // sum of probabilities of all timestamp tokens // token-level timestamp data // do not use if you haven't computed token-level timestamps - int64_t t0; // start time of the token - int64_t t1; // end time of the token + int64_t t0; // start time of the token + int64_t t1; // end time of the token - float vlen; // voice length of the token + float vlen; // voice length of the token } whisper_token_data; // Allocates all memory needed for the model and loads the model from the given file. @@ -96,9 +96,9 @@ extern "C" { // Returns 0 on success WHISPER_API int whisper_pcm_to_mel( struct whisper_context * ctx, - const float * samples, - int n_samples, - int n_threads); + const float * samples, + int n_samples, + int n_threads); // This can be used to set a custom log mel spectrogram inside the provided whisper context. // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram. @@ -106,9 +106,9 @@ extern "C" { // Returns 0 on success WHISPER_API int whisper_set_mel( struct whisper_context * ctx, - const float * data, - int n_len, - int n_mel); + const float * data, + int n_len, + int n_mel); // Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context. // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first. @@ -116,8 +116,8 @@ extern "C" { // Returns 0 on success WHISPER_API int whisper_encode( struct whisper_context * ctx, - int offset, - int n_threads); + int offset, + int n_threads); // Run the Whisper decoder to obtain the logits and probabilities for the next token. // Make sure to call whisper_encode() first. @@ -126,10 +126,10 @@ extern "C" { // Returns 0 on success WHISPER_API int whisper_decode( struct whisper_context * ctx, - const whisper_token * tokens, - int n_tokens, - int n_past, - int n_threads); + const whisper_token * tokens, + int n_tokens, + int n_past, + int n_threads); // Token sampling methods. // These are provided for convenience and can be used after each call to whisper_decode(). @@ -137,7 +137,7 @@ extern "C" { // whisper_sample_best() returns the token with the highest probability // whisper_sample_timestamp() returns the most probable timestamp token WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx); - WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx); + WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial); // Return the id of the specified language, returns -1 if not found WHISPER_API int whisper_lang_id(const char * lang); @@ -162,11 +162,15 @@ extern "C" { WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx); // Task tokens - WHISPER_API whisper_token whisper_token_translate (); - WHISPER_API whisper_token whisper_token_transcribe(); + WHISPER_API whisper_token whisper_token_translate (void); + WHISPER_API whisper_token whisper_token_transcribe(void); // Performance information WHISPER_API void whisper_print_timings(struct whisper_context * ctx); + WHISPER_API void whisper_reset_timings(struct whisper_context * ctx); + + // Print system information + WHISPER_API const char * whisper_print_system_info(void); //////////////////////////////////////////////////////////////////////////// @@ -181,17 +185,26 @@ extern "C" { // Use the whisper_full_...() functions to obtain the text segments typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data); + // Encoder begin callback + // If not NULL, called before the encoder starts + // If it returns false, the computation is aborted + typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data); + + // Parameters for the whisper_full() function + // If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp: + // whisper_full_default_params() struct whisper_full_params { enum whisper_sampling_strategy strategy; int n_threads; int n_max_text_ctx; - int offset_ms; // start offset in ms - int duration_ms; // audio duration to process in ms + int offset_ms; // start offset in ms + int duration_ms; // audio duration to process in ms bool translate; bool no_context; - bool print_special_tokens; + bool single_segment; // force single segment output (useful for streaming) + bool print_special; bool print_progress; bool print_realtime; bool print_timestamps; @@ -201,6 +214,16 @@ extern "C" { float thold_pt; // timestamp token probability threshold (~0.01) float thold_ptsum; // timestamp token sum probability threshold (~0.01) int max_len; // max segment length in characters + int max_tokens; // max tokens per segment (0 = no limit) + + // [EXPERIMENTAL] speed-up techniques + bool speed_up; // speed-up the audio by 2x using Phase Vocoder + int audio_ctx; // overwrite the audio context size (0 = use default) + + // tokens to provide the whisper model as initial prompt + // these are prepended to any existing text context from a previous call + const whisper_token * prompt_tokens; + int prompt_n_tokens; const char * language; @@ -216,6 +239,9 @@ extern "C" { whisper_new_segment_callback new_segment_callback; void * new_segment_callback_user_data; + + whisper_encoder_begin_callback encoder_begin_callback; + void * encoder_begin_callback_user_data; }; WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy); @@ -223,20 +249,20 @@ extern "C" { // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text // Uses the specified decoding strategy to obtain the text. WHISPER_API int whisper_full( - struct whisper_context * ctx, - struct whisper_full_params params, - const float * samples, - int n_samples); + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples); // Split the input audio in chunks and process each chunk separately using whisper_full() // It seems this approach can offer some speedup in some cases. // However, the transcription accuracy can be worse at the beginning and end of each chunk. WHISPER_API int whisper_full_parallel( - struct whisper_context * ctx, - struct whisper_full_params params, - const float * samples, - int n_samples, - const int n_processors); + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples, + int n_processors); // Number of generated text segments. // A segment can be a few words, a sentence, or even a paragraph. @@ -263,9 +289,6 @@ extern "C" { // Get the probability of the specified token in the specified segment. WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token); - // Print system information - WHISPER_API const char * whisper_print_system_info(); - #ifdef __cplusplus } #endif diff --git a/include/ggml/ggml.h b/include/ggml/ggml.h index f352e71..3e4e962 100644 --- a/include/ggml/ggml.h +++ b/include/ggml/ggml.h @@ -723,6 +723,7 @@ enum ggml_opt_result ggml_opt( // system info // +int ggml_cpu_has_avx(void); int ggml_cpu_has_avx2(void); int ggml_cpu_has_avx512(void); int ggml_cpu_has_neon(void); diff --git a/src/ggml.c b/src/ggml.c index 484b6dc..b6d528d 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -15,7 +15,14 @@ #include #if defined _MSC_VER || defined(__MINGW32__) + +#if !defined(__MINGW32__) #include +#else +// ref: https://github.com/ggerganov/whisper.cpp/issues/168 +#include +#include +#endif typedef volatile LONG atomic_int; typedef atomic_int atomic_bool; @@ -37,7 +44,7 @@ typedef HANDLE pthread_t; typedef DWORD thread_ret_t; static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) { - HANDLE handle = CreateThread(NULL, 0, func, arg, 0, NULL); + HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL); if (handle == NULL) { return EAGAIN; @@ -372,6 +379,49 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float sumf = _mm_cvtss_f32(r1); + // leftovers + for (int i = n32; i < n; ++i) { + sumf += x[i]*y[i]; + } +#elif defined(__AVX__) + // AVX 256-bit + const int n32 = (n & ~31); + + __m256 sum0 = _mm256_setzero_ps(); + __m256 sum1 = _mm256_setzero_ps(); + __m256 sum2 = _mm256_setzero_ps(); + __m256 sum3 = _mm256_setzero_ps(); + + __m256 x0, x1, x2, x3; + __m256 y0, y1, y2, y3; + + for (int i = 0; i < n32; i += 32) { + x0 = _mm256_loadu_ps(x + i + 0); + x1 = _mm256_loadu_ps(x + i + 8); + x2 = _mm256_loadu_ps(x + i + 16); + x3 = _mm256_loadu_ps(x + i + 24); + + y0 = _mm256_loadu_ps(y + i + 0); + y1 = _mm256_loadu_ps(y + i + 8); + y2 = _mm256_loadu_ps(y + i + 16); + y3 = _mm256_loadu_ps(y + i + 24); + + sum0 = _mm256_add_ps(_mm256_mul_ps(x0, y0), sum0); + sum1 = _mm256_add_ps(_mm256_mul_ps(x1, y1), sum1); + sum2 = _mm256_add_ps(_mm256_mul_ps(x2, y2), sum2); + sum3 = _mm256_add_ps(_mm256_mul_ps(x3, y3), sum3); + } + + sum0 = _mm256_add_ps(sum0, sum1); + sum2 = _mm256_add_ps(sum2, sum3); + sum0 = _mm256_add_ps(sum0, sum2); + + const __m128 r4 = _mm_add_ps(_mm256_castps256_ps128(sum0), _mm256_extractf128_ps(sum0, 1)); + const __m128 r2 = _mm_add_ps(r4, _mm_movehl_ps(r4, r4)); + const __m128 r1 = _mm_add_ss(r2, _mm_movehdup_ps(r2)); + + sumf = _mm_cvtss_f32(r1); + // leftovers for (int i = n32; i < n; ++i) { sumf += x[i]*y[i]; @@ -569,6 +619,50 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t sumf = _mm_cvtss_f32(r1); + // leftovers + for (int i = n32; i < n; ++i) { + //GGML_ASSERT(false); + sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]); + } +#elif defined(__AVX__) + // AVX 256-bit + const int n32 = (n & ~31); + + __m256 sum0 = _mm256_setzero_ps(); + __m256 sum1 = _mm256_setzero_ps(); + __m256 sum2 = _mm256_setzero_ps(); + __m256 sum3 = _mm256_setzero_ps(); + + __m256 x0, x1, x2, x3; + __m256 y0, y1, y2, y3; + + for (int i = 0; i < n32; i += 32) { + x0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 ))); + x1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 ))); + x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16))); + x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24))); + + y0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 ))); + y1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 ))); + y2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16))); + y3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24))); + + sum0 = _mm256_add_ps(_mm256_mul_ps(x0, y0), sum0); + sum1 = _mm256_add_ps(_mm256_mul_ps(x1, y1), sum1); + sum2 = _mm256_add_ps(_mm256_mul_ps(x2, y2), sum2); + sum3 = _mm256_add_ps(_mm256_mul_ps(x3, y3), sum3); + } + + const __m256 sum01 = _mm256_add_ps(sum0, sum1); + const __m256 sum23 = _mm256_add_ps(sum2, sum3); + const __m256 sum0123 = _mm256_add_ps(sum01, sum23); + + const __m128 r4 = _mm_add_ps(_mm256_castps256_ps128(sum0123), _mm256_extractf128_ps(sum0123, 1)); + const __m128 r2 = _mm_add_ps(r4, _mm_movehl_ps(r4, r4)); + const __m128 r1 = _mm_add_ss(r2, _mm_movehdup_ps(r2)); + + sumf = _mm_cvtss_f32(r1); + // leftovers for (int i = n32; i < n; ++i) { //GGML_ASSERT(false); @@ -698,6 +792,41 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float _mm256_storeu_ps(y + i + 24, y3); } + // leftovers + for (int i = n32; i < n; ++i) { + y[i] += x[i]*v; + } +#elif defined(__AVX__) + // AVX 256-bit + const int n32 = (n & ~31); + + const __m256 v4 = _mm256_set1_ps(v); + + __m256 x0, x1, x2, x3; + __m256 y0, y1, y2, y3; + + for (int i = 0; i < n32; i += 32) { + x0 = _mm256_loadu_ps(x + i + 0); + x1 = _mm256_loadu_ps(x + i + 8); + x2 = _mm256_loadu_ps(x + i + 16); + x3 = _mm256_loadu_ps(x + i + 24); + + y0 = _mm256_loadu_ps(y + i + 0); + y1 = _mm256_loadu_ps(y + i + 8); + y2 = _mm256_loadu_ps(y + i + 16); + y3 = _mm256_loadu_ps(y + i + 24); + + y0 = _mm256_add_ps(_mm256_mul_ps(x0, v4), y0); + y1 = _mm256_add_ps(_mm256_mul_ps(x1, v4), y1); + y2 = _mm256_add_ps(_mm256_mul_ps(x2, v4), y2); + y3 = _mm256_add_ps(_mm256_mul_ps(x3, v4), y3); + + _mm256_storeu_ps(y + i + 0, y0); + _mm256_storeu_ps(y + i + 8, y1); + _mm256_storeu_ps(y + i + 16, y2); + _mm256_storeu_ps(y + i + 24, y3); + } + // leftovers for (int i = n32; i < n; ++i) { y[i] += x[i]*v; @@ -859,6 +988,42 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_ _mm_storeu_si128((__m128i*)(y + i + 24), _mm256_cvtps_ph(y3, 0)); } + // leftovers + for (int i = n32; i < n; ++i) { + GGML_ASSERT(false); + y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v); + } +#elif defined(__AVX__) + // AVX 256-bit + const int n32 = (n & ~31); + + const __m256 v8 = _mm256_set1_ps(v); + + __m256 x0, x1, x2, x3; + __m256 y0, y1, y2, y3; + + for (int i = 0; i < n32; i += 32) { + y0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 ))); + y1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 ))); + y2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16))); + y3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24))); + + x0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 ))); + x1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 ))); + x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16))); + x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24))); + + y0 = _mm256_add_ps(_mm256_mul_ps(x0, v8), y0); + y1 = _mm256_add_ps(_mm256_mul_ps(x1, v8), y1); + y2 = _mm256_add_ps(_mm256_mul_ps(x2, v8), y2); + y3 = _mm256_add_ps(_mm256_mul_ps(x3, v8), y3); + + _mm_storeu_si128((__m128i*)(y + i + 0 ), _mm256_cvtps_ph(y0, 0)); + _mm_storeu_si128((__m128i*)(y + i + 8 ), _mm256_cvtps_ph(y1, 0)); + _mm_storeu_si128((__m128i*)(y + i + 16), _mm256_cvtps_ph(y2, 0)); + _mm_storeu_si128((__m128i*)(y + i + 24), _mm256_cvtps_ph(y3, 0)); + } + // leftovers for (int i = n32; i < n; ++i) { GGML_ASSERT(false); @@ -8081,6 +8246,14 @@ enum ggml_opt_result ggml_opt( //////////////////////////////////////////////////////////////////////////////// +int ggml_cpu_has_avx(void) { +#if defined(__AVX__) + return 1; +#else + return 0; +#endif +} + int ggml_cpu_has_avx2(void) { #if defined(__AVX2__) return 1;