From bd9f710a45dede925f4c1058207eceb225b05508 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 31 Dec 2022 12:32:04 +0200 Subject: [PATCH] sync : latest changes from whisper.cpp --- examples/whisper/main.cpp | 201 ++-- examples/whisper/whisper.cpp | 420 ++++++-- examples/whisper/whisper.h | 35 + include/ggml/ggml.h | 57 +- src/ggml.c | 1774 +++++++++++++++++----------------- 5 files changed, 1426 insertions(+), 1061 deletions(-) diff --git a/examples/whisper/main.cpp b/examples/whisper/main.cpp index 465d43f..ce8b484 100644 --- a/examples/whisper/main.cpp +++ b/examples/whisper/main.cpp @@ -62,19 +62,22 @@ struct whisper_params { float word_thold = 0.01f; - bool speed_up = false; - bool translate = false; - bool diarize = false; - bool output_txt = false; - bool output_vtt = false; - bool output_srt = false; - bool output_wts = false; - bool print_special = false; - bool print_colors = false; - bool no_timestamps = false; - - std::string language = "en"; - std::string model = "models/ggml-base.en.bin"; + bool speed_up = false; + bool translate = false; + bool diarize = false; + bool output_txt = false; + bool output_vtt = false; + bool output_srt = false; + bool output_wts = false; + bool output_csv = false; + bool print_special = false; + bool print_colors = false; + bool print_progress = false; + bool no_timestamps = false; + + std::string language = "en"; + std::string prompt; + std::string model = "models/ggml-base.en.bin"; std::vector fname_inp = {}; }; @@ -94,27 +97,30 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { whisper_print_usage(argc, argv, params); exit(0); } - else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } - else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); } - else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); } - else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); } - else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); } - else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); } - else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); } - else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } - else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } - else if (arg == "-tr" || arg == "--translate") { params.translate = true; } - else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } - else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } - else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; } - else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; } - else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; } - else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } - else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } - else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } - else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } - else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } - else if (arg == "-f" || arg == "--file") { params.fname_inp.push_back(argv[++i]); } + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } + else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); } + else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); } + else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); } + else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); } + else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); } + else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); } + else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } + else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } + else if (arg == "-tr" || arg == "--translate") { params.translate = true; } + else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } + else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } + else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; } + else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; } + else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; } + else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; } + else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } + else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } + else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; } + else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } + else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } + else if ( arg == "--prompt") { params.prompt = argv[++i]; } + else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } + else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); whisper_print_usage(argc, argv, params); @@ -125,33 +131,36 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { return true; } -void whisper_print_usage(int argc, char ** argv, const whisper_params & params) { +void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) { fprintf(stderr, "\n"); fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]); fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); - fprintf(stderr, " -h, --help [default] show this help message and exit\n"); - fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); - fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors); - fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); - fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n); - fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); - fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); - fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); - fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); - fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); - fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); - fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); - fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); - fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false"); - fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false"); - fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false"); - fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); - fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); - fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true"); - fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); - fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); - fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", ""); + fprintf(stderr, " -h, --help [default] show this help message and exit\n"); + fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); + fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors); + fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); + fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n); + fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); + fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); + fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); + fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); + fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); + fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); + fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); + fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); + fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false"); + fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false"); + fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false"); + fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false"); + fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); + fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); + fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false"); + fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true"); + fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str()); + fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str()); + fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); + fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", ""); fprintf(stderr, "\n"); } @@ -200,7 +209,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - std::string speaker = ""; + std::string speaker; if (params.diarize && pcmf32s.size() == 2) { const int64_t n_samples = pcmf32s[0].size(); @@ -266,7 +275,7 @@ bool output_txt(struct whisper_context * ctx, const char * fname) { const int n_segments = whisper_full_n_segments(ctx); for (int i = 0; i < n_segments; ++i) { const char * text = whisper_full_get_segment_text(ctx, i); - fout << text; + fout << text << "\n"; } return true; @@ -319,10 +328,36 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_ return true; } +bool output_csv(struct whisper_context * ctx, const char * fname) { + std::ofstream fout(fname); + if (!fout.is_open()) { + fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); + return false; + } + + fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); + + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) { + const char * text = whisper_full_get_segment_text(ctx, i); + if (text[0] == ' ') + text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character. + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + //need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds. + fout << 10 * t0 << ", " + << 10 * t1 << ", \"" + << text << "\"\n"; + } + + return true; +} + + // karaoke video generation // outputs a bash script that uses ffmpeg to generate a video with the subtitles // TODO: font parameter adjustments -bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec) { +bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & /*params*/, float t_sec) { std::ofstream fout(fname); fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); @@ -371,7 +406,6 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f txt_ul = "\\ \\ "; { - int ncnt = 0; for (int k = 0; k < n; ++k) { const auto & token2 = tokens[k]; @@ -395,8 +429,6 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f txt_ul += "\\ "; } } - - ncnt += txt.size(); } ::replace_all(txt_bg, "'", "\u2019"); @@ -447,7 +479,7 @@ int main(int argc, char ** argv) { return 2; } - if (whisper_lang_id(params.language.c_str()) == -1) { + if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) { fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str()); whisper_print_usage(argc, argv, params); exit(0); @@ -462,6 +494,22 @@ int main(int argc, char ** argv) { return 3; } + // initial prompt + std::vector prompt_tokens; + + if (!params.prompt.empty()) { + prompt_tokens.resize(1024); + prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size())); + + fprintf(stderr, "\n"); + fprintf(stderr, "initial prompt: '%s'\n", params.prompt.c_str()); + fprintf(stderr, "initial tokens: [ "); + for (int i = 0; i < (int) prompt_tokens.size(); ++i) { + fprintf(stderr, "%d ", prompt_tokens[i]); + } + fprintf(stderr, "]\n"); + } + for (int f = 0; f < (int) params.fname_inp.size(); ++f) { const auto fname_inp = params.fname_inp[f]; @@ -486,14 +534,14 @@ int main(int argc, char ** argv) { } } - if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), NULL) == false) { + if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) { fprintf(stderr, "error: failed to open WAV file from stdin\n"); return 4; } fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size()); } - else if (drwav_init_file(&wav, fname_inp.c_str(), NULL) == false) { + else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) { fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str()); return 5; } @@ -528,11 +576,11 @@ int main(int argc, char ** argv) { // convert to mono, float pcmf32.resize(n); if (wav.channels == 1) { - for (int i = 0; i < n; i++) { + for (uint64_t i = 0; i < n; i++) { pcmf32[i] = float(pcm16[i])/32768.0f; } } else { - for (int i = 0; i < n; i++) { + for (uint64_t i = 0; i < n; i++) { pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f; } } @@ -543,7 +591,7 @@ int main(int argc, char ** argv) { pcmf32s[0].resize(n); pcmf32s[1].resize(n); - for (int i = 0; i < n; i++) { + for (uint64_t i = 0; i < n; i++) { pcmf32s[0][i] = float(pcm16[2*i])/32768.0f; pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f; } @@ -577,13 +625,12 @@ int main(int argc, char ** argv) { fprintf(stderr, "\n"); } - // run the inference { whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); wparams.print_realtime = false; - wparams.print_progress = false; + wparams.print_progress = params.print_progress; wparams.print_timestamps = !params.no_timestamps; wparams.print_special = params.print_special; wparams.translate = params.translate; @@ -599,6 +646,9 @@ int main(int argc, char ** argv) { wparams.speed_up = params.speed_up; + wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data(); + wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size(); + whisper_print_user_data user_data = { ¶ms, &pcmf32s }; // this callback is called on each new segment @@ -613,7 +663,7 @@ int main(int argc, char ** argv) { { static bool is_aborted = false; // NOTE: this should be atomic to avoid data race - wparams.encoder_begin_callback = [](struct whisper_context * ctx, void * user_data) { + wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) { bool is_aborted = *(bool*)user_data; return !is_aborted; }; @@ -653,6 +703,13 @@ int main(int argc, char ** argv) { const auto fname_wts = fname_inp + ".wts"; output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE); } + + // output to CSV file + if (params.output_csv) { + const auto fname_csv = fname_inp + ".csv"; + output_csv(ctx, fname_csv.c_str()); + } + } } diff --git a/examples/whisper/whisper.cpp b/examples/whisper/whisper.cpp index 42467ef..84c2490 100644 --- a/examples/whisper/whisper.cpp +++ b/examples/whisper/whisper.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #define USE_FLASH_ATTN //#define USE_FLASH_FF @@ -203,6 +204,10 @@ struct whisper_vocab { std::map token_to_id; std::map id_to_token; + // used to avoid memory allocations during sampling + // TODO: move to whisper_context in the future + std::vector> probs_id; + id token_eot = 50256; id token_sot = 50257; id token_prev = 50360; @@ -429,6 +434,12 @@ struct whisper_context { int32_t exp_n_audio_ctx; // 0 - use default }; +template +static void read_safe(std::ifstream& fin, T& dest) +{ + fin.read((char*)& dest, sizeof(T)); +} + // load the model from a ggml file // // file format: @@ -455,7 +466,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx // verify magic { uint32_t magic; - fin.read((char *) &magic, sizeof(magic)); + read_safe(fin, magic); if (magic != 0x67676d6c) { fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str()); return false; @@ -466,17 +477,17 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx { auto & hparams = model.hparams; - fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab)); - fin.read((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx)); - fin.read((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state)); - fin.read((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head)); - fin.read((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer)); - fin.read((char *) &hparams.n_text_ctx, sizeof(hparams.n_text_ctx)); - fin.read((char *) &hparams.n_text_state, sizeof(hparams.n_text_state)); - fin.read((char *) &hparams.n_text_head, sizeof(hparams.n_text_head)); - fin.read((char *) &hparams.n_text_layer, sizeof(hparams.n_text_layer)); - fin.read((char *) &hparams.n_mels, sizeof(hparams.n_mels)); - fin.read((char *) &hparams.f16, sizeof(hparams.f16)); + read_safe(fin, hparams.n_vocab); + read_safe(fin, hparams.n_audio_ctx); + read_safe(fin, hparams.n_audio_state); + read_safe(fin, hparams.n_audio_head); + read_safe(fin, hparams.n_audio_layer); + read_safe(fin, hparams.n_text_ctx); + read_safe(fin, hparams.n_text_state); + read_safe(fin, hparams.n_text_head); + read_safe(fin, hparams.n_text_layer); + read_safe(fin, hparams.n_mels); + read_safe(fin, hparams.f16); assert(hparams.n_text_state == hparams.n_audio_state); @@ -524,8 +535,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx { auto & filters = wctx.model.filters; - fin.read((char *) &filters.n_mel, sizeof(filters.n_mel)); - fin.read((char *) &filters.n_fft, sizeof(filters.n_fft)); + read_safe(fin, filters.n_mel); + read_safe(fin, filters.n_fft); filters.data.resize(filters.n_mel * filters.n_fft); fin.read((char *) filters.data.data(), filters.data.size() * sizeof(float)); @@ -534,7 +545,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx // load vocab { int32_t n_vocab = 0; - fin.read((char *) &n_vocab, sizeof(n_vocab)); + read_safe(fin, n_vocab); //if (n_vocab != model.hparams.n_vocab) { // fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n", @@ -543,12 +554,23 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx //} std::string word; + std::vector tmp; + + tmp.reserve(128); + for (int i = 0; i < n_vocab; i++) { uint32_t len; - fin.read((char *) &len, sizeof(len)); + read_safe(fin, len); - word.resize(len); - fin.read((char *) word.data(), len); + if (len > 0) { + tmp.resize(len); + fin.read(&tmp[0], tmp.size()); // read to buffer + word.assign(&tmp[0], tmp.size()); + } else { + // seems like we have an empty-string token in multi-language models (i = 50256) + //fprintf(stderr, "%s: warning: empty-string token in vocab, i = %d\n", __func__, i); + word = ""; + } vocab.token_to_id[word] = i; vocab.id_to_token[i] = word; @@ -588,6 +610,11 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx vocab.id_to_token[i] = word; } } + + wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx); + wctx.probs.reserve(vocab.n_vocab*model.hparams.n_text_ctx); + + vocab.probs_id.reserve(n_vocab); } { @@ -606,7 +633,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32; size_t ctx_size = 0; - size_t ctx_mem_size = 0; { const auto & hparams = model.hparams; @@ -715,12 +741,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b } - ctx_mem_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_k - ctx_mem_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_v - - ctx_mem_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_k - ctx_mem_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_v - ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead fprintf(stderr, "%s: ggml ctx size = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0)); @@ -728,10 +748,9 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx // create the ggml context { - struct ggml_init_params params = { - .mem_size = wctx.buf_model->size(), - .mem_buffer = wctx.buf_model->data(), - }; + struct ggml_init_params params; + params.mem_size = wctx.buf_model->size(); + params.mem_buffer = wctx.buf_model->data(); model.ctx = ggml_init(params); if (!model.ctx) { @@ -938,10 +957,9 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx // create the ggml memory context { - struct ggml_init_params params = { - .mem_size = wctx.buf_memory.size(), - .mem_buffer = wctx.buf_memory.data(), - }; + struct ggml_init_params params; + params.mem_size = wctx.buf_memory.size(); + params.mem_buffer = wctx.buf_memory.data(); model.ctx_mem = ggml_init(params); if (!model.ctx_mem) { @@ -998,9 +1016,9 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx int32_t length; int32_t ftype; - fin.read(reinterpret_cast(&n_dims), sizeof(n_dims)); - fin.read(reinterpret_cast(&length), sizeof(length)); - fin.read(reinterpret_cast(&ftype), sizeof(ftype)); + read_safe(fin, n_dims); + read_safe(fin, length); + read_safe(fin, ftype); if (fin.eof()) { break; @@ -1009,14 +1027,16 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx int32_t nelements = 1; int32_t ne[3] = { 1, 1, 1 }; for (int i = 0; i < n_dims; ++i) { - fin.read(reinterpret_cast(&ne[i]), sizeof(ne[i])); + read_safe(fin, ne[i]); nelements *= ne[i]; } - std::string name(length, 0); - fin.read(&name[0], length); + std::string name; + std::vector tmp(length); // create a buffer + fin.read(&tmp[0], tmp.size()); // read to buffer + name.assign(&tmp[0], tmp.size()); - if (model.tensors.find(name.data()) == model.tensors.end()) { + if (model.tensors.find(name) == model.tensors.end()) { fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data()); return false; } @@ -1088,10 +1108,9 @@ static bool whisper_encode( const int n_mels = hparams.n_mels; assert(mel_inp.n_mel == n_mels); - struct ggml_init_params params = { - .mem_size = wctx.buf_compute.size(), - .mem_buffer = wctx.buf_compute.data(), - }; + struct ggml_init_params params; + params.mem_size = wctx.buf_compute.size(); + params.mem_buffer = wctx.buf_compute.data(); struct ggml_context * ctx0 = ggml_init(params); @@ -1166,10 +1185,9 @@ static bool whisper_encode( // create separate context for each layer to reduce memory usage - struct ggml_init_params paramsL = { - .mem_size = wctx.buf_compute_layer.size(), - .mem_buffer = wctx.buf_compute_layer.data(), - }; + struct ggml_init_params paramsL; + paramsL.mem_size = wctx.buf_compute_layer.size(); + paramsL.mem_buffer = wctx.buf_compute_layer.data(); struct ggml_context * ctxL = ggml_init(paramsL); @@ -1374,8 +1392,8 @@ static bool whisper_encode( // input for next layer (inpO -> inpL) memcpy(inpL->data, inpO->data, ggml_nbytes(inpL)); inpL->op = GGML_OP_NONE; - inpL->src0 = NULL; - inpL->src1 = NULL; + inpL->src0 = nullptr; + inpL->src1 = nullptr; //printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0); @@ -1428,8 +1446,8 @@ static bool whisper_encode( // TODO: hack to disconnect the encoded features from the previous graph cur->op = GGML_OP_NONE; - cur->src0 = NULL; - cur->src1 = NULL; + cur->src0 = nullptr; + cur->src1 = nullptr; for (int il = 0; il < model.hparams.n_text_layer; ++il) { auto & layer = model.layers_decoder[il]; @@ -1503,10 +1521,9 @@ static bool whisper_decode( const int N = n_tokens; const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx; - struct ggml_init_params params = { - .mem_size = wctx.buf_compute.size(), - .mem_buffer = wctx.buf_compute.data(), - }; + struct ggml_init_params params; + params.mem_size = wctx.buf_compute.size(); + params.mem_buffer = wctx.buf_compute.data(); struct ggml_context * ctx0 = ggml_init(params); @@ -1529,10 +1546,9 @@ static bool whisper_decode( for (int il = 0; il < n_layer; ++il) { const auto & layer = model.layers_decoder[il]; - struct ggml_init_params paramsL = { - .mem_size = wctx.buf_compute_layer.size(), - .mem_buffer = wctx.buf_compute_layer.data(), - }; + struct ggml_init_params paramsL; + paramsL.mem_size = wctx.buf_compute_layer.size(); + paramsL.mem_buffer = wctx.buf_compute_layer.data(); struct ggml_context * ctxL = ggml_init(paramsL); struct ggml_cgraph gf = {}; @@ -1788,8 +1804,8 @@ static bool whisper_decode( // input for next layer (inpO -> inpL) memcpy(inpL->data, inpO->data, ggml_nbytes(inpL)); inpL->op = GGML_OP_NONE; - inpL->src0 = NULL; - inpL->src1 = NULL; + inpL->src0 = nullptr; + inpL->src1 = nullptr; if (N > 1) { //printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0); @@ -1845,7 +1861,7 @@ static bool whisper_decode( // the most basic sampling scheme - select the top token static whisper_token_data whisper_sample_best( - const whisper_vocab & vocab, + whisper_vocab & vocab, const float * probs, bool force_timestamp, bool is_initial) { @@ -1853,13 +1869,13 @@ static whisper_token_data whisper_sample_best( 0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f, }; - int n_logits = vocab.id_to_token.size(); + const int n_logits = vocab.n_vocab; - std::vector> probs_id; - probs_id.reserve(n_logits); + auto & probs_id = vocab.probs_id; + probs_id.clear(); for (int i = 0; i < n_logits; i++) { - probs_id.push_back(std::make_pair(probs[i], i)); + probs_id.emplace_back(probs[i], i); } { @@ -1997,6 +2013,9 @@ static void fft(const std::vector & in, std::vector & out) { std::vector even; std::vector odd; + even.reserve(N/2); + odd.reserve(N/2); + for (int i = 0; i < N; i++) { if (i % 2 == 0) { even.push_back(in[i]); @@ -2032,7 +2051,7 @@ static void fft(const std::vector & in, std::vector & out) { static bool log_mel_spectrogram( const float * samples, const int n_samples, - const int sample_rate, + const int /*sample_rate*/, const int fft_size, const int fft_step, const int n_mel, @@ -2151,6 +2170,71 @@ static bool log_mel_spectrogram( return true; } +// split text into tokens +// +// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53 +// +// Regex (Python): +// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" +// +// Regex (C++): +// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)" +// +static std::vector tokenize(const whisper_vocab & vocab, const std::string & text) { + std::vector words; + + // first split the text into words + { + std::string str = text; + std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"; + + std::regex re(pat); + std::smatch m; + + while (std::regex_search(str, m, re)) { + for (auto x : m) { + words.push_back(x); + } + str = m.suffix(); + } + } + + // find the longest tokens that form the words: + std::vector tokens; + for (const auto & word : words) { + if (word.empty()) continue; + + int i = 0; + int n = word.size(); + while (i < n) { + int j = n; + while (j > i) { + auto it = vocab.token_to_id.find(word.substr(i, j-i)); + if (it != vocab.token_to_id.end()) { + tokens.push_back(it->second); + i = j; + break; + } + --j; + } + if (i == n) { + break; + } + if (j == i) { + auto sub = word.substr(i, 1); + if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) { + tokens.push_back(vocab.token_to_id.at(sub)); + } else { + fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data()); + } + ++i; + } + } + } + + return tokens; +} + // // interface implementation // @@ -2166,7 +2250,8 @@ struct whisper_context * whisper_init(const char * path_model) { if (!whisper_model_load(path_model, *ctx)) { fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, path_model); - return NULL; + delete ctx; + return nullptr; } ctx->t_load_us = ggml_time_us() - t_start_us; @@ -2281,8 +2366,38 @@ struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, return res; } +int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) { + const auto res = tokenize(ctx->vocab, text); + + if (n_max_tokens < (int) res.size()) { + fprintf(stderr, "%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens); + return -1; + } + + for (int i = 0; i < (int) res.size(); i++) { + tokens[i] = res[i]; + } + + return res.size(); +} + +int whisper_lang_max_id() { + auto max_id = 0; + for (const auto & kv : g_lang) { + max_id = std::max(max_id, kv.second.first); + } + + return max_id; +} + int whisper_lang_id(const char * lang) { if (!g_lang.count(lang)) { + for (const auto & kv : g_lang) { + if (kv.second.second == lang) { + return kv.second.first; + } + } + fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang); return -1; } @@ -2290,6 +2405,86 @@ int whisper_lang_id(const char * lang) { return g_lang.at(lang).first; } +const char * whisper_lang_str(int id) { + for (const auto & kv : g_lang) { + if (kv.second.first == id) { + return kv.first.c_str(); + } + } + + fprintf(stderr, "%s: unknown language id %d\n", __func__, id); + return nullptr; +} + +int whisper_lang_auto_detect( + struct whisper_context * ctx, + int offset_ms, + int n_threads, + float * lang_probs) { + const int seek = offset_ms/10; + + if (seek < 0) { + fprintf(stderr, "%s: offset %dms is before the start of the audio\n", __func__, offset_ms); + return -1; + } + + if (seek >= ctx->mel.n_len) { + fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, ctx->mel.n_len*10); + return -2; + } + + // run the encoder + if (whisper_encode(ctx, seek, n_threads) != 0) { + fprintf(stderr, "%s: failed to encode\n", __func__); + return -6; + } + + const std::vector prompt = { whisper_token_sot(ctx) }; + + if (whisper_decode(ctx, prompt.data(), prompt.size(), 0, n_threads) != 0) { + fprintf(stderr, "%s: failed to decode\n", __func__); + return -7; + } + + std::vector> probs_id; + for (const auto & kv : g_lang) { + const auto token_lang = whisper_token_lang(ctx, kv.second.first); + probs_id.emplace_back(ctx->probs[token_lang], kv.second.first); + } + + // sort descending + { + using pair_type = decltype(probs_id)::value_type; + std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) { + return a.first > b.first; + }); + } + + // softmax + { + float sum = 0; + for (const auto & kv : probs_id) { + sum += exp(kv.first); + } + + for (auto & kv : probs_id) { + kv.first = exp(kv.first) / sum; + } + } + + { + for (int i = 0; i < (int) probs_id.size(); i++) { + if (lang_probs) { + lang_probs[probs_id[i].second] = probs_id[i].first; + } + + //printf("%s: lang %2d (%3s): %f\n", __func__, probs_id[i].second, whisper_lang_str(probs_id[i].second), probs_id[i].first); + } + } + + return probs_id[0].second; +} + int whisper_n_len(struct whisper_context * ctx) { return ctx->mel.n_len; } @@ -2302,6 +2497,10 @@ int whisper_n_text_ctx(struct whisper_context * ctx) { return ctx->model.hparams.n_text_ctx; } +int whisper_n_audio_ctx(struct whisper_context * ctx) { + return ctx->model.hparams.n_audio_ctx; +} + int whisper_is_multilingual(struct whisper_context * ctx) { return ctx->vocab.is_multilingual() ? 1 : 0; } @@ -2338,6 +2537,10 @@ whisper_token whisper_token_beg(struct whisper_context * ctx) { return ctx->vocab.token_beg; } +whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) { + return whisper_token_sot(ctx) + 1 + lang_id; +} + whisper_token whisper_token_translate(void) { return whisper_vocab::token_translate; } @@ -2371,7 +2574,10 @@ const char * whisper_print_system_info(void) { s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | "; s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | "; s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | "; + s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | "; s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | "; + s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | "; + s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | "; s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | "; s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | "; s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | "; @@ -2569,10 +2775,25 @@ int whisper_full( } else { if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) { fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__); - return -1; + return -2; } } + // auto-detect language if not specified + if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) { + std::vector probs(whisper_lang_max_id() + 1, 0.0f); + + const auto lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, probs.data()); + if (lang_id < 0) { + fprintf(stderr, "%s: failed to auto-detect language\n", __func__); + return -3; + } + + params.language = whisper_lang_str(lang_id); + + fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]); + } + if (params.token_timestamps) { ctx->t_beg = 0; ctx->t_last = 0; @@ -2605,13 +2826,18 @@ int whisper_full( std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end()); } - // overwrite audio_ctx + // overwrite audio_ctx, max allowed is hparams.n_audio_ctx + if (params.audio_ctx > whisper_n_audio_ctx(ctx)) { + fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx)); + return -4; + } ctx->exp_n_audio_ctx = params.audio_ctx; // these tokens determine the task that will be performed std::vector prompt_init = { whisper_token_sot(ctx) }; if (whisper_is_multilingual(ctx)) { - prompt_init.push_back(whisper_token_sot(ctx) + 1 + whisper_lang_id(params.language)); + const int lang_id = whisper_lang_id(params.language); + prompt_init.push_back(whisper_token_lang(ctx, lang_id)); if (params.translate) { prompt_init.push_back(whisper_token_translate()); } else { @@ -2639,10 +2865,17 @@ int whisper_full( } } + // of only 1 second left, then stop if (seek + 100 >= seek_end) { break; } + // if there is a very short audio segment left to process, we remove any past prompt since it tends + // to confuse the decoder and often make it repeat or hallucinate stuff + if (seek > seek_start && seek + 500 >= seek_end) { + prompt_past.clear(); + } + if (params.encoder_begin_callback) { if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) { fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__); @@ -2653,14 +2886,14 @@ int whisper_full( // encode audio features starting at offset seek if (whisper_encode(ctx, seek, params.n_threads) != 0) { fprintf(stderr, "%s: failed to encode\n", __func__); - return 7; + return -4; } int n_past = 0; prompt.clear(); // if we have already generated some text, use it as a prompt to condition the next generation - if (prompt_past.size() > 0) { + if (!prompt_past.empty()) { int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size())); prompt = { whisper_token_prev(ctx) }; @@ -2686,11 +2919,12 @@ int whisper_full( tokens_cur.clear(); bool failed = false; + bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment? for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) { fprintf(stderr, "%s: failed to decode\n", __func__); - return 8; + return -5; } n_past += prompt.size(); @@ -2711,13 +2945,13 @@ int whisper_full( const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx)); // do not allow to go back in time - if (seek_delta != 100*WHISPER_CHUNK_SIZE && - seek_delta > seek_delta_new && result_len < i) { + if (has_ts && seek_delta > seek_delta_new && result_len < i) { break; } seek_delta = seek_delta_new; result_len = i + 1; + has_ts = true; } // add it to the context @@ -2726,11 +2960,14 @@ int whisper_full( //{ // const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]"; - // printf("%s: %10s %6d %6.3f '%s'\n", __func__, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str()); + // printf("%s: %3d %10s %6d %6.3f '%s'\n", __func__, i, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str()); //} - // end of text token - if (token.id == whisper_token_eot(ctx) || (params.max_tokens > 0 && i > params.max_tokens)) { + // end of segment + if (token.id == whisper_token_eot(ctx) || // end of text token + (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached + (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached + ) { if (result_len == 0) { if (seek + seek_delta + 100 >= seek_end) { result_len = i + 1; @@ -2765,8 +3002,14 @@ int whisper_full( } if (failed) { - fprintf(stderr, "\n%s: failed to generate timestamp token - using fallback strategy\n\n", __func__); - seek += 100; + // when we fail to sample timestamp token, retry by clearing the past prompt + // if it fails again, then we advance the window by 1 second + if (!prompt_past.empty()) { + prompt_past.clear(); + } else { + fprintf(stderr, "\n%s: failed to generate timestamp token - skipping one second\n\n", __func__); + seek += 100; + } continue; } @@ -2778,11 +3021,11 @@ int whisper_full( } // store the text from this iteration - if (tokens_cur.size() > 0) { + if (!tokens_cur.empty()) { int i0 = 0; auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx)); - std::string text = ""; + std::string text; for (int i = 0; i < (int) tokens_cur.size(); i++) { //printf("%s: %18s %6.3f %18s %6.3f\n", __func__, @@ -2901,10 +3144,9 @@ int whisper_full_parallel( // create the ggml memory context { - struct ggml_init_params params = { - .mem_size = ctxs[i].buf_memory.size(), - .mem_buffer = ctxs[i].buf_memory.data(), - }; + struct ggml_init_params params; + params.mem_size = ctxs[i].buf_memory.size(); + params.mem_buffer = ctxs[i].buf_memory.data(); model.ctx_mem = ggml_init(params); if (!model.ctx_mem) { @@ -2990,7 +3232,7 @@ int whisper_full_parallel( results_i[j].t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t; // make sure that segments are not overlapping - if (ctx->result_all.size() > 0) { + if (!ctx->result_all.empty()) { results_i[j].t0 = std::max(results_i[j].t0, ctx->result_all.back().t1); } diff --git a/examples/whisper/whisper.h b/examples/whisper/whisper.h index def77d4..e36b761 100644 --- a/examples/whisper/whisper.h +++ b/examples/whisper/whisper.h @@ -139,12 +139,45 @@ extern "C" { WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx); WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial); + // Convert the provided text into tokens. + // The tokens pointer must be large enough to hold the resulting tokens. + // Returns the number of tokens on success, no more than n_max_tokens + // Returns -1 on failure + // TODO: not sure if correct + WHISPER_API int whisper_tokenize( + struct whisper_context * ctx, + const char * text, + whisper_token * tokens, + int n_max_tokens); + + // Largest language id (i.e. number of available languages - 1) + WHISPER_API int whisper_lang_max_id(); + // Return the id of the specified language, returns -1 if not found + // Examples: + // "de" -> 2 + // "german" -> 2 WHISPER_API int whisper_lang_id(const char * lang); + // Return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found + WHISPER_API const char * whisper_lang_str(int id); + + // Use mel data at offset_ms to try and auto-detect the spoken language + // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first + // Returns the top language id or negative on failure + // If not null, fills the lang_probs array with the probabilities of all languages + // The array must be whispe_lang_max_id() + 1 in size + // ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69 + WHISPER_API int whisper_lang_auto_detect( + struct whisper_context * ctx, + int offset_ms, + int n_threads, + float * lang_probs); + WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length WHISPER_API int whisper_n_vocab (struct whisper_context * ctx); WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx); + WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx); WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx); // The probabilities for the next token @@ -160,6 +193,7 @@ extern "C" { WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx); WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx); WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id); // Task tokens WHISPER_API whisper_token whisper_token_translate (void); @@ -225,6 +259,7 @@ extern "C" { const whisper_token * prompt_tokens; int prompt_n_tokens; + // for auto-detection, set to nullptr, "" or "auto" const char * language; struct { diff --git a/include/ggml/ggml.h b/include/ggml/ggml.h index 3e4e962..a217d2d 100644 --- a/include/ggml/ggml.h +++ b/include/ggml/ggml.h @@ -681,34 +681,32 @@ struct ggml_opt_params { bool print_forward_graph; bool print_backward_graph; - union { - // ADAM parameters - struct { - int n_iter; - - float alpha; // learning rate - float beta1; - float beta2; - float eps; // epsilon for numerical stability - float eps_f; // epsilon for convergence test - float eps_g; // epsilon for convergence test - } adam; - - // LBFGS parameters - struct { - int m; // number of corrections to approximate the inv. Hessian - int n_iter; - int max_linesearch; - - float eps; // convergence tolerance - float ftol; // line search tolerance - float wolfe; - float min_step; - float max_step; - - enum ggml_linesearch linesearch; - } lbfgs; - }; + // ADAM parameters + struct { + int n_iter; + + float alpha; // learning rate + float beta1; + float beta2; + float eps; // epsilon for numerical stability + float eps_f; // epsilon for convergence test + float eps_g; // epsilon for convergence test + } adam; + + // LBFGS parameters + struct { + int m; // number of corrections to approximate the inv. Hessian + int n_iter; + int max_linesearch; + + float eps; // convergence tolerance + float ftol; // line search tolerance + float wolfe; + float min_step; + float max_step; + + enum ggml_linesearch linesearch; + } lbfgs; }; struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type); @@ -726,7 +724,10 @@ enum ggml_opt_result ggml_opt( int ggml_cpu_has_avx(void); int ggml_cpu_has_avx2(void); int ggml_cpu_has_avx512(void); +int ggml_cpu_has_fma(void); int ggml_cpu_has_neon(void); +int ggml_cpu_has_arm_fma(void); +int ggml_cpu_has_f16c(void); int ggml_cpu_has_fp16_va(void); int ggml_cpu_has_wasm_simd(void); int ggml_cpu_has_blas(void); diff --git a/src/ggml.c b/src/ggml.c index b6d528d..7d2f465 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -14,6 +14,12 @@ #include #include +// if C99 - static_assert is noop +// ref: https://stackoverflow.com/a/53923785/4039976 +#ifndef static_assert +#define static_assert(cond, msg) struct global_scope_noop_trick +#endif + #if defined _MSC_VER || defined(__MINGW32__) #if !defined(__MINGW32__) @@ -69,6 +75,10 @@ static int sched_yield (void) { typedef void* thread_ret_t; #endif +#ifdef __HAIKU__ +#define static_assert(cond, msg) _Static_assert(cond, msg) +#endif + #define GGML_DEBUG 0 #define GGML_GELU_FP16 @@ -120,13 +130,35 @@ ggml_fp16_t ggml_fp32_to_fp16(float x) { return x; } +#define GGML_FP16_TO_FP32(x) (x) +#define GGML_FP32_TO_FP16(x) (x) + #else #ifdef __wasm_simd128__ #include #else +#ifdef __POWER9_VECTOR__ +#include +#undef bool +#define bool _Bool +#else #include #endif +#endif + +#ifdef __F16C__ +float ggml_fp16_to_fp32(ggml_fp16_t h) { + return _cvtsh_ss(h); +} +ggml_fp16_t ggml_fp32_to_fp16(float f) { + return _cvtss_sh(f, 0); +} + +#define GGML_FP16_TO_FP32(x) _cvtsh_ss(x) +#define GGML_FP32_TO_FP16(x) _cvtss_sh(x, 0) + +#else // FP16 <-> FP32 // ref: https://github.com/Maratyszcza/FP16 @@ -135,7 +167,8 @@ static inline float fp32_from_bits(uint32_t w) { union { uint32_t as_bits; float as_value; - } fp32 = { w }; + } fp32; + fp32.as_bits = w; return fp32.as_value; } @@ -143,7 +176,8 @@ static inline uint32_t fp32_to_bits(float f) { union { float as_value; uint32_t as_bits; - } fp32 = { f }; + } fp32; + fp32.as_value = f; return fp32.as_bits; } @@ -195,7 +229,13 @@ ggml_fp16_t ggml_fp32_to_fp16(float f) { const uint32_t nonsign = exp_bits + mantissa_bits; return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign); } -#endif + +#define GGML_FP16_TO_FP32(x) ggml_fp16_to_fp32(x) +#define GGML_FP32_TO_FP16(x) ggml_fp32_to_fp16(x) + +#endif // __F16C__ + +#endif // __ARM_NEON // // global data @@ -273,196 +313,429 @@ int64_t ggml_cycles_per_ms(void) { #define CACHE_LINE_SIZE 64 #endif -const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); +static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); // -// fundamental operations +// simd mappings // -inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - -inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - -inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - -inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - -inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; } -inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; } -inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; } -inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; } -inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; } -inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; } -inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; } -inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } -inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } - -inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) { - ggml_float sumf = 0.0; -#ifdef __ARM_NEON - // NEON 128-bit - const int n16 = (n & ~15); - - float32x4_t sum0 = vdupq_n_f32(0); - float32x4_t sum1 = vdupq_n_f32(0); - float32x4_t sum2 = vdupq_n_f32(0); - float32x4_t sum3 = vdupq_n_f32(0); - - float32x4_t x0, x1, x2, x3; - float32x4_t y0, y1, y2, y3; - - for (int i = 0; i < n16; i += 16) { - x0 = vld1q_f32(x + i + 0); - x1 = vld1q_f32(x + i + 4); - x2 = vld1q_f32(x + i + 8); - x3 = vld1q_f32(x + i + 12); - - y0 = vld1q_f32(y + i + 0); - y1 = vld1q_f32(y + i + 4); - y2 = vld1q_f32(y + i + 8); - y3 = vld1q_f32(y + i + 12); +// we define a common set of C macros which map to specific intrinsics based on the current architecture +// we then implement the fundamental computation operations below using only these macros +// adding support for new architectures requires to define the corresponding SIMD macros +// +// GGML_F32_STEP / GGML_F16_STEP +// number of elements to process in a single step +// +// GGML_F32_EPR / GGML_F16_EPR +// number of elements to fit in a single register +// - sum0 = vfmaq_f32(sum0, x0, y0); - sum1 = vfmaq_f32(sum1, x1, y1); - sum2 = vfmaq_f32(sum2, x2, y2); - sum3 = vfmaq_f32(sum3, x3, y3); - } +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA) - // reduce sum0..sum3 to sum0 - sum0 = vaddq_f32(sum0, sum1); - sum2 = vaddq_f32(sum2, sum3); - sum0 = vaddq_f32(sum0, sum2); +#define GGML_SIMD - float32x2_t sumf32 = vadd_f32(vget_low_f32(sum0), vget_high_f32(sum0)); - sumf = vget_lane_f32(sumf32, 0) + vget_lane_f32(sumf32, 1); +// F32 NEON - // leftovers - for (int i = n16; i < n; ++i) { - sumf += x[i]*y[i]; - } -#elif defined(__AVX2__) - // AVX 256-bit - const int n32 = (n & ~31); +#define GGML_F32_STEP 16 +#define GGML_F32_EPR 4 - __m256 sum0 = _mm256_setzero_ps(); - __m256 sum1 = _mm256_setzero_ps(); - __m256 sum2 = _mm256_setzero_ps(); - __m256 sum3 = _mm256_setzero_ps(); +#define GGML_F32x4 float32x4_t +#define GGML_F32x4_ZERO vdupq_n_f32(0.0f) +#define GGML_F32x4_SET1(x) vdupq_n_f32(x) +#define GGML_F32x4_LOAD vld1q_f32 +#define GGML_F32x4_STORE vst1q_f32 +#define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c) +#define GGML_F32x4_ADD vaddq_f32 +#define GGML_F32x4_MUL vmulq_f32 +#if defined(__ARM_FEATURE_QRDMX) + #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x) +#else + #define GGML_F32x4_REDUCE_ONE(x) \ + (vgetq_lane_f32(x, 0) + \ + vgetq_lane_f32(x, 1) + \ + vgetq_lane_f32(x, 2) + \ + vgetq_lane_f32(x, 3)) +#endif +#define GGML_F32x4_REDUCE(res, x) \ +{ \ + for (int i = 0; i < GGML_F32_ARR/2; ++i) { \ + x[2*i] = vaddq_f32(x[2*i], x[2*i+1]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/4; ++i) { \ + x[4*i] = vaddq_f32(x[4*i], x[4*i+2]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/8; ++i) { \ + x[8*i] = vaddq_f32(x[8*i], x[8*i+4]); \ + } \ + res = GGML_F32x4_REDUCE_ONE(x[0]); \ +} + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + +// F16 NEON - __m256 x0, x1, x2, x3; - __m256 y0, y1, y2, y3; +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + #define GGML_F16_STEP 32 + #define GGML_F16_EPR 8 + + #define GGML_F16x8 float16x8_t + #define GGML_F16x8_ZERO vdupq_n_f16(0.0f) + #define GGML_F16x8_SET1(x) vdupq_n_f16(x) + #define GGML_F16x8_LOAD vld1q_f16 + #define GGML_F16x8_STORE vst1q_f16 + #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c) + #define GGML_F16x8_ADD vaddq_f16 + #define GGML_F16x8_MUL vmulq_f16 + #define GGML_F16x8_REDUCE(res, x) \ + { \ + for (int i = 0; i < GGML_F16_ARR/2; ++i) { \ + x[2*i] = vaddq_f16(x[2*i], x[2*i+1]); \ + } \ + for (int i = 0; i < GGML_F16_ARR/4; ++i) { \ + x[4*i] = vaddq_f16(x[4*i], x[4*i+2]); \ + } \ + for (int i = 0; i < GGML_F16_ARR/8; ++i) { \ + x[8*i] = vaddq_f16(x[8*i], x[8*i+4]); \ + } \ + const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \ + const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \ + res = vaddvq_f32(vaddq_f32(t0, t1)); \ + } + + #define GGML_F16_VEC GGML_F16x8 + #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO + #define GGML_F16_VEC_SET1 GGML_F16x8_SET1 + #define GGML_F16_VEC_LOAD GGML_F16x8_LOAD + #define GGML_F16_VEC_STORE GGML_F16x8_STORE + #define GGML_F16_VEC_FMA GGML_F16x8_FMA + #define GGML_F16_VEC_ADD GGML_F16x8_ADD + #define GGML_F16_VEC_MUL GGML_F16x8_MUL + #define GGML_F16_VEC_REDUCE GGML_F16x8_REDUCE +#else + // if FP16 vector arithmetic is not supported, we use FP32 instead + // and take advantage of the vcvt_ functions to convert to/from FP16 + + #define GGML_F16_STEP 16 + #define GGML_F16_EPR 4 + + #define GGML_F32Cx4 float32x4_t + #define GGML_F32Cx4_ZERO vdupq_n_f32(0.0f) + #define GGML_F32Cx4_SET1(x) vdupq_n_f32(x) + #define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16(x)) + #define GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y)) + #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c) + #define GGML_F32Cx4_ADD vaddq_f32 + #define GGML_F32Cx4_MUL vmulq_f32 + #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE + + #define GGML_F16_VEC GGML_F32Cx4 + #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO + #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 + #define GGML_F16_VEC_LOAD GGML_F32Cx4_LOAD + #define GGML_F16_VEC_STORE GGML_F32Cx4_STORE + #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA + #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD + #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL + #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE +#endif - for (int i = 0; i < n32; i += 32) { - x0 = _mm256_loadu_ps(x + i + 0); - x1 = _mm256_loadu_ps(x + i + 8); - x2 = _mm256_loadu_ps(x + i + 16); - x3 = _mm256_loadu_ps(x + i + 24); +#elif defined(__AVX__) - y0 = _mm256_loadu_ps(y + i + 0); - y1 = _mm256_loadu_ps(y + i + 8); - y2 = _mm256_loadu_ps(y + i + 16); - y3 = _mm256_loadu_ps(y + i + 24); +#define GGML_SIMD - sum0 = _mm256_fmadd_ps(x0, y0, sum0); - sum1 = _mm256_fmadd_ps(x1, y1, sum1); - sum2 = _mm256_fmadd_ps(x2, y2, sum2); - sum3 = _mm256_fmadd_ps(x3, y3, sum3); - } +// F32 AVX - sum0 = _mm256_add_ps(sum0, sum1); - sum2 = _mm256_add_ps(sum2, sum3); - sum0 = _mm256_add_ps(sum0, sum2); +#define GGML_F32_STEP 32 +#define GGML_F32_EPR 8 - const __m128 r4 = _mm_add_ps(_mm256_castps256_ps128(sum0), _mm256_extractf128_ps(sum0, 1)); - const __m128 r2 = _mm_add_ps(r4, _mm_movehl_ps(r4, r4)); - const __m128 r1 = _mm_add_ss(r2, _mm_movehdup_ps(r2)); +#define GGML_F32x8 __m256 +#define GGML_F32x8_ZERO _mm256_setzero_ps() +#define GGML_F32x8_SET1(x) _mm256_set1_ps(x) +#define GGML_F32x8_LOAD _mm256_loadu_ps +#define GGML_F32x8_STORE _mm256_storeu_ps +#if defined(__FMA__) + #define GGML_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a) +#else + #define GGML_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a) +#endif +#define GGML_F32x8_ADD _mm256_add_ps +#define GGML_F32x8_MUL _mm256_mul_ps +#define GGML_F32x8_REDUCE(res, x) \ +{ \ + for (int i = 0; i < GGML_F32_ARR/2; ++i) { \ + x[2*i] = _mm256_add_ps(x[2*i], x[2*i+1]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/4; ++i) { \ + x[4*i] = _mm256_add_ps(x[4*i], x[4*i+2]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/8; ++i) { \ + x[8*i] = _mm256_add_ps(x[8*i], x[8*i+4]); \ + } \ + const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), \ + _mm256_extractf128_ps(x[0], 1)); \ + const __m128 t1 = _mm_hadd_ps(t0, t0); \ + res = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); \ +} +// TODO: is this optimal ? + +#define GGML_F32_VEC GGML_F32x8 +#define GGML_F32_VEC_ZERO GGML_F32x8_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x8_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x8_LOAD +#define GGML_F32_VEC_STORE GGML_F32x8_STORE +#define GGML_F32_VEC_FMA GGML_F32x8_FMA +#define GGML_F32_VEC_ADD GGML_F32x8_ADD +#define GGML_F32_VEC_MUL GGML_F32x8_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE + +// F16 AVX + +#define GGML_F16_STEP 32 +#define GGML_F16_EPR 8 + +// F16 arithmetic is not supported by AVX, so we use F32 instead +// we take advantage of the _mm256_cvt intrinsics to convert F16 <-> F32 + +#define GGML_F32Cx8 __m256 +#define GGML_F32Cx8_ZERO _mm256_setzero_ps() +#define GGML_F32Cx8_SET1(x) _mm256_set1_ps(x) +#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x))) +#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0)) +#define GGML_F32Cx8_FMA GGML_F32x8_FMA +#define GGML_F32Cx8_ADD _mm256_add_ps +#define GGML_F32Cx8_MUL _mm256_mul_ps +#define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE + +#define GGML_F16_VEC GGML_F32Cx8 +#define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO +#define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1 +#define GGML_F16_VEC_LOAD GGML_F32Cx8_LOAD +#define GGML_F16_VEC_STORE GGML_F32Cx8_STORE +#define GGML_F16_VEC_FMA GGML_F32Cx8_FMA +#define GGML_F16_VEC_ADD GGML_F32Cx8_ADD +#define GGML_F16_VEC_MUL GGML_F32Cx8_MUL +#define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE + +#elif defined(__POWER9_VECTOR__) + +// TODO: uncomment this when it works +//#define GGML_SIMD + +// F32 POWER9 + +#define GGML_F32_STEP 32 +#define GGML_F32_EPR 8 + +// TODO: not tested !! +#define GGML_F32x4 __vector float +#define GGML_F32x4_ZERO (__vector float){0.0f, 0.0f, 0.0f, 0.0f} +#define GGML_F32x4_SET1(x) (__vector float){x, x, x, x} +#define GGML_F32x4_LOAD vec_vsx_ld +#define GGML_F32x4_STORE vec_vsx_st +#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a) +#define GGML_F32x4_ADD vec_add +#define GGML_F32x4_MUL vec_mul +#define GGML_F32x4_REDUCE(res, x) \ +{ \ + for (int i = 0; i < GGML_F32_ARR/2; ++i) { \ + x[2*i] = vec_add(x[2*i], x[2*i+1]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/4; ++i) { \ + x[4*i] = vec_add(x[4*i], x[4*i+2]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/8; ++i) { \ + x[8*i] = vec_add(x[8*i], x[8*i+4]); \ + } \ + res = vec_extract(x[0], 0) + \ + vec_extract(x[0], 1) + \ + vec_extract(x[0], 2) + \ + vec_extract(x[0], 3); \ +} + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + +// F16 POWER9 +// TODO: implement here +// ... - sumf = _mm_cvtss_f32(r1); +#elif defined(__wasm_simd128__) - // leftovers - for (int i = n32; i < n; ++i) { - sumf += x[i]*y[i]; - } -#elif defined(__AVX__) - // AVX 256-bit - const int n32 = (n & ~31); +#define GGML_SIMD + +// F32 WASM + +#define GGML_F32_STEP 16 +#define GGML_F32_EPR 4 + +#define GGML_F32x4 v128_t +#define GGML_F32x4_ZERO wasm_f32x4_splat(0.0f) +#define GGML_F32x4_SET1(x) wasm_f32x4_splat(x) +#define GGML_F32x4_LOAD wasm_v128_load +#define GGML_F32x4_STORE wasm_v128_store +#define GGML_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a) +#define GGML_F32x4_ADD wasm_f32x4_add +#define GGML_F32x4_MUL wasm_f32x4_mul +#define GGML_F32x4_REDUCE(res, x) \ +{ \ + for (int i = 0; i < GGML_F32_ARR/2; ++i) { \ + x[2*i] = wasm_f32x4_add(x[2*i], x[2*i+1]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/4; ++i) { \ + x[4*i] = wasm_f32x4_add(x[4*i], x[4*i+2]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/8; ++i) { \ + x[8*i] = wasm_f32x4_add(x[8*i], x[8*i+4]); \ + } \ + res = wasm_f32x4_extract_lane(x[0], 0) + \ + wasm_f32x4_extract_lane(x[0], 1) + \ + wasm_f32x4_extract_lane(x[0], 2) + \ + wasm_f32x4_extract_lane(x[0], 3); \ +} + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + +// F16 WASM + +#define GGML_F16_STEP 16 +#define GGML_F16_EPR 4 + +inline static v128_t __wasm_f16x4_load(const ggml_fp16_t * p) { + float tmp[4]; + + tmp[0] = GGML_FP16_TO_FP32(p[0]); + tmp[1] = GGML_FP16_TO_FP32(p[1]); + tmp[2] = GGML_FP16_TO_FP32(p[2]); + tmp[3] = GGML_FP16_TO_FP32(p[3]); + + return wasm_v128_load(tmp); +} + +inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) { + float tmp[4]; + + wasm_v128_store(tmp, x); + + p[0] = GGML_FP32_TO_FP16(tmp[0]); + p[1] = GGML_FP32_TO_FP16(tmp[1]); + p[2] = GGML_FP32_TO_FP16(tmp[2]); + p[3] = GGML_FP32_TO_FP16(tmp[3]); +} + +#define GGML_F16x4 v128_t +#define GGML_F16x4_ZERO wasm_f32x4_splat(0.0f) +#define GGML_F16x4_SET1(x) wasm_f32x4_splat(x) +#define GGML_F16x4_LOAD(x) __wasm_f16x4_load(x) +#define GGML_F16x4_STORE(x, y) __wasm_f16x4_store(x, y) +#define GGML_F16x4_FMA GGML_F32x4_FMA +#define GGML_F16x4_ADD wasm_f32x4_add +#define GGML_F16x4_MUL wasm_f32x4_mul +#define GGML_F16x4_REDUCE(res, x) \ +{ \ + for (int i = 0; i < GGML_F16_ARR/2; ++i) { \ + x[2*i] = wasm_f32x4_add(x[2*i], x[2*i+1]); \ + } \ + for (int i = 0; i < GGML_F16_ARR/4; ++i) { \ + x[4*i] = wasm_f32x4_add(x[4*i], x[4*i+2]); \ + } \ + for (int i = 0; i < GGML_F16_ARR/8; ++i) { \ + x[8*i] = wasm_f32x4_add(x[8*i], x[8*i+4]); \ + } \ + res = wasm_f32x4_extract_lane(x[0], 0) + \ + wasm_f32x4_extract_lane(x[0], 1) + \ + wasm_f32x4_extract_lane(x[0], 2) + \ + wasm_f32x4_extract_lane(x[0], 3); \ +} + +#define GGML_F16_VEC GGML_F16x4 +#define GGML_F16_VEC_ZERO GGML_F16x4_ZERO +#define GGML_F16_VEC_SET1 GGML_F16x4_SET1 +#define GGML_F16_VEC_LOAD GGML_F16x4_LOAD +#define GGML_F16_VEC_STORE GGML_F16x4_STORE +#define GGML_F16_VEC_FMA GGML_F16x4_FMA +#define GGML_F16_VEC_ADD GGML_F16x4_ADD +#define GGML_F16_VEC_MUL GGML_F16x4_MUL +#define GGML_F16_VEC_REDUCE GGML_F16x4_REDUCE - __m256 sum0 = _mm256_setzero_ps(); - __m256 sum1 = _mm256_setzero_ps(); - __m256 sum2 = _mm256_setzero_ps(); - __m256 sum3 = _mm256_setzero_ps(); +#endif - __m256 x0, x1, x2, x3; - __m256 y0, y1, y2, y3; +// GGML_F32_ARR / GGML_F16_ARR +// number of registers to use per step +#ifdef GGML_SIMD +#define GGML_F32_ARR (GGML_F32_STEP/GGML_F32_EPR) +#define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR) +#endif - for (int i = 0; i < n32; i += 32) { - x0 = _mm256_loadu_ps(x + i + 0); - x1 = _mm256_loadu_ps(x + i + 8); - x2 = _mm256_loadu_ps(x + i + 16); - x3 = _mm256_loadu_ps(x + i + 24); +// +// fundamental operations +// - y0 = _mm256_loadu_ps(y + i + 0); - y1 = _mm256_loadu_ps(y + i + 8); - y2 = _mm256_loadu_ps(y + i + 16); - y3 = _mm256_loadu_ps(y + i + 24); +inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - sum0 = _mm256_add_ps(_mm256_mul_ps(x0, y0), sum0); - sum1 = _mm256_add_ps(_mm256_mul_ps(x1, y1), sum1); - sum2 = _mm256_add_ps(_mm256_mul_ps(x2, y2), sum2); - sum3 = _mm256_add_ps(_mm256_mul_ps(x3, y3), sum3); - } +inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - sum0 = _mm256_add_ps(sum0, sum1); - sum2 = _mm256_add_ps(sum2, sum3); - sum0 = _mm256_add_ps(sum0, sum2); +inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - const __m128 r4 = _mm_add_ps(_mm256_castps256_ps128(sum0), _mm256_extractf128_ps(sum0, 1)); - const __m128 r2 = _mm_add_ps(r4, _mm_movehl_ps(r4, r4)); - const __m128 r1 = _mm_add_ss(r2, _mm_movehdup_ps(r2)); +inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - sumf = _mm_cvtss_f32(r1); +inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; } +inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; } +inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; } +inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; } +inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; } +inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; } +inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; } +inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } +inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } - // leftovers - for (int i = n32; i < n; ++i) { - sumf += x[i]*y[i]; - } -#elif defined(__wasm_simd128__) - // WASM 128-bit - const int n16 = (n & ~15); +inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) { + ggml_float sumf = 0.0; - v128_t sum0 = wasm_f32x4_splat(0); - v128_t sum1 = wasm_f32x4_splat(0); - v128_t sum2 = wasm_f32x4_splat(0); - v128_t sum3 = wasm_f32x4_splat(0); +#ifdef GGML_SIMD + const int np = (n & ~(GGML_F32_STEP - 1)); - v128_t x0, x1, x2, x3; - v128_t y0, y1, y2, y3; + GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; - for (int i = 0; i < n16; i += 16) { - x0 = wasm_v128_load(x + i + 0); - x1 = wasm_v128_load(x + i + 4); - x2 = wasm_v128_load(x + i + 8); - x3 = wasm_v128_load(x + i + 12); + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; - y0 = wasm_v128_load(y + i + 0); - y1 = wasm_v128_load(y + i + 4); - y2 = wasm_v128_load(y + i + 8); - y3 = wasm_v128_load(y + i + 12); + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); - sum0 = wasm_f32x4_add(sum0, wasm_f32x4_mul(x0, y0)); - sum1 = wasm_f32x4_add(sum1, wasm_f32x4_mul(x1, y1)); - sum2 = wasm_f32x4_add(sum2, wasm_f32x4_mul(x2, y2)); - sum3 = wasm_f32x4_add(sum3, wasm_f32x4_mul(x3, y3)); + sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]); + } } - sum0 = wasm_f32x4_add(sum0, sum1); - sum2 = wasm_f32x4_add(sum2, sum3); - sum0 = wasm_f32x4_add(sum0, sum2); - - sumf = wasm_f32x4_extract_lane(sum0, 0) + wasm_f32x4_extract_lane(sum0, 1) + wasm_f32x4_extract_lane(sum0, 2) + wasm_f32x4_extract_lane(sum0, 3); + // reduce sum0..sum3 to sum0 + GGML_F32_VEC_REDUCE(sumf, sum); // leftovers - for (int i = n16; i < n; ++i) { + for (int i = np; i < n; ++i) { sumf += x[i]*y[i]; } #else @@ -477,248 +750,87 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) { ggml_float sumf = 0.0; -#ifdef __ARM_NEON - const int n32 = (n & ~31); -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) - float16x8_t sum0 = vdupq_n_f16(0); - float16x8_t sum1 = vdupq_n_f16(0); - float16x8_t sum2 = vdupq_n_f16(0); - float16x8_t sum3 = vdupq_n_f16(0); +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); - float16x8_t x0, x1, x2, x3; - float16x8_t y0, y1, y2, y3; + GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO }; - for (int i = 0; i < n32; i += 32) { - x0 = vld1q_f16(x + i + 0 ); - x1 = vld1q_f16(x + i + 8 ); - x2 = vld1q_f16(x + i + 16); - x3 = vld1q_f16(x + i + 24); + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; - y0 = vld1q_f16(y + i + 0 ); - y1 = vld1q_f16(y + i + 8 ); - y2 = vld1q_f16(y + i + 16); - y3 = vld1q_f16(y + i + 24); + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR); + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR); - sum0 = vfmaq_f16(sum0, x0, y0); - sum1 = vfmaq_f16(sum1, x1, y1); - sum2 = vfmaq_f16(sum2, x2, y2); - sum3 = vfmaq_f16(sum3, x3, y3); + sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]); + } } // reduce sum0..sum3 to sum0 - sum0 = vaddq_f16(sum0, sum1); - sum2 = vaddq_f16(sum2, sum3); - sum0 = vaddq_f16(sum0, sum2); - - // load sum0 into 2 float32x4_t - float32x4_t sum0f32 = vcvt_f32_f16(vget_low_f16(sum0)); - float32x4_t sum1f32 = vcvt_f32_f16(vget_high_f16(sum0)); - - // reduce sum0f32 and sum1f32 to sumf - sum0f32 = vaddq_f32(sum0f32, sum1f32); - - float32x2_t sumf32 = vadd_f32(vget_low_f32(sum0f32), vget_high_f32(sum0f32)); - sumf = vget_lane_f32(sumf32, 0) + vget_lane_f32(sumf32, 1); -#else - float32x4_t sum0 = vdupq_n_f32(0); - float32x4_t sum1 = vdupq_n_f32(0); - float32x4_t sum2 = vdupq_n_f32(0); - float32x4_t sum3 = vdupq_n_f32(0); - float32x4_t sum4 = vdupq_n_f32(0); - float32x4_t sum5 = vdupq_n_f32(0); - float32x4_t sum6 = vdupq_n_f32(0); - float32x4_t sum7 = vdupq_n_f32(0); - - float32x4_t x0, x1, x2, x3, x4, x5, x6, x7; - float32x4_t y0, y1, y2, y3, y4, y5, y6, y7; - - for (int i = 0; i < n32; i += 32) { - x0 = vcvt_f32_f16(vld1_f16(x + i + 0 )); - x1 = vcvt_f32_f16(vld1_f16(x + i + 4 )); - x2 = vcvt_f32_f16(vld1_f16(x + i + 8 )); - x3 = vcvt_f32_f16(vld1_f16(x + i + 12)); - x4 = vcvt_f32_f16(vld1_f16(x + i + 16)); - x5 = vcvt_f32_f16(vld1_f16(x + i + 20)); - x6 = vcvt_f32_f16(vld1_f16(x + i + 24)); - x7 = vcvt_f32_f16(vld1_f16(x + i + 28)); - - y0 = vcvt_f32_f16(vld1_f16(y + i + 0 )); - y1 = vcvt_f32_f16(vld1_f16(y + i + 4 )); - y2 = vcvt_f32_f16(vld1_f16(y + i + 8 )); - y3 = vcvt_f32_f16(vld1_f16(y + i + 12)); - y4 = vcvt_f32_f16(vld1_f16(y + i + 16)); - y5 = vcvt_f32_f16(vld1_f16(y + i + 20)); - y6 = vcvt_f32_f16(vld1_f16(y + i + 24)); - y7 = vcvt_f32_f16(vld1_f16(y + i + 28)); - - sum0 = vfmaq_f32(sum0, x0, y0); - sum1 = vfmaq_f32(sum1, x1, y1); - sum2 = vfmaq_f32(sum2, x2, y2); - sum3 = vfmaq_f32(sum3, x3, y3); - sum4 = vfmaq_f32(sum4, x4, y4); - sum5 = vfmaq_f32(sum5, x5, y5); - sum6 = vfmaq_f32(sum6, x6, y6); - sum7 = vfmaq_f32(sum7, x7, y7); - } - - // reduce sum0..sum7 to sum0 - sum0 = vaddq_f32(sum0, sum1); - sum2 = vaddq_f32(sum2, sum3); - sum4 = vaddq_f32(sum4, sum5); - sum6 = vaddq_f32(sum6, sum7); - sum0 = vaddq_f32(sum0, sum2); - sum4 = vaddq_f32(sum4, sum6); - sum0 = vaddq_f32(sum0, sum4); - - // reduce sum0 to sumf - float32x2_t sumf32 = vadd_f32(vget_low_f32(sum0), vget_high_f32(sum0)); - sumf = vget_lane_f32(sumf32, 0) + vget_lane_f32(sumf32, 1); -#endif - - // leftovers - for (int i = n32; i < n; ++i) { - sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]); - } -#elif defined(__AVX2__) - // AVX 256-bit - const int n32 = (n & ~31); - - __m256 sum0 = _mm256_setzero_ps(); - __m256 sum1 = _mm256_setzero_ps(); - __m256 sum2 = _mm256_setzero_ps(); - __m256 sum3 = _mm256_setzero_ps(); - - __m256 x0, x1, x2, x3; - __m256 y0, y1, y2, y3; - - for (int i = 0; i < n32; i += 32) { - x0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 ))); - x1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 ))); - x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16))); - x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24))); - - y0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 ))); - y1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 ))); - y2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16))); - y3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24))); - - sum0 = _mm256_fmadd_ps(x0, y0, sum0); - sum1 = _mm256_fmadd_ps(x1, y1, sum1); - sum2 = _mm256_fmadd_ps(x2, y2, sum2); - sum3 = _mm256_fmadd_ps(x3, y3, sum3); - } - - const __m256 sum01 = _mm256_add_ps(sum0, sum1); - const __m256 sum23 = _mm256_add_ps(sum2, sum3); - const __m256 sum0123 = _mm256_add_ps(sum01, sum23); - - const __m128 r4 = _mm_add_ps(_mm256_castps256_ps128(sum0123), _mm256_extractf128_ps(sum0123, 1)); - const __m128 r2 = _mm_add_ps(r4, _mm_movehl_ps(r4, r4)); - const __m128 r1 = _mm_add_ss(r2, _mm_movehdup_ps(r2)); - - sumf = _mm_cvtss_f32(r1); + GGML_F16_VEC_REDUCE(sumf, sum); // leftovers - for (int i = n32; i < n; ++i) { - //GGML_ASSERT(false); - sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]); + for (int i = np; i < n; ++i) { + sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]); } -#elif defined(__AVX__) - // AVX 256-bit +#elif defined(__POWER9_VECTOR__) + // TODO: this is temporary because I cannot fit it in the GGML_SIMD pattern like all other architectures without + // being able to test it. hoping someone with access to a POWER9 machine can help out here. const int n32 = (n & ~31); - __m256 sum0 = _mm256_setzero_ps(); - __m256 sum1 = _mm256_setzero_ps(); - __m256 sum2 = _mm256_setzero_ps(); - __m256 sum3 = _mm256_setzero_ps(); - - __m256 x0, x1, x2, x3; - __m256 y0, y1, y2, y3; + vector float sum0 = vec_splats (0.0f); for (int i = 0; i < n32; i += 32) { - x0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 ))); - x1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 ))); - x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16))); - x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24))); - - y0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 ))); - y1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 ))); - y2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16))); - y3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24))); + // Use vec_xl, not vec_ld, because x is sometimes unaligned. + vector unsigned short x0 = vec_xl(i * 2 + 0, x); + vector unsigned short x1 = vec_xl(i * 2 + 16, x); + vector unsigned short x2 = vec_xl(i * 2 + 32, x); + vector unsigned short x3 = vec_xl(i * 2 + 48, x); + + vector unsigned short y0 = vec_xl(i * 2 + 0, y); + vector unsigned short y1 = vec_xl(i * 2 + 16, y); + vector unsigned short y2 = vec_xl(i * 2 + 32, y); + vector unsigned short y3 = vec_xl(i * 2 + 48, y); + + vector float fx0l = vec_extract_fp32_from_shortl(x0); + vector float fx0h = vec_extract_fp32_from_shorth(x0); + vector float fx1l = vec_extract_fp32_from_shortl(x1); + vector float fx1h = vec_extract_fp32_from_shorth(x1); + vector float fx2l = vec_extract_fp32_from_shortl(x2); + vector float fx2h = vec_extract_fp32_from_shorth(x2); + vector float fx3l = vec_extract_fp32_from_shortl(x3); + vector float fx3h = vec_extract_fp32_from_shorth(x3); + + vector float fy0l = vec_extract_fp32_from_shortl(y0); + vector float fy0h = vec_extract_fp32_from_shorth(y0); + vector float fy1l = vec_extract_fp32_from_shortl(y1); + vector float fy1h = vec_extract_fp32_from_shorth(y1); + vector float fy2l = vec_extract_fp32_from_shortl(y2); + vector float fy2h = vec_extract_fp32_from_shorth(y2); + vector float fy3l = vec_extract_fp32_from_shortl(y3); + vector float fy3h = vec_extract_fp32_from_shorth(y3); + + sum0 = vec_add(sum0, vec_mul(fx0l, fy0l)); + sum0 = vec_add(sum0, vec_mul(fx0h, fy0h)); + sum0 = vec_add(sum0, vec_mul(fx1l, fy1l)); + sum0 = vec_add(sum0, vec_mul(fx1h, fy1h)); + sum0 = vec_add(sum0, vec_mul(fx2l, fy2l)); + sum0 = vec_add(sum0, vec_mul(fx2h, fy2h)); + sum0 = vec_add(sum0, vec_mul(fx3l, fy3l)); + sum0 = vec_add(sum0, vec_mul(fx3h, fy3h)); + } + + sumf = vec_extract(sum0, 0) + vec_extract(sum0, 1) + + vec_extract(sum0, 2) + vec_extract(sum0, 3); - sum0 = _mm256_add_ps(_mm256_mul_ps(x0, y0), sum0); - sum1 = _mm256_add_ps(_mm256_mul_ps(x1, y1), sum1); - sum2 = _mm256_add_ps(_mm256_mul_ps(x2, y2), sum2); - sum3 = _mm256_add_ps(_mm256_mul_ps(x3, y3), sum3); - } - - const __m256 sum01 = _mm256_add_ps(sum0, sum1); - const __m256 sum23 = _mm256_add_ps(sum2, sum3); - const __m256 sum0123 = _mm256_add_ps(sum01, sum23); - - const __m128 r4 = _mm_add_ps(_mm256_castps256_ps128(sum0123), _mm256_extractf128_ps(sum0123, 1)); - const __m128 r2 = _mm_add_ps(r4, _mm_movehl_ps(r4, r4)); - const __m128 r1 = _mm_add_ss(r2, _mm_movehdup_ps(r2)); - - sumf = _mm_cvtss_f32(r1); - - // leftovers for (int i = n32; i < n; ++i) { - //GGML_ASSERT(false); - sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]); - } -#elif defined(__wasm_simd128__) - // WASM 128-bit - const int n16 = (n & ~15); - - v128_t sum0 = wasm_f32x4_splat(0.0f); - v128_t sum1 = wasm_f32x4_splat(0.0f); - v128_t sum2 = wasm_f32x4_splat(0.0f); - v128_t sum3 = wasm_f32x4_splat(0.0f); - - v128_t x0, x1, x2, x3; - v128_t y0, y1, y2, y3; - - float tx[16]; - float ty[16]; - - for (int i = 0; i < n16; i += 16) { - for (int k = 0; k < 16; ++k) { - tx[k] = ggml_fp16_to_fp32(x[i + k]); - ty[k] = ggml_fp16_to_fp32(y[i + k]); - } - - x0 = wasm_v128_load(tx + 0); - x1 = wasm_v128_load(tx + 4); - x2 = wasm_v128_load(tx + 8); - x3 = wasm_v128_load(tx + 12); - - y0 = wasm_v128_load(ty + 0); - y1 = wasm_v128_load(ty + 4); - y2 = wasm_v128_load(ty + 8); - y3 = wasm_v128_load(ty + 12); - - sum0 = wasm_f32x4_add(sum0, wasm_f32x4_mul(x0, y0)); - sum1 = wasm_f32x4_add(sum1, wasm_f32x4_mul(x1, y1)); - sum2 = wasm_f32x4_add(sum2, wasm_f32x4_mul(x2, y2)); - sum3 = wasm_f32x4_add(sum3, wasm_f32x4_mul(x3, y3)); - } - - sum0 = wasm_f32x4_add(sum0, sum1); - sum2 = wasm_f32x4_add(sum2, sum3); - sum0 = wasm_f32x4_add(sum0, sum2); - - sumf = wasm_f32x4_extract_lane(sum0, 0) + wasm_f32x4_extract_lane(sum0, 1) + wasm_f32x4_extract_lane(sum0, 2) + wasm_f32x4_extract_lane(sum0, 3); - - // leftovers - for (int i = n16; i < n; ++i) { - //GGML_ASSERT(false); - sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]); + sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]); } #else for (int i = 0; i < n; ++i) { - sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]); + sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]); } #endif @@ -726,144 +838,26 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t } inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) { -#ifdef __ARM_NEON - // NEON 128-bit - const int n16 = (n & ~15); - - const float32x4_t v4 = vdupq_n_f32(v); - - float32x4_t x0, x1, x2, x3; - float32x4_t y0, y1, y2, y3; - - for (int i = 0; i < n16; i += 16) { - x0 = vld1q_f32(x + i + 0); - x1 = vld1q_f32(x + i + 4); - x2 = vld1q_f32(x + i + 8); - x3 = vld1q_f32(x + i + 12); - - y0 = vld1q_f32(y + i + 0); - y1 = vld1q_f32(y + i + 4); - y2 = vld1q_f32(y + i + 8); - y3 = vld1q_f32(y + i + 12); - - y0 = vfmaq_f32(y0, x0, v4); - y1 = vfmaq_f32(y1, x1, v4); - y2 = vfmaq_f32(y2, x2, v4); - y3 = vfmaq_f32(y3, x3, v4); - - vst1q_f32(y + i + 0, y0); - vst1q_f32(y + i + 4, y1); - vst1q_f32(y + i + 8, y2); - vst1q_f32(y + i + 12, y3); - } - - // leftovers - for (int i = n16; i < n; ++i) { - y[i] += x[i]*v; - } -#elif defined(__AVX2__) - // AVX 256-bit - const int n32 = (n & ~31); - - const __m256 v4 = _mm256_set1_ps(v); - - __m256 x0, x1, x2, x3; - __m256 y0, y1, y2, y3; - - for (int i = 0; i < n32; i += 32) { - x0 = _mm256_loadu_ps(x + i + 0); - x1 = _mm256_loadu_ps(x + i + 8); - x2 = _mm256_loadu_ps(x + i + 16); - x3 = _mm256_loadu_ps(x + i + 24); - - y0 = _mm256_loadu_ps(y + i + 0); - y1 = _mm256_loadu_ps(y + i + 8); - y2 = _mm256_loadu_ps(y + i + 16); - y3 = _mm256_loadu_ps(y + i + 24); - - y0 = _mm256_fmadd_ps(x0, v4, y0); - y1 = _mm256_fmadd_ps(x1, v4, y1); - y2 = _mm256_fmadd_ps(x2, v4, y2); - y3 = _mm256_fmadd_ps(x3, v4, y3); - - _mm256_storeu_ps(y + i + 0, y0); - _mm256_storeu_ps(y + i + 8, y1); - _mm256_storeu_ps(y + i + 16, y2); - _mm256_storeu_ps(y + i + 24, y3); - } - - // leftovers - for (int i = n32; i < n; ++i) { - y[i] += x[i]*v; - } -#elif defined(__AVX__) - // AVX 256-bit - const int n32 = (n & ~31); +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F32_STEP - 1)); - const __m256 v4 = _mm256_set1_ps(v); + GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); - __m256 x0, x1, x2, x3; - __m256 y0, y1, y2, y3; + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; - for (int i = 0; i < n32; i += 32) { - x0 = _mm256_loadu_ps(x + i + 0); - x1 = _mm256_loadu_ps(x + i + 8); - x2 = _mm256_loadu_ps(x + i + 16); - x3 = _mm256_loadu_ps(x + i + 24); + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx); - y0 = _mm256_loadu_ps(y + i + 0); - y1 = _mm256_loadu_ps(y + i + 8); - y2 = _mm256_loadu_ps(y + i + 16); - y3 = _mm256_loadu_ps(y + i + 24); - - y0 = _mm256_add_ps(_mm256_mul_ps(x0, v4), y0); - y1 = _mm256_add_ps(_mm256_mul_ps(x1, v4), y1); - y2 = _mm256_add_ps(_mm256_mul_ps(x2, v4), y2); - y3 = _mm256_add_ps(_mm256_mul_ps(x3, v4), y3); - - _mm256_storeu_ps(y + i + 0, y0); - _mm256_storeu_ps(y + i + 8, y1); - _mm256_storeu_ps(y + i + 16, y2); - _mm256_storeu_ps(y + i + 24, y3); - } - - // leftovers - for (int i = n32; i < n; ++i) { - y[i] += x[i]*v; - } -#elif defined(__wasm_simd128__) - // WASM SIMD 128-bit - const int n16 = (n & ~15); - - const v128_t v4 = wasm_f32x4_splat(v); - - v128_t x0, x1, x2, x3; - v128_t y0, y1, y2, y3; - - for (int i = 0; i < n16; i += 16) { - x0 = wasm_v128_load(x + i + 0); - x1 = wasm_v128_load(x + i + 4); - x2 = wasm_v128_load(x + i + 8); - x3 = wasm_v128_load(x + i + 12); - - y0 = wasm_v128_load(y + i + 0); - y1 = wasm_v128_load(y + i + 4); - y2 = wasm_v128_load(y + i + 8); - y3 = wasm_v128_load(y + i + 12); - - y0 = wasm_f32x4_add(y0, wasm_f32x4_mul(x0, v4)); - y1 = wasm_f32x4_add(y1, wasm_f32x4_mul(x1, v4)); - y2 = wasm_f32x4_add(y2, wasm_f32x4_mul(x2, v4)); - y3 = wasm_f32x4_add(y3, wasm_f32x4_mul(x3, v4)); - - wasm_v128_store(y + i + 0, y0); - wasm_v128_store(y + i + 4, y1); - wasm_v128_store(y + i + 8, y2); - wasm_v128_store(y + i + 12, y3); + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } } // leftovers - for (int i = n16; i < n; ++i) { + for (int i = np; i < n; ++i) { y[i] += x[i]*v; } #else @@ -875,216 +869,125 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float } inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_fp16_t * restrict x, const float v) { -#ifdef __ARM_NEON - // NEON 128-bit - const int n32 = (n & ~31); +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) - const float16x8_t v8 = vdupq_n_f16(v); + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); - float16x8_t x0, x1, x2, x3; - float16x8_t y0, y1, y2, y3; + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; - for (int i = 0; i < n32; i += 32) { - y0 = vld1q_f16(y + i + 0 ); - y1 = vld1q_f16(y + i + 8 ); - y2 = vld1q_f16(y + i + 16); - y3 = vld1q_f16(y + i + 24); - - x0 = vld1q_f16(x + i + 0 ); - x1 = vld1q_f16(x + i + 8 ); - x2 = vld1q_f16(x + i + 16); - x3 = vld1q_f16(x + i + 24); - - y0 = vfmaq_f16(y0, x0, v8); - y1 = vfmaq_f16(y1, x1, v8); - y2 = vfmaq_f16(y2, x2, v8); - y3 = vfmaq_f16(y3, x3, v8); - - vst1q_f16(y + i + 0 , y0); - vst1q_f16(y + i + 8 , y1); - vst1q_f16(y + i + 16, y2); - vst1q_f16(y + i + 24, y3); - } -#else - const float32x4_t v40 = vdupq_n_f32(v); - const float32x4_t v41 = vdupq_n_f32(v); + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR); + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR); + ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); - float32x4_t x0, x1, x2, x3, x4, x5, x6, x7; - float32x4_t y0, y1, y2, y3, y4, y5, y6, y7; - - for (int i = 0; i < n32; i += 32) { - y0 = vcvt_f32_f16(vld1_f16(y + i + 0 )); - y1 = vcvt_f32_f16(vld1_f16(y + i + 4 )); - y2 = vcvt_f32_f16(vld1_f16(y + i + 8 )); - y3 = vcvt_f32_f16(vld1_f16(y + i + 12)); - y4 = vcvt_f32_f16(vld1_f16(y + i + 16)); - y5 = vcvt_f32_f16(vld1_f16(y + i + 20)); - y6 = vcvt_f32_f16(vld1_f16(y + i + 24)); - y7 = vcvt_f32_f16(vld1_f16(y + i + 28)); - - x0 = vcvt_f32_f16(vld1_f16(x + i + 0 )); - x1 = vcvt_f32_f16(vld1_f16(x + i + 4 )); - x2 = vcvt_f32_f16(vld1_f16(x + i + 8 )); - x3 = vcvt_f32_f16(vld1_f16(x + i + 12)); - x4 = vcvt_f32_f16(vld1_f16(x + i + 16)); - x5 = vcvt_f32_f16(vld1_f16(x + i + 20)); - x6 = vcvt_f32_f16(vld1_f16(x + i + 24)); - x7 = vcvt_f32_f16(vld1_f16(x + i + 28)); - - y0 = vfmaq_f32(y0, x0, v40); - y1 = vfmaq_f32(y1, x1, v40); - y2 = vfmaq_f32(y2, x2, v40); - y3 = vfmaq_f32(y3, x3, v40); - y4 = vfmaq_f32(y4, x4, v41); - y5 = vfmaq_f32(y5, x5, v41); - y6 = vfmaq_f32(y6, x6, v41); - y7 = vfmaq_f32(y7, x7, v41); - - vst1_f16(y + i + 0 , vcvt_f16_f32(y0)); - vst1_f16(y + i + 4 , vcvt_f16_f32(y1)); - vst1_f16(y + i + 8 , vcvt_f16_f32(y2)); - vst1_f16(y + i + 12, vcvt_f16_f32(y3)); - vst1_f16(y + i + 16, vcvt_f16_f32(y4)); - vst1_f16(y + i + 20, vcvt_f16_f32(y5)); - vst1_f16(y + i + 24, vcvt_f16_f32(y6)); - vst1_f16(y + i + 28, vcvt_f16_f32(y7)); + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay[j]); + } } -#endif // leftovers - for (int i = n32; i < n; ++i) { + for (int i = np; i < n; ++i) { GGML_ASSERT(false); - y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v); + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); } -#elif defined(__AVX2__) - // AVX 256-bit +#elif defined(__POWER9_VECTOR__) + // TODO: this is temporary because I cannot fit it in the GGML_SIMD pattern like all other architectures without + // being able to test it. hoping someone with access to a POWER9 machine can help out here. const int n32 = (n & ~31); - - const __m256 v8 = _mm256_set1_ps(v); - - __m256 x0, x1, x2, x3; - __m256 y0, y1, y2, y3; - for (int i = 0; i < n32; i += 32) { - y0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 ))); - y1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 ))); - y2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16))); - y3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24))); - - x0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 ))); - x1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 ))); - x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16))); - x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24))); - - y0 = _mm256_fmadd_ps(x0, v8, y0); - y1 = _mm256_fmadd_ps(x1, v8, y1); - y2 = _mm256_fmadd_ps(x2, v8, y2); - y3 = _mm256_fmadd_ps(x3, v8, y3); - - _mm_storeu_si128((__m128i*)(y + i + 0 ), _mm256_cvtps_ph(y0, 0)); - _mm_storeu_si128((__m128i*)(y + i + 8 ), _mm256_cvtps_ph(y1, 0)); - _mm_storeu_si128((__m128i*)(y + i + 16), _mm256_cvtps_ph(y2, 0)); - _mm_storeu_si128((__m128i*)(y + i + 24), _mm256_cvtps_ph(y3, 0)); + // Use vec_xl, not vec_ld, because x is sometimes unaligned! + vector unsigned short x0 = vec_xl(i * 2 + 0, x); + vector unsigned short x1 = vec_xl(i * 2 + 16, x); + vector unsigned short x2 = vec_xl(i * 2 + 32, x); + vector unsigned short x3 = vec_xl(i * 2 + 48, x); + + vector unsigned short y0 = vec_xl(i * 2 + 0, y); + vector unsigned short y1 = vec_xl(i * 2 + 16, y); + vector unsigned short y2 = vec_xl(i * 2 + 32, y); + vector unsigned short y3 = vec_xl(i * 2 + 48, y); + + vector float v4 = vec_splats(v); + + vector float fx0l = vec_extract_fp32_from_shortl(x0); + vector float fx0h = vec_extract_fp32_from_shorth(x0); + vector float fx1l = vec_extract_fp32_from_shortl(x1); + vector float fx1h = vec_extract_fp32_from_shorth(x1); + vector float fx2l = vec_extract_fp32_from_shortl(x2); + vector float fx2h = vec_extract_fp32_from_shorth(x2); + vector float fx3l = vec_extract_fp32_from_shortl(x3); + vector float fx3h = vec_extract_fp32_from_shorth(x3); + + vector float fy0l = vec_extract_fp32_from_shortl(y0); + vector float fy0h = vec_extract_fp32_from_shorth(y0); + vector float fy1l = vec_extract_fp32_from_shortl(y1); + vector float fy1h = vec_extract_fp32_from_shorth(y1); + vector float fy2l = vec_extract_fp32_from_shortl(y2); + vector float fy2h = vec_extract_fp32_from_shorth(y2); + vector float fy3l = vec_extract_fp32_from_shortl(y3); + vector float fy3h = vec_extract_fp32_from_shorth(y3); + + fy0l = vec_madd(fx0l, v4, fy0l); + fy0h = vec_madd(fx0h, v4, fy0h); + fy1l = vec_madd(fx1l, v4, fy1l); + fy1h = vec_madd(fx1h, v4, fy1h); + fy2l = vec_madd(fx2l, v4, fy2l); + fy2h = vec_madd(fx2h, v4, fy2h); + fy3l = vec_madd(fx3l, v4, fy3l); + fy3h = vec_madd(fx3h, v4, fy3h); + + y0 = vec_pack_to_short_fp32(fy0h, fy0l); + y1 = vec_pack_to_short_fp32(fy1h, fy1l); + y2 = vec_pack_to_short_fp32(fy2h, fy2l); + y3 = vec_pack_to_short_fp32(fy3h, fy3l); + + vec_xst(y0, i * 2 + 0, y); + vec_xst(y1, i * 2 + 16, y); + vec_xst(y2, i * 2 + 32, y); + vec_xst(y3, i * 2 + 48, y); } - // leftovers for (int i = n32; i < n; ++i) { - GGML_ASSERT(false); - y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v); - } -#elif defined(__AVX__) - // AVX 256-bit - const int n32 = (n & ~31); - - const __m256 v8 = _mm256_set1_ps(v); - - __m256 x0, x1, x2, x3; - __m256 y0, y1, y2, y3; - - for (int i = 0; i < n32; i += 32) { - y0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 ))); - y1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 ))); - y2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16))); - y3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24))); - - x0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 ))); - x1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 ))); - x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16))); - x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24))); - - y0 = _mm256_add_ps(_mm256_mul_ps(x0, v8), y0); - y1 = _mm256_add_ps(_mm256_mul_ps(x1, v8), y1); - y2 = _mm256_add_ps(_mm256_mul_ps(x2, v8), y2); - y3 = _mm256_add_ps(_mm256_mul_ps(x3, v8), y3); - - _mm_storeu_si128((__m128i*)(y + i + 0 ), _mm256_cvtps_ph(y0, 0)); - _mm_storeu_si128((__m128i*)(y + i + 8 ), _mm256_cvtps_ph(y1, 0)); - _mm_storeu_si128((__m128i*)(y + i + 16), _mm256_cvtps_ph(y2, 0)); - _mm_storeu_si128((__m128i*)(y + i + 24), _mm256_cvtps_ph(y3, 0)); + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); } - - // leftovers - for (int i = n32; i < n; ++i) { - GGML_ASSERT(false); - y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v); +#else + for (int i = 0; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); } -#elif defined(__wasm_simd128__) - // WASM SIMD 128-bit - const int n16 = (n & ~15); - - const v128_t v4 = wasm_f32x4_splat(v); - - v128_t x0, x1, x2, x3; - v128_t y0, y1, y2, y3; - - float tx[16]; - float ty[16]; - - for (int i = 0; i < n16; i += 16) { - for (int k = 0; k < 16; ++k) { - tx[k] = ggml_fp16_to_fp32(x[i + k]); - ty[k] = ggml_fp16_to_fp32(y[i + k]); - } +#endif +} - x0 = wasm_v128_load(tx + 0); - x1 = wasm_v128_load(tx + 4); - x2 = wasm_v128_load(tx + 8); - x3 = wasm_v128_load(tx + 12); +//inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; } +inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F32_STEP - 1)); - y0 = wasm_v128_load(ty + 0); - y1 = wasm_v128_load(ty + 4); - y2 = wasm_v128_load(ty + 8); - y3 = wasm_v128_load(ty + 12); + GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); - y0 = wasm_f32x4_add(y0, wasm_f32x4_mul(x0, v4)); - y1 = wasm_f32x4_add(y1, wasm_f32x4_mul(x1, v4)); - y2 = wasm_f32x4_add(y2, wasm_f32x4_mul(x2, v4)); - y3 = wasm_f32x4_add(y3, wasm_f32x4_mul(x3, v4)); + GGML_F32_VEC ay[GGML_F32_ARR]; - wasm_v128_store(ty + 0, y0); - wasm_v128_store(ty + 4, y1); - wasm_v128_store(ty + 8, y2); - wasm_v128_store(ty + 12, y3); + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_MUL(ay[j], vx); - for (int k = 0; k < 16; ++k) { - y[i + k] = ggml_fp32_to_fp16(ty[k]); + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); } } // leftovers - for (int i = n16; i < n; ++i) { - GGML_ASSERT(false); - y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v); + for (int i = np; i < n; ++i) { + y[i] *= v; } #else + // scalar for (int i = 0; i < n; ++i) { - y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v); + y[i] *= v; } #endif } -inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; } inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrt(*s); } inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrt(x[i]); } @@ -1093,8 +996,8 @@ inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; } inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } -const ggml_float GELU_COEF_A = 0.044715; -const ggml_float SQRT_2_OVER_PI = 0.79788456080286535587989211986876; +static const ggml_float GELU_COEF_A = 0.044715; +static const ggml_float SQRT_2_OVER_PI = 0.79788456080286535587989211986876; inline static float ggml_gelu_f32(float x) { return 0.5*x*(1.0 + tanh(SQRT_2_OVER_PI*x*(1.0 + GELU_COEF_A*x*x))); @@ -1111,9 +1014,9 @@ inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { uint16_t t; for (int i = 0; i < n; ++i) { - ggml_fp16_t fp16 = ggml_fp32_to_fp16(x[i]); + ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); memcpy(&t, &fp16, sizeof(uint16_t)); - y[i] = ggml_fp16_to_fp32(table_gelu_f16[t]); + y[i] = GGML_FP16_TO_FP32(table_gelu_f16[t]); } } #else @@ -1155,7 +1058,7 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x // data types // -const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { +static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { sizeof(int8_t ), sizeof(int16_t), sizeof(int32_t), @@ -1163,7 +1066,7 @@ const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { sizeof(float ), }; -const char * GGML_OP_LABEL[GGML_OP_COUNT] = { +static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "NONE", "DUP", @@ -1203,7 +1106,7 @@ const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "FLASH_FF", }; -const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { +static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", "x", @@ -1256,7 +1159,7 @@ struct ggml_object { char padding[8]; }; -const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object); +static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); @@ -1311,8 +1214,26 @@ struct ggml_state { }; // global state -struct ggml_state g_state; -atomic_int g_state_barrier = 0; +static struct ggml_state g_state; +static atomic_int g_state_barrier = 0; + +// barrier via spin lock +inline static void ggml_critical_section_start() { + int processing = atomic_fetch_add(&g_state_barrier, 1); + + while (processing > 0) { + // wait for other threads to finish + atomic_fetch_sub(&g_state_barrier, 1); + sched_yield(); // TODO: reconsider this + processing = atomic_fetch_add(&g_state_barrier, 1); + } +} + +// TODO: make this somehow automatically executed +// some sort of "sentry" mechanism +inline static void ggml_critical_section_end() { + atomic_fetch_sub(&g_state_barrier, 1); +} //////////////////////////////////////////////////////////////////////////////// @@ -1403,7 +1324,7 @@ bool ggml_is_padded_1d(const struct ggml_tensor * tensor) { return tensor->nb[0] == GGML_TYPE_SIZE[tensor->type] && tensor->nb[2] == tensor->nb[1]*tensor->ne[1] && - tensor->nb[3] == tensor->nb[2]*tensor->ne[2];; + tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; } bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { @@ -1443,32 +1364,45 @@ int ggml_up64(int n) { struct ggml_context * ggml_init(struct ggml_init_params params) { // make this function thread safe - { - int processing = atomic_fetch_add(&g_state_barrier, 1); - while (processing > 0) { - // wait for other threads to finish - atomic_fetch_sub(&g_state_barrier, 1); - sched_yield(); - processing = atomic_fetch_add(&g_state_barrier, 1); - } - } + ggml_critical_section_start(); static bool is_first_call = true; + if (is_first_call) { - const uint64_t t_start = ggml_time_us(); UNUSED(t_start); - - ggml_fp16_t ii; - for (int i = 0; i < (1 << 16); ++i) { - uint16_t ui = i; - memcpy(&ii, &ui, sizeof(ii)); - const float f = ggml_fp16_to_fp32(ii); - table_gelu_f16[i] = ggml_fp32_to_fp16(ggml_gelu_f32(f)); - table_exp_f16[i] = ggml_fp32_to_fp16(exp(f)); + // initialize GELU and EXP tables + { + const uint64_t t_start = ggml_time_us(); UNUSED(t_start); + + ggml_fp16_t ii; + for (int i = 0; i < (1 << 16); ++i) { + uint16_t ui = i; + memcpy(&ii, &ui, sizeof(ii)); + const float f = GGML_FP16_TO_FP32(ii); + table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f)); + table_exp_f16[i] = GGML_FP32_TO_FP16(exp(f)); + } + + const uint64_t t_end = ggml_time_us(); UNUSED(t_end); + + GGML_PRINT_DEBUG("%s: GELU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); } - const uint64_t t_end = ggml_time_us(); UNUSED(t_end); + // initialize g_state + { + const uint64_t t_start = ggml_time_us(); UNUSED(t_start); - GGML_PRINT_DEBUG("%s: GELU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); + g_state = (struct ggml_state) { + /*.contexts =*/ { 0 }, + }; + + for (int i = 0; i < GGML_MAX_CONTEXTS; ++i) { + g_state.contexts[i].used = false; + } + + const uint64_t t_end = ggml_time_us(); UNUSED(t_end); + + GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); + } is_first_call = false; } @@ -1476,14 +1410,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { // find non-used context in g_state struct ggml_context * ctx = NULL; - static bool first_time = true; - if (first_time) { - for (int i = 0; i < GGML_MAX_CONTEXTS; i++) { - g_state.contexts[i].used = false; - } - first_time = false; - } - for (int i = 0; i < GGML_MAX_CONTEXTS; i++) { if (!g_state.contexts[i].used) { g_state.contexts[i].used = true; @@ -1497,7 +1423,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { if (ctx == NULL) { GGML_PRINT_DEBUG("%s: no unused context found\n", __func__); - atomic_fetch_sub(&g_state_barrier, 1); + ggml_critical_section_end(); return NULL; } @@ -1515,22 +1441,16 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { GGML_PRINT_DEBUG("%s: context initialized\n", __func__); - atomic_fetch_sub(&g_state_barrier, 1); + ggml_critical_section_end(); return ctx; } void ggml_free(struct ggml_context * ctx) { // make this function thread safe - { - int processing = atomic_fetch_add(&g_state_barrier, 1); - while (processing > 0) { - // wait for other threads to finish - atomic_fetch_sub(&g_state_barrier, 1); - sched_yield(); - processing = atomic_fetch_add(&g_state_barrier, 1); - } - } + ggml_critical_section_start(); + + bool found = false; for (int i = 0; i < GGML_MAX_CONTEXTS; i++) { if (&g_state.contexts[i].context == ctx) { @@ -1543,15 +1463,16 @@ void ggml_free(struct ggml_context * ctx) { free(ctx->mem_buffer); } - atomic_fetch_sub(&g_state_barrier, 1); - - return; + found = true; + break; } } - GGML_PRINT_DEBUG("%s: context not found\n", __func__); + if (!found) { + GGML_PRINT_DEBUG("%s: context not found\n", __func__); + } - atomic_fetch_sub(&g_state_barrier, 1); + ggml_critical_section_end(); } size_t ggml_used_mem(const struct ggml_context * ctx) { @@ -1846,7 +1767,7 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) { case GGML_TYPE_F16: { GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); - return ggml_fp16_to_fp32(((ggml_fp16_t *)(tensor->data))[i]); + return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]); } break; case GGML_TYPE_F32: { @@ -1882,7 +1803,7 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) { case GGML_TYPE_F16: { GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); - ((ggml_fp16_t *)(tensor->data))[i] = ggml_fp32_to_fp16(value); + ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value); } break; case GGML_TYPE_F32: { @@ -1916,7 +1837,7 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) { case GGML_TYPE_F16: { GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); - return ggml_fp16_to_fp32(((ggml_fp16_t *)(tensor->data))[i]); + return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]); } break; case GGML_TYPE_F32: { @@ -1952,7 +1873,7 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) { case GGML_TYPE_F16: { GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); - ((ggml_fp16_t *)(tensor->data))[i] = ggml_fp32_to_fp16(value); + ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value); } break; case GGML_TYPE_F32: { @@ -3132,7 +3053,7 @@ void ggml_set_param( // ggml_compute_forward_dup -void ggml_compute_forward_dup_f16( +static void ggml_compute_forward_dup_f16( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { @@ -3144,25 +3065,99 @@ void ggml_compute_forward_dup_f16( return; } - //const int ne00 = src0->ne[0]; - //const int ne01 = src0->ne[1]; - //const int ne02 = src0->ne[2]; - //const int ne03 = src0->ne[3]; + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; - //const size_t nb00 = src0->nb[0]; - //const size_t nb01 = src0->nb[1]; - //const size_t nb02 = src0->nb[2]; - //const size_t nb03 = src0->nb[3]; + const size_t nb00 = src0->nb[0]; + const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; if (ggml_is_contiguous(src0) && src0->type == dst->type) { memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]); return; } - GGML_ASSERT(false); // TODO: implement + if (src0->nb[0] == sizeof(ggml_fp16_t)) { + if (dst->type == GGML_TYPE_F16) { + int id = 0; + const size_t rs = ne00*nb00; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + char * dst_ptr = (char *) dst->data + id*rs; + + memcpy(dst_ptr, src0_ptr, rs); + + id++; + } + } + } + } else if (dst->type == GGML_TYPE_F32) { + int id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr); + id++; + } + } + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } + } else { + //printf("%s: this is not optimal - fix me\n", __func__); + + if (dst->type == GGML_TYPE_F32) { + int id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr); + id++; + } + } + } + } + } else if (dst->type == GGML_TYPE_F16) { + int id = 0; + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = *src0_ptr; + id++; + } + } + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } + } } -void ggml_compute_forward_dup_f32( +static void ggml_compute_forward_dup_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { @@ -3216,7 +3211,7 @@ void ggml_compute_forward_dup_f32( for (int i00 = 0; i00 < ne00; i00++) { const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - dst_ptr[id] = ggml_fp32_to_fp16(*src0_ptr); + dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr); id++; } } @@ -3254,7 +3249,7 @@ void ggml_compute_forward_dup_f32( for (int i00 = 0; i00 < ne00; i00++) { const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - dst_ptr[id] = ggml_fp32_to_fp16(*src0_ptr); + dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr); id++; } } @@ -3266,7 +3261,7 @@ void ggml_compute_forward_dup_f32( } } -void ggml_compute_forward_dup( +static void ggml_compute_forward_dup( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { @@ -3291,7 +3286,7 @@ void ggml_compute_forward_dup( // ggml_compute_forward_add -void ggml_compute_forward_add_f32( +static void ggml_compute_forward_add_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -3344,7 +3339,7 @@ void ggml_compute_forward_add_f32( } } -void ggml_compute_forward_add( +static void ggml_compute_forward_add( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -3367,7 +3362,7 @@ void ggml_compute_forward_add( // ggml_compute_forward_sub -void ggml_compute_forward_sub_f32( +static void ggml_compute_forward_sub_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -3394,7 +3389,7 @@ void ggml_compute_forward_sub_f32( } } -void ggml_compute_forward_sub( +static void ggml_compute_forward_sub( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -3417,7 +3412,7 @@ void ggml_compute_forward_sub( // ggml_compute_forward_mul -void ggml_compute_forward_mul_f32( +static void ggml_compute_forward_mul_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -3444,7 +3439,7 @@ void ggml_compute_forward_mul_f32( } } -void ggml_compute_forward_mul( +static void ggml_compute_forward_mul( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -3467,7 +3462,7 @@ void ggml_compute_forward_mul( // ggml_compute_forward_div -void ggml_compute_forward_div_f32( +static void ggml_compute_forward_div_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -3494,7 +3489,7 @@ void ggml_compute_forward_div_f32( } } -void ggml_compute_forward_div( +static void ggml_compute_forward_div( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -3517,7 +3512,7 @@ void ggml_compute_forward_div( // ggml_compute_forward_sqr -void ggml_compute_forward_sqr_f32( +static void ggml_compute_forward_sqr_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { @@ -3541,7 +3536,7 @@ void ggml_compute_forward_sqr_f32( } } -void ggml_compute_forward_sqr( +static void ggml_compute_forward_sqr( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { @@ -3563,7 +3558,7 @@ void ggml_compute_forward_sqr( // ggml_compute_forward_sqrt -void ggml_compute_forward_sqrt_f32( +static void ggml_compute_forward_sqrt_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { @@ -3587,7 +3582,7 @@ void ggml_compute_forward_sqrt_f32( } } -void ggml_compute_forward_sqrt( +static void ggml_compute_forward_sqrt( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { @@ -3609,7 +3604,7 @@ void ggml_compute_forward_sqrt( // ggml_compute_forward_sum -void ggml_compute_forward_sum_f32( +static void ggml_compute_forward_sum_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { @@ -3645,7 +3640,7 @@ void ggml_compute_forward_sum_f32( } } -void ggml_compute_forward_sum( +static void ggml_compute_forward_sum( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { @@ -3667,7 +3662,7 @@ void ggml_compute_forward_sum( // ggml_compute_forward_mean -void ggml_compute_forward_mean_f32( +static void ggml_compute_forward_mean_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { @@ -3722,7 +3717,7 @@ void ggml_compute_forward_mean_f32( } } -void ggml_compute_forward_mean( +static void ggml_compute_forward_mean( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { @@ -3744,7 +3739,7 @@ void ggml_compute_forward_mean( // ggml_compute_forward_repeat -void ggml_compute_forward_repeat_f32( +static void ggml_compute_forward_repeat_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { @@ -3784,7 +3779,7 @@ void ggml_compute_forward_repeat_f32( } } -void ggml_compute_forward_repeat( +static void ggml_compute_forward_repeat( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { @@ -3806,7 +3801,7 @@ void ggml_compute_forward_repeat( // ggml_compute_forward_abs -void ggml_compute_forward_abs_f32( +static void ggml_compute_forward_abs_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { @@ -3830,7 +3825,7 @@ void ggml_compute_forward_abs_f32( } } -void ggml_compute_forward_abs( +static void ggml_compute_forward_abs( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { @@ -3852,7 +3847,7 @@ void ggml_compute_forward_abs( // ggml_compute_forward_sgn -void ggml_compute_forward_sgn_f32( +static void ggml_compute_forward_sgn_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { @@ -3876,7 +3871,7 @@ void ggml_compute_forward_sgn_f32( } } -void ggml_compute_forward_sgn( +static void ggml_compute_forward_sgn( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { @@ -3898,7 +3893,7 @@ void ggml_compute_forward_sgn( // ggml_compute_forward_neg -void ggml_compute_forward_neg_f32( +static void ggml_compute_forward_neg_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { @@ -3922,7 +3917,7 @@ void ggml_compute_forward_neg_f32( } } -void ggml_compute_forward_neg( +static void ggml_compute_forward_neg( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { @@ -3944,7 +3939,7 @@ void ggml_compute_forward_neg( // ggml_compute_forward_step -void ggml_compute_forward_step_f32( +static void ggml_compute_forward_step_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { @@ -3968,7 +3963,7 @@ void ggml_compute_forward_step_f32( } } -void ggml_compute_forward_step( +static void ggml_compute_forward_step( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { @@ -3990,7 +3985,7 @@ void ggml_compute_forward_step( // ggml_compute_forward_relu -void ggml_compute_forward_relu_f32( +static void ggml_compute_forward_relu_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { @@ -4014,7 +4009,7 @@ void ggml_compute_forward_relu_f32( } } -void ggml_compute_forward_relu( +static void ggml_compute_forward_relu( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { @@ -4036,7 +4031,7 @@ void ggml_compute_forward_relu( // ggml_compute_forward_gelu -void ggml_compute_forward_gelu_f32( +static void ggml_compute_forward_gelu_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { @@ -4077,7 +4072,7 @@ void ggml_compute_forward_gelu_f32( } } -void ggml_compute_forward_gelu( +static void ggml_compute_forward_gelu( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { @@ -4099,7 +4094,7 @@ void ggml_compute_forward_gelu( // ggml_compute_forward_norm -void ggml_compute_forward_norm_f32( +static void ggml_compute_forward_norm_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { @@ -4159,7 +4154,7 @@ void ggml_compute_forward_norm_f32( } } -void ggml_compute_forward_norm( +static void ggml_compute_forward_norm( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { @@ -4181,9 +4176,10 @@ void ggml_compute_forward_norm( // ggml_compute_forward_mul_mat +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) // helper function to determine if it is better to use BLAS or not // for large matrices, BLAS is faster -bool ggml_compute_forward_mul_mat_use_blas( +static bool ggml_compute_forward_mul_mat_use_blas( const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { @@ -4195,15 +4191,16 @@ bool ggml_compute_forward_mul_mat_use_blas( const int ne1 = dst->ne[1]; // TODO: find the optimal values for these - if (ggml_is_contiguous(src1) && ne0 >= 32 && ne1 >= 32 && ne10 >= 32) { + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ne0 >= 32 && ne1 >= 32 && ne10 >= 32) { //printf("BLAS: %d %d %d\n", ne0, ne1, ne10); return true; } return false; } +#endif -void ggml_compute_forward_mul_mat_f32( +static void ggml_compute_forward_mul_mat_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -4272,7 +4269,6 @@ void ggml_compute_forward_mul_mat_f32( #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { - GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(nb10 == sizeof(float)); if (params->ith != 0) return; @@ -4447,7 +4443,7 @@ void ggml_compute_forward_mul_mat_f32( //} } -void ggml_compute_forward_mul_mat_f16_f32( +static void ggml_compute_forward_mul_mat_f16_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -4536,7 +4532,7 @@ void ggml_compute_forward_mul_mat_f16_f32( int id = 0; for (int i01 = 0; i01 < ne01; ++i01) { for (int i00 = 0; i00 < ne00; ++i00) { - wdata[id++] = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00)); + wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00)); } } } @@ -4564,13 +4560,22 @@ void ggml_compute_forward_mul_mat_f16_f32( // } //} - // zT = y * xT { +#if 1 + // zT = y * xT cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, ne11, ne01, ne10, - 1.0f, y, ne10, - x, ne10, + 1.0f, y, ne00, + x, ne00, 0.0f, d, ne01); +#else + // zT = (xT * y)T + cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, + ne01, ne11, ne10, + 1.0f, x, ne00, + y, ne00, + 0.0f, d, ne01); +#endif } } } @@ -4590,7 +4595,7 @@ void ggml_compute_forward_mul_mat_f16_f32( for (int i12 = 0; i12 < ne12; ++i12) { for (int i11 = 0; i11 < ne11; ++i11) { for (int i10 = 0; i10 < ne10; ++i10) { - wdata[id++] = ggml_fp32_to_fp16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10)); + wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10)); } } } @@ -4624,12 +4629,12 @@ void ggml_compute_forward_mul_mat_f16_f32( const int ic1 = MIN(ic0 + dc, ne); for (int i = ic0; i < ic1; ++i) { - ((float *) dst->data)[i] = ggml_fp16_to_fp32(wdata[i]); + ((float *) dst->data)[i] = GGML_FP16_TO_FP32(wdata[i]); } for (int k = 1; k < nth; k++) { for (int i = ic0; i < ic1; ++i) { - ((float *) dst->data)[i] += ggml_fp16_to_fp32(wdata[(ne + CACHE_LINE_SIZE_F32)*k + i]); + ((float *) dst->data)[i] += GGML_FP16_TO_FP32(wdata[(ne + CACHE_LINE_SIZE_F32)*k + i]); } } @@ -4742,7 +4747,7 @@ void ggml_compute_forward_mul_mat_f16_f32( //} } -void ggml_compute_forward_mul_mat( +static void ggml_compute_forward_mul_mat( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -4768,7 +4773,7 @@ void ggml_compute_forward_mul_mat( // ggml_compute_forward_scale -void ggml_compute_forward_scale_f32( +static void ggml_compute_forward_scale_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -4803,7 +4808,7 @@ void ggml_compute_forward_scale_f32( } } -void ggml_compute_forward_scale( +static void ggml_compute_forward_scale( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -4826,7 +4831,7 @@ void ggml_compute_forward_scale( // ggml_compute_forward_cpy -void ggml_compute_forward_cpy( +static void ggml_compute_forward_cpy( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { @@ -4835,7 +4840,7 @@ void ggml_compute_forward_cpy( // ggml_compute_forward_reshape -void ggml_compute_forward_reshape( +static void ggml_compute_forward_reshape( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { @@ -4847,7 +4852,7 @@ void ggml_compute_forward_reshape( // ggml_compute_forward_view -void ggml_compute_forward_view( +static void ggml_compute_forward_view( const struct ggml_compute_params * params, const struct ggml_tensor * src0) { // NOP @@ -4857,7 +4862,7 @@ void ggml_compute_forward_view( // ggml_compute_forward_permute -void ggml_compute_forward_permute( +static void ggml_compute_forward_permute( const struct ggml_compute_params * params, const struct ggml_tensor * src0) { // NOP @@ -4867,7 +4872,7 @@ void ggml_compute_forward_permute( // ggml_compute_forward_transpose -void ggml_compute_forward_transpose( +static void ggml_compute_forward_transpose( const struct ggml_compute_params * params, const struct ggml_tensor * src0) { // NOP @@ -4877,7 +4882,7 @@ void ggml_compute_forward_transpose( // ggml_compute_forward_get_rows -void ggml_compute_forward_get_rows_f16( +static void ggml_compute_forward_get_rows_f16( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -4900,12 +4905,12 @@ void ggml_compute_forward_get_rows_f16( for (int j = 0; j < nc; ++j) { ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1]))[j]; - ((float *) ((char *) dst->data + i*dst->nb[1]))[j] = ggml_fp16_to_fp32(v); + ((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v); } } } -void ggml_compute_forward_get_rows_f32( +static void ggml_compute_forward_get_rows_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -4932,7 +4937,7 @@ void ggml_compute_forward_get_rows_f32( } } -void ggml_compute_forward_get_rows( +static void ggml_compute_forward_get_rows( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -4958,7 +4963,7 @@ void ggml_compute_forward_get_rows( // ggml_compute_forward_diag_mask_inf -void ggml_compute_forward_diag_mask_inf_f32( +static void ggml_compute_forward_diag_mask_inf_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -4994,7 +4999,7 @@ void ggml_compute_forward_diag_mask_inf_f32( } } -void ggml_compute_forward_diag_mask_inf( +static void ggml_compute_forward_diag_mask_inf( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -5017,7 +5022,7 @@ void ggml_compute_forward_diag_mask_inf( // ggml_compute_forward_soft_max -void ggml_compute_forward_soft_max_f32( +static void ggml_compute_forward_soft_max_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { @@ -5066,9 +5071,9 @@ void ggml_compute_forward_soft_max_f32( p[i] = 0.0; } else { //const float val = (p[i] == -INFINITY) ? 0.0 : exp(p[i] - max); - ggml_fp16_t s = ggml_fp32_to_fp16(p[i] - max); + ggml_fp16_t s = GGML_FP32_TO_FP16(p[i] - max); memcpy(&ss, &s, sizeof(ss)); - const float val = ggml_fp16_to_fp32(table_exp_f16[ss]); + const float val = GGML_FP16_TO_FP32(table_exp_f16[ss]); sum += val; p[i] = val; } @@ -5088,7 +5093,7 @@ void ggml_compute_forward_soft_max_f32( } } -void ggml_compute_forward_soft_max( +static void ggml_compute_forward_soft_max( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { @@ -5110,7 +5115,7 @@ void ggml_compute_forward_soft_max( // ggml_compute_forward_rope -void ggml_compute_forward_rope_f32( +static void ggml_compute_forward_rope_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -5167,7 +5172,7 @@ void ggml_compute_forward_rope_f32( } } -void ggml_compute_forward_rope( +static void ggml_compute_forward_rope( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -5190,7 +5195,7 @@ void ggml_compute_forward_rope( // ggml_compute_forward_conv_1d_1s -void ggml_compute_forward_conv_1d_1s_f16_f32( +static void ggml_compute_forward_conv_1d_1s_f16_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -5272,7 +5277,7 @@ void ggml_compute_forward_conv_1d_1s_f16_f32( const float * const src = (float *)((char *) src1->data + i11*nb11); ggml_fp16_t * dst_data = wdata; for (int i10 = 0; i10 < ne10; i10++) { - dst_data[(i10 + nh)*ew0 + i11] = ggml_fp32_to_fp16(src[i10]); + dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]); } } } @@ -5310,7 +5315,7 @@ void ggml_compute_forward_conv_1d_1s_f16_f32( } } -void ggml_compute_forward_conv_1d_1s_f32( +static void ggml_compute_forward_conv_1d_1s_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -5430,7 +5435,7 @@ void ggml_compute_forward_conv_1d_1s_f32( } } -void ggml_compute_forward_conv_1d_1s( +static void ggml_compute_forward_conv_1d_1s( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -5456,7 +5461,7 @@ void ggml_compute_forward_conv_1d_1s( // ggml_compute_forward_conv_1d_2s -void ggml_compute_forward_conv_1d_2s_f16_f32( +static void ggml_compute_forward_conv_1d_2s_f16_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -5538,7 +5543,7 @@ void ggml_compute_forward_conv_1d_2s_f16_f32( const float * const src = (float *)((char *) src1->data + i11*nb11); ggml_fp16_t * dst_data = wdata; for (int i10 = 0; i10 < ne10; i10++) { - dst_data[(i10 + nh)*ew0 + i11] = ggml_fp32_to_fp16(src[i10]); + dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]); } } } @@ -5576,7 +5581,7 @@ void ggml_compute_forward_conv_1d_2s_f16_f32( } } -void ggml_compute_forward_conv_1d_2s_f32( +static void ggml_compute_forward_conv_1d_2s_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -5696,7 +5701,7 @@ void ggml_compute_forward_conv_1d_2s_f32( } } -void ggml_compute_forward_conv_1d_2s( +static void ggml_compute_forward_conv_1d_2s( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -5722,7 +5727,7 @@ void ggml_compute_forward_conv_1d_2s( // ggml_compute_forward_flash_attn -void ggml_compute_forward_flash_attn_f32( +static void ggml_compute_forward_flash_attn_f32( const struct ggml_compute_params * params, const struct ggml_tensor * q, const struct ggml_tensor * k, @@ -5875,9 +5880,9 @@ void ggml_compute_forward_flash_attn_f32( S[i] = 0.0; } else { //const float val = (S[i] == -INFINITY) ? 0.0 : exp(S[i] - max); - ggml_fp16_t s = ggml_fp32_to_fp16(S[i] - max); + ggml_fp16_t s = GGML_FP32_TO_FP16(S[i] - max); memcpy(&ss, &s, sizeof(ss)); - const float val = ggml_fp16_to_fp32(table_exp_f16[ss]); + const float val = GGML_FP16_TO_FP32(table_exp_f16[ss]); sum += val; S[i] = val; } @@ -5903,7 +5908,7 @@ void ggml_compute_forward_flash_attn_f32( } } -void ggml_compute_forward_flash_attn_f16( +static void ggml_compute_forward_flash_attn_f16( const struct ggml_compute_params * params, const struct ggml_tensor * q, const struct ggml_tensor * k, @@ -6056,9 +6061,9 @@ void ggml_compute_forward_flash_attn_f16( S[i] = 0.0; } else { //const float val = (S[i] == -INFINITY) ? 0.0 : exp(S[i] - max); - ggml_fp16_t s = ggml_fp32_to_fp16(S[i] - max); + ggml_fp16_t s = GGML_FP32_TO_FP16(S[i] - max); memcpy(&ss, &s, sizeof(ss)); - const float val = ggml_fp16_to_fp32(table_exp_f16[ss]); + const float val = GGML_FP16_TO_FP32(table_exp_f16[ss]); sum += val; S[i] = val; } @@ -6073,7 +6078,7 @@ void ggml_compute_forward_flash_attn_f16( ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M); for (int i = 0; i < M; i++) { - S16[i] = ggml_fp32_to_fp16(S[i]); + S16[i] = GGML_FP32_TO_FP16(S[i]); } for (int ic = 0; ic < nev1; ++ic) { @@ -6090,7 +6095,7 @@ void ggml_compute_forward_flash_attn_f16( } } -void ggml_compute_forward_flash_attn( +static void ggml_compute_forward_flash_attn( const struct ggml_compute_params * params, const struct ggml_tensor * q, const struct ggml_tensor * k, @@ -6118,7 +6123,7 @@ void ggml_compute_forward_flash_attn( // ggml_compute_forward_flash_ff -void ggml_compute_forward_flash_ff_f16( +static void ggml_compute_forward_flash_ff_f16( const struct ggml_compute_params * params, const struct ggml_tensor * a, // F16 const struct ggml_tensor * b0, // F16 fc_w @@ -6271,7 +6276,7 @@ void ggml_compute_forward_flash_ff_f16( ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M); for (int i = 0; i < M; i++) { - S16[i] = ggml_fp32_to_fp16(S[i]); + S16[i] = GGML_FP32_TO_FP16(S[i]); } ggml_vec_gelu_f16(neb01, S16, S16); @@ -6298,7 +6303,7 @@ void ggml_compute_forward_flash_ff_f16( } } -void ggml_compute_forward_flash_ff( +static void ggml_compute_forward_flash_ff( const struct ggml_compute_params * params, const struct ggml_tensor * a, const struct ggml_tensor * b0, @@ -6327,7 +6332,7 @@ void ggml_compute_forward_flash_ff( ///////////////////////////////// -void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { +static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { assert(params); switch (tensor->op) { @@ -6470,12 +6475,12 @@ void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tenso { GGML_ASSERT(false); } break; - }; + } } //////////////////////////////////////////////////////////////////////////////// -void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, bool inplace) { +static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, bool inplace) { struct ggml_tensor * src0 = tensor->src0; struct ggml_tensor * src1 = tensor->src1; @@ -6716,10 +6721,10 @@ void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tenso { GGML_ASSERT(false); } break; - }; + } } -void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) { +static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) { if (node->grad == NULL) { // this usually happens when we generate intermediate nodes from constants in the backward pass // it can also happen during forward pass, if the user performs computations with constants @@ -6770,7 +6775,7 @@ void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) } } -void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) { +static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) { if (!expand) { cgraph->n_nodes = 0; cgraph->n_leafs = 0; @@ -6881,6 +6886,11 @@ typedef int ggml_lock_t; #define GGML_LOCK_INITIALIZER 0 +typedef pthread_t ggml_thread_t; + +#define ggml_thread_create pthread_create +#define ggml_thread_join pthread_join + #else //typedef pthread_spinlock_t ggml_lock_t; @@ -6899,6 +6909,11 @@ typedef int ggml_lock_t; #define GGML_LOCK_INITIALIZER 0 +typedef pthread_t ggml_thread_t; + +#define ggml_thread_create pthread_create +#define ggml_thread_join pthread_join + #endif struct ggml_compute_state_shared { @@ -6913,7 +6928,7 @@ struct ggml_compute_state_shared { }; struct ggml_compute_state { - pthread_t thrd; + ggml_thread_t thrd; struct ggml_compute_params params; struct ggml_tensor * node; @@ -6921,16 +6936,7 @@ struct ggml_compute_state { struct ggml_compute_state_shared * shared; }; -// function used by each compute thread -void * ggml_graph_compute_one(void * data) { - struct ggml_compute_state * state = (struct ggml_compute_state *) data; - - ggml_compute_forward(&state->params, state->node); - - return NULL; -} - -thread_ret_t ggml_graph_compute_thread(void * data) { +static thread_ret_t ggml_graph_compute_thread(void * data) { struct ggml_compute_state * state = (struct ggml_compute_state *) data; const int n_threads = state->shared->n_threads; @@ -7010,7 +7016,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) .node = NULL, .shared = &state_shared, }; - int rc = pthread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]); + int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]); assert(rc == 0); UNUSED(rc); } @@ -7185,7 +7191,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { assert(false); } break; - }; + } } if (cgraph->work != NULL && work_size > cgraph->work_size) { @@ -7354,7 +7360,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) atomic_store(&state_shared.has_work, true); for (int j = 0; j < n_threads - 1; j++) { - int rc = pthread_join(workers[j].thrd, NULL); + int rc = ggml_thread_join(workers[j].thrd, NULL); assert(rc == 0); UNUSED(rc); } @@ -7432,7 +7438,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) { } // check if node is part of the graph -bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { +static bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { if (cgraph == NULL) { return true; } @@ -7446,7 +7452,7 @@ bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor return false; } -struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { +static struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { for (int i = 0; i < cgraph->n_nodes; i++) { struct ggml_tensor * parent = cgraph->nodes[i]; @@ -7575,7 +7581,7 @@ label=\"CONST %d [%d, %d]\"; ]\n", //////////////////////////////////////////////////////////////////////////////// -void ggml_opt_set_params(int np, struct ggml_tensor * const ps[], const float * x) { +static void ggml_opt_set_params(int np, struct ggml_tensor * const ps[], const float * x) { int i = 0; for (int p = 0; p < np; ++p) { const int ne = ggml_nelements(ps[p]) ; @@ -7586,7 +7592,7 @@ void ggml_opt_set_params(int np, struct ggml_tensor * const ps[], const float * } } -void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float * x) { +static void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float * x) { int i = 0; for (int p = 0; p < np; ++p) { const int ne = ggml_nelements(ps[p]) ; @@ -7597,7 +7603,7 @@ void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float * x) { } } -void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) { +static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) { int i = 0; for (int p = 0; p < np; ++p) { const int ne = ggml_nelements(ps[p]) ; @@ -7614,7 +7620,7 @@ void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) { // ref: https://arxiv.org/pdf/1412.6980.pdf // -enum ggml_opt_result ggml_opt_adam( +static enum ggml_opt_result ggml_opt_adam( struct ggml_context * ctx, struct ggml_opt_params params, struct ggml_tensor * f, @@ -7907,7 +7913,7 @@ static enum ggml_opt_result linesearch_backtracking( return GGML_LINESEARCH_FAIL; } -enum ggml_opt_result ggml_opt_lbfgs( +static enum ggml_opt_result ggml_opt_lbfgs( struct ggml_context * ctx, struct ggml_opt_params params, struct ggml_tensor * f, @@ -8270,6 +8276,14 @@ int ggml_cpu_has_avx512(void) { #endif } +int ggml_cpu_has_fma(void) { +#if defined(__FMA__) + return 1; +#else + return 0; +#endif +} + int ggml_cpu_has_neon(void) { #if defined(__ARM_NEON) return 1; @@ -8278,6 +8292,22 @@ int ggml_cpu_has_neon(void) { #endif } +int ggml_cpu_has_arm_fma(void) { +#if defined(__ARM_FEATURE_FMA) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_f16c(void) { +#if defined(__F16C__) + return 1; +#else + return 0; +#endif +} + int ggml_cpu_has_fp16_va(void) { #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) return 1;