From a6acb3318aaece6d05dfeb7f1615c70bf7dab804 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 15 Feb 2023 20:59:36 +0200 Subject: [PATCH] sync : latest whisper.cpp (scratch buffers in ggml) --- examples/whisper/CMakeLists.txt | 2 +- examples/whisper/common.cpp | 162 +++++ examples/whisper/common.h | 40 ++ examples/whisper/main.cpp | 213 +++---- examples/whisper/whisper.cpp | 1002 ++++++++++++++++++++++--------- examples/whisper/whisper.h | 24 +- include/ggml/ggml.h | 9 + src/ggml.c | 151 +++-- 8 files changed, 1146 insertions(+), 457 deletions(-) create mode 100644 examples/whisper/common.cpp create mode 100644 examples/whisper/common.h diff --git a/examples/whisper/CMakeLists.txt b/examples/whisper/CMakeLists.txt index 55dd1b4..c8fa83a 100644 --- a/examples/whisper/CMakeLists.txt +++ b/examples/whisper/CMakeLists.txt @@ -10,6 +10,6 @@ target_link_libraries(whisper-cpp PRIVATE ) set(TEST_TARGET whisper) -add_executable(${TEST_TARGET} main.cpp) +add_executable(${TEST_TARGET} main.cpp common.cpp) target_link_libraries(${TEST_TARGET} PRIVATE whisper-cpp) target_include_directories(${TEST_TARGET} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/..) diff --git a/examples/whisper/common.cpp b/examples/whisper/common.cpp new file mode 100644 index 0000000..194ef0e --- /dev/null +++ b/examples/whisper/common.cpp @@ -0,0 +1,162 @@ +#include "common.h" + +// third-party utilities +// use your favorite implementations +#define DR_WAV_IMPLEMENTATION +#include "dr_wav.h" + +#include +#include + +#ifndef M_PI +#define M_PI 3.14159265358979323846 +#endif + +std::string trim(const std::string & s) { + std::regex e("^\\s+|\\s+$"); + return std::regex_replace(s, e, ""); +} + +std::string replace(const std::string & s, const std::string & from, const std::string & to) { + std::string result = s; + size_t pos = 0; + while ((pos = result.find(from, pos)) != std::string::npos) { + result.replace(pos, from.length(), to); + pos += to.length(); + } + return result; +} + +bool read_wav(const std::string & fname, std::vector& pcmf32, std::vector>& pcmf32s, bool stereo) { + drwav wav; + std::vector wav_data; // used for pipe input from stdin + + if (fname == "-") { + { + uint8_t buf[1024]; + while (true) + { + const size_t n = fread(buf, 1, sizeof(buf), stdin); + if (n == 0) { + break; + } + wav_data.insert(wav_data.end(), buf, buf + n); + } + } + + if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) { + fprintf(stderr, "error: failed to open WAV file from stdin\n"); + return false; + } + + fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size()); + } + else if (drwav_init_file(&wav, fname.c_str(), nullptr) == false) { + fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname.c_str()); + return false; + } + + if (wav.channels != 1 && wav.channels != 2) { + fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", __func__, fname.c_str()); + return false; + } + + if (stereo && wav.channels != 2) { + fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization\n", __func__, fname.c_str()); + return false; + } + + if (wav.sampleRate != COMMON_SAMPLE_RATE) { + fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", __func__, fname.c_str(), COMMON_SAMPLE_RATE/1000); + return false; + } + + if (wav.bitsPerSample != 16) { + fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", __func__, fname.c_str()); + return false; + } + + const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8); + + std::vector pcm16; + pcm16.resize(n*wav.channels); + drwav_read_pcm_frames_s16(&wav, n, pcm16.data()); + drwav_uninit(&wav); + + // convert to mono, float + pcmf32.resize(n); + if (wav.channels == 1) { + for (uint64_t i = 0; i < n; i++) { + pcmf32[i] = float(pcm16[i])/32768.0f; + } + } else { + for (uint64_t i = 0; i < n; i++) { + pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f; + } + } + + if (stereo) { + // convert to stereo, float + pcmf32s.resize(2); + + pcmf32s[0].resize(n); + pcmf32s[1].resize(n); + for (uint64_t i = 0; i < n; i++) { + pcmf32s[0][i] = float(pcm16[2*i])/32768.0f; + pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f; + } + } + + return true; +} + +void high_pass_filter(std::vector & data, float cutoff, float sample_rate) { + const float rc = 1.0f / (2.0f * M_PI * cutoff); + const float dt = 1.0f / sample_rate; + const float alpha = dt / (rc + dt); + + float y = data[0]; + + for (size_t i = 1; i < data.size(); i++) { + y = alpha * (y + data[i] - data[i - 1]); + data[i] = y; + } +} + +bool vad_simple(std::vector & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) { + const int n_samples = pcmf32.size(); + const int n_samples_last = (sample_rate * last_ms) / 1000; + + if (n_samples_last >= n_samples) { + // not enough samples - assume no speech + return false; + } + + if (freq_thold > 0.0f) { + high_pass_filter(pcmf32, freq_thold, sample_rate); + } + + float energy_all = 0.0f; + float energy_last = 0.0f; + + for (int i = 0; i < n_samples; i++) { + energy_all += fabsf(pcmf32[i]); + + if (i >= n_samples - n_samples_last) { + energy_last += fabsf(pcmf32[i]); + } + } + + energy_all /= n_samples; + energy_last /= n_samples_last; + + if (verbose) { + fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold); + } + + if (energy_last > vad_thold*energy_all) { + return false; + } + + return true; +} diff --git a/examples/whisper/common.h b/examples/whisper/common.h new file mode 100644 index 0000000..04dd7cb --- /dev/null +++ b/examples/whisper/common.h @@ -0,0 +1,40 @@ +#pragma once + +// needs to match WHISPER_SAMPLE_RATE +#define COMMON_SAMPLE_RATE 16000 + +#include +#include + +std::string trim(const std::string & s); + +std::string replace( + const std::string & s, + const std::string & from, + const std::string & to); + +// Read WAV audio file and store the PCM data into pcmf32 +// The sample rate of the audio must be equal to COMMON_SAMPLE_RATE +// If stereo flag is set and the audio has 2 channels, the pcmf32s will contain 2 channel PCM +bool read_wav( + const std::string & fname, + std::vector & pcmf32, + std::vector> & pcmf32s, + bool stereo); + +// Apply a high-pass frequency filter to PCM audio +// Suppresses frequencies below cutoff Hz +void high_pass_filter( + std::vector & data, + float cutoff, + float sample_rate); + +// Basic voice activity detection (VAD) using audio energy adaptive threshold +bool vad_simple( + std::vector & pcmf32, + int sample_rate, + int last_ms, + float vad_thold, + float freq_thold, + bool verbose); + diff --git a/examples/whisper/main.cpp b/examples/whisper/main.cpp index 65b06ca..5bd7e42 100644 --- a/examples/whisper/main.cpp +++ b/examples/whisper/main.cpp @@ -1,9 +1,6 @@ -#include "whisper.h" +#include "common.h" -// third-party utilities -// use your favorite implementations -#define DR_WAV_IMPLEMENTATION -#include "dr_wav.h" +#include "whisper.h" #include #include @@ -53,22 +50,24 @@ void replace_all(std::string & s, const std::string & search, const std::string // command-line parameters struct whisper_params { int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); - int32_t n_processors = 1; - int32_t offset_t_ms = 0; - int32_t offset_n = 0; - int32_t duration_ms = 0; + int32_t n_processors = 1; + int32_t offset_t_ms = 0; + int32_t offset_n = 0; + int32_t duration_ms = 0; int32_t max_context = -1; - int32_t max_len = 0; - int32_t best_of = 5; + int32_t max_len = 0; + int32_t best_of = 5; int32_t beam_size = -1; - float word_thold = 0.01f; - float entropy_thold = 2.4f; - float logprob_thold = -1.0f; + float word_thold = 0.01f; + float entropy_thold = 2.40f; + float logprob_thold = -1.00f; bool speed_up = false; bool translate = false; bool diarize = false; + bool split_on_word = false; + bool no_fallback = false; bool output_txt = false; bool output_vtt = false; bool output_srt = false; @@ -84,6 +83,7 @@ struct whisper_params { std::string model = "models/ggml-base.en.bin"; std::vector fname_inp = {}; + std::vector fname_out = {}; }; void whisper_print_usage(int argc, char ** argv, const whisper_params & params); @@ -91,7 +91,12 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params); bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { for (int i = 1; i < argc; i++) { std::string arg = argv[i]; - + + if (arg == "-"){ + params.fname_inp.push_back(arg); + continue; + } + if (arg[0] != '-') { params.fname_inp.push_back(arg); continue; @@ -116,11 +121,14 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } + else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; } + else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; } else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; } else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; } else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; } + else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; } @@ -144,35 +152,38 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]); fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); - fprintf(stderr, " -h, --help [default] show this help message and exit\n"); - fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); - fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors); - fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); - fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n); - fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); - fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); - fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); - fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of); - fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size); - fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); - fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold); - fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); - fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); - fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); - fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); - fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); - fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false"); - fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false"); - fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false"); - fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false"); - fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); - fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); - fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false"); - fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true"); - fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str()); - fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str()); - fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); - fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", ""); + fprintf(stderr, " -h, --help [default] show this help message and exit\n"); + fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); + fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors); + fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); + fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n); + fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); + fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); + fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); + fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false"); + fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of); + fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size); + fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); + fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold); + fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); + fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); + fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); + fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); + fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false"); + fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); + fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false"); + fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false"); + fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false"); + fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false"); + fprintf(stderr, " -of FNAME, --output-file FNAME [%-7s] output file path (without file extension)\n", ""); + fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); + fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); + fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false"); + fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true"); + fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str()); + fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str()); + fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); + fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", ""); fprintf(stderr, "\n"); } @@ -343,9 +354,6 @@ bool output_csv(struct whisper_context * ctx, const char * fname) { const int n_segments = whisper_full_n_segments(ctx); for (int i = 0; i < n_segments; ++i) { const char * text = whisper_full_get_segment_text(ctx, i); - if (text[0] == ' ') { - text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character. - } const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i); @@ -514,90 +522,14 @@ int main(int argc, char ** argv) { for (int f = 0; f < (int) params.fname_inp.size(); ++f) { const auto fname_inp = params.fname_inp[f]; + const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f]; - std::vector pcmf32; // mono-channel F32 PCM + std::vector pcmf32; // mono-channel F32 PCM std::vector> pcmf32s; // stereo-channel F32 PCM - // WAV input - { - drwav wav; - std::vector wav_data; // used for pipe input from stdin - - if (fname_inp == "-") { - { - uint8_t buf[1024]; - while (true) - { - const size_t n = fread(buf, 1, sizeof(buf), stdin); - if (n == 0) { - break; - } - wav_data.insert(wav_data.end(), buf, buf + n); - } - } - - if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) { - fprintf(stderr, "error: failed to open WAV file from stdin\n"); - return 4; - } - - fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size()); - } - else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) { - fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str()); - return 5; - } - - if (wav.channels != 1 && wav.channels != 2) { - fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str()); - return 6; - } - - if (params.diarize && wav.channels != 2 && params.no_timestamps == false) { - fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", argv[0], fname_inp.c_str()); - return 6; - } - - if (wav.sampleRate != WHISPER_SAMPLE_RATE) { - fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", argv[0], fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000); - return 8; - } - - if (wav.bitsPerSample != 16) { - fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str()); - return 9; - } - - const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8); - - std::vector pcm16; - pcm16.resize(n*wav.channels); - drwav_read_pcm_frames_s16(&wav, n, pcm16.data()); - drwav_uninit(&wav); - - // convert to mono, float - pcmf32.resize(n); - if (wav.channels == 1) { - for (uint64_t i = 0; i < n; i++) { - pcmf32[i] = float(pcm16[i])/32768.0f; - } - } else { - for (uint64_t i = 0; i < n; i++) { - pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f; - } - } - - if (params.diarize) { - // convert to stereo, float - pcmf32s.resize(2); - - pcmf32s[0].resize(n); - pcmf32s[1].resize(n); - for (uint64_t i = 0; i < n; i++) { - pcmf32s[0][i] = float(pcm16[2*i])/32768.0f; - pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f; - } - } + if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.diarize)) { + fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str()); + continue; } // print system information @@ -646,18 +578,20 @@ int main(int argc, char ** argv) { wparams.token_timestamps = params.output_wts || params.max_len > 0; wparams.thold_pt = params.word_thold; - wparams.entropy_thold = params.entropy_thold; - wparams.logprob_thold = params.logprob_thold; wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; + wparams.split_on_word = params.split_on_word; wparams.speed_up = params.speed_up; + wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data(); + wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size(); + wparams.greedy.best_of = params.best_of; wparams.beam_search.beam_size = params.beam_size; - wparams.temperature_inc = -1; - wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data(); - wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size(); + wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc; + wparams.entropy_thold = params.entropy_thold; + wparams.logprob_thold = params.logprob_thold; whisper_print_user_data user_data = { ¶ms, &pcmf32s }; @@ -692,34 +626,33 @@ int main(int argc, char ** argv) { // output to text file if (params.output_txt) { - const auto fname_txt = fname_inp + ".txt"; + const auto fname_txt = fname_out + ".txt"; output_txt(ctx, fname_txt.c_str()); } // output to VTT file if (params.output_vtt) { - const auto fname_vtt = fname_inp + ".vtt"; + const auto fname_vtt = fname_out + ".vtt"; output_vtt(ctx, fname_vtt.c_str()); } // output to SRT file if (params.output_srt) { - const auto fname_srt = fname_inp + ".srt"; + const auto fname_srt = fname_out + ".srt"; output_srt(ctx, fname_srt.c_str(), params); } // output to WTS file if (params.output_wts) { - const auto fname_wts = fname_inp + ".wts"; + const auto fname_wts = fname_out + ".wts"; output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE); } - // output to CSV file + // output to CSV file if (params.output_csv) { - const auto fname_csv = fname_inp + ".csv"; + const auto fname_csv = fname_out + ".csv"; output_csv(ctx, fname_csv.c_str()); } - } } diff --git a/examples/whisper/whisper.cpp b/examples/whisper/whisper.cpp index 05bf58e..331d408 100644 --- a/examples/whisper/whisper.cpp +++ b/examples/whisper/whisper.cpp @@ -17,6 +17,68 @@ #include #include +#if defined(GGML_BIG_ENDIAN) +#include + +template +static T byteswap(T value) { + return std::byteswap(value); +} + +template<> +float byteswap(float value) { + return std::bit_cast(byteswap(std::bit_cast(value))); +} + +template +static void byteswap_tensor_data(ggml_tensor * tensor) { + T * datum = reinterpret_cast(tensor->data); + for (int i = 0; i < ggml_nelements(tensor); i++) { + datum[i] = byteswap(datum[i]); + } +} + +static void byteswap_tensor(ggml_tensor * tensor) { + switch (tensor->type) { + case GGML_TYPE_I16: { + byteswap_tensor_data(tensor); + break; + } + case GGML_TYPE_F16: { + byteswap_tensor_data(tensor); + break; + } + case GGML_TYPE_I32: { + byteswap_tensor_data(tensor); + break; + } + case GGML_TYPE_F32: { + byteswap_tensor_data(tensor); + break; + } + default: { // GML_TYPE_I8 + break; + } + } +} + +#define BYTESWAP_VALUE(d) d = byteswap(d) +#define BYTESWAP_FILTERS(f) \ + do { \ + for (auto & datum : f.data) { \ + datum = byteswap(datum); \ + } \ + } while (0) +#define BYTESWAP_TENSOR(t) \ + do { \ + byteswap_tensor(tensor); \ + } while (0) +#else +#define BYTESWAP_VALUE(d) do {} while (0) +#define BYTESWAP_FILTERS(f) do {} while (0) +#define BYTESWAP_TENSOR(t) do {} while (0) +#endif + #define WHISPER_ASSERT(x) \ do { \ if (!(x)) { \ @@ -41,6 +103,9 @@ //#define WHISPER_USE_FLASH_FF #define WHISPER_MAX_DECODERS 16 +#define WHISPER_USE_SCRATCH +#define WHISPER_MAX_SCRATCH_BUFFERS 16 + // available whisper models enum e_model { MODEL_UNKNOWN, @@ -155,6 +220,38 @@ static const std::map> g_lang = { static const size_t MB = 1024*1024; +static const std::map MEM_REQ_SCRATCH0 = { + { MODEL_TINY, 12ull*MB }, + { MODEL_BASE, 15ull*MB }, + { MODEL_SMALL, 23ull*MB }, + { MODEL_MEDIUM, 31ull*MB }, + { MODEL_LARGE, 38ull*MB }, +}; + +static const std::map MEM_REQ_SCRATCH1 = { + { MODEL_TINY, 18ull*MB }, + { MODEL_BASE, 24ull*MB }, + { MODEL_SMALL, 36ull*MB }, + { MODEL_MEDIUM, 48ull*MB }, + { MODEL_LARGE, 60ull*MB }, +}; + +static const std::map MEM_REQ_SCRATCH2 = { + { MODEL_TINY, 4ull*MB }, + { MODEL_BASE, 4ull*MB }, + { MODEL_SMALL, 6ull*MB }, + { MODEL_MEDIUM, 7ull*MB }, + { MODEL_LARGE, 9ull*MB }, +}; + +static const std::map MEM_REQ_SCRATCH3 = { + { MODEL_TINY, 4ull*MB }, + { MODEL_BASE, 4ull*MB }, + { MODEL_SMALL, 6ull*MB }, + { MODEL_MEDIUM, 7ull*MB }, + { MODEL_LARGE, 9ull*MB }, +}; + static const std::map MEM_REQ_MODEL = { { MODEL_TINY, 74ull*MB }, { MODEL_BASE, 142ull*MB }, @@ -180,35 +277,19 @@ static const std::map MEM_REQ_KV_CROSS = { }; static const std::map MEM_REQ_ENCODE = { - { MODEL_TINY, 80ull*MB }, - { MODEL_BASE, 128ull*MB }, - { MODEL_SMALL, 300ull*MB }, - { MODEL_MEDIUM, 680ull*MB }, - { MODEL_LARGE, 1100ull*MB }, -}; - -static const std::map MEM_REQ_ENCODE_LAYER = { - { MODEL_TINY, 104ull*MB }, - { MODEL_BASE, 138ull*MB }, - { MODEL_SMALL, 208ull*MB }, - { MODEL_MEDIUM, 280ull*MB }, - { MODEL_LARGE, 354ull*MB }, + { MODEL_TINY, 6ull*MB }, + { MODEL_BASE, 8ull*MB }, + { MODEL_SMALL, 13ull*MB }, + { MODEL_MEDIUM, 22ull*MB }, + { MODEL_LARGE, 33ull*MB }, }; static const std::map MEM_REQ_DECODE = { - { MODEL_TINY, 200ull*MB }, - { MODEL_BASE, 202ull*MB }, - { MODEL_SMALL, 204ull*MB }, - { MODEL_MEDIUM, 206ull*MB }, - { MODEL_LARGE, 208ull*MB }, -}; - -static const std::map MEM_REQ_DECODE_LAYER = { - { MODEL_TINY, 32ull*MB }, - { MODEL_BASE, 44ull*MB }, - { MODEL_SMALL, 64ull*MB }, - { MODEL_MEDIUM, 84ull*MB }, - { MODEL_LARGE, 110ull*MB }, + { MODEL_TINY, 3ull*MB }, + { MODEL_BASE, 5ull*MB }, + { MODEL_SMALL, 10ull*MB }, + { MODEL_MEDIUM, 18ull*MB }, + { MODEL_LARGE, 27ull*MB }, }; struct whisper_mel { @@ -474,6 +555,12 @@ struct whisper_context { int64_t t_decode_us = 0; int64_t t_start_us = 0; + int32_t n_sample = 0; // number of tokens sampled + int32_t n_encode = 0; // number of encoder calls + int32_t n_decode = 0; // number of decoder calls + int32_t n_fail_p = 0; // number of logprob threshold failures + int32_t n_fail_h = 0; // number of entropy threshold failures + ggml_type wtype; // weight type (FP32 or FP16) whisper_mel mel; @@ -489,7 +576,10 @@ struct whisper_context { // memory buffers used by encode / decode contexts std::vector buf_compute; - std::vector buf_compute_layer; + std::vector buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS]; + + int buf_last = 0; + size_t buf_max_size[WHISPER_MAX_SCRATCH_BUFFERS] = { 0 }; // decode output (2-dimensional array: [n_tokens][n_vocab]) std::vector logits; @@ -502,6 +592,8 @@ struct whisper_context { mutable std::mt19937 rng; // used for sampling at t > 0.0 + int lang_id; + // [EXPERIMENTAL] token-level timestamps data int64_t t_beg; int64_t t_last; @@ -510,11 +602,43 @@ struct whisper_context { // [EXPERIMENTAL] speed-up techniques int32_t exp_n_audio_ctx; // 0 - use default + + void use_buf(struct ggml_context * ctx, int i) { +#if defined(WHISPER_USE_SCRATCH) + size_t last_size = 0; + + if (i == -1) { + last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, }); + } else { + auto & buf = buf_scratch[i]; + last_size = ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), }); + } + + if (buf_last >= 0) { + buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size); + } + + buf_last = i; +#else + (void) i; + (void) ctx; +#endif + } + + size_t get_buf_max_mem(int i) const { +#if defined(WHISPER_USE_SCRATCH) + return buf_max_size[i]; +#else + (void) i; + return 0; +#endif + } }; template static void read_safe(whisper_model_loader * loader, T & dest) { loader->read(loader->context, &dest, sizeof(T)); + BYTESWAP_VALUE(dest); } static bool kv_cache_init( @@ -675,10 +799,13 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con { // this is the total memory required to run the inference const size_t mem_required = - scale*MEM_REQ_MODEL.at (model.type) + - scale*MEM_REQ_KV_CROSS.at (model.type) + - scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)) + - scale*std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)); + MEM_REQ_SCRATCH0.at (model.type) + + MEM_REQ_SCRATCH1.at (model.type) + + MEM_REQ_SCRATCH2.at (model.type) + + MEM_REQ_SCRATCH3.at (model.type) + + scale*MEM_REQ_MODEL.at (model.type) + + scale*MEM_REQ_KV_CROSS.at(model.type) + + scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)); // this is the memory required by one decoder const size_t mem_required_decoder = @@ -714,8 +841,12 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); } - wctx.buf_compute.resize (scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); - wctx.buf_compute_layer.resize(scale*std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type))); + wctx.buf_compute.resize(scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); + + wctx.buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(model.type)); + wctx.buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(model.type)); + wctx.buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(model.type)); + wctx.buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(model.type)); } // load mel filters @@ -727,6 +858,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con filters.data.resize(filters.n_mel * filters.n_fft); loader->read(loader->context, filters.data.data(), filters.data.size() * sizeof(float)); + BYTESWAP_FILTERS(filters); } // load vocab @@ -1190,6 +1322,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con } loader->read(loader->context, tensor->data, ggml_nbytes(tensor)); + BYTESWAP_TENSOR(tensor); //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0); total_size += ggml_nbytes(tensor); @@ -1246,6 +1379,8 @@ static bool whisper_encode( struct ggml_context * ctx0 = ggml_init(params); + wctx.use_buf(ctx0, 0); + struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels); assert(mel->type == GGML_TYPE_F32); { @@ -1266,6 +1401,8 @@ static bool whisper_encode( // convolution + gelu { + wctx.use_buf(ctx0, 1); + cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel); cur = ggml_add(ctx0, ggml_repeat(ctx0, @@ -1275,6 +1412,8 @@ static bool whisper_encode( cur = ggml_gelu(ctx0, cur); + wctx.use_buf(ctx0, 0); + cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur); cur = ggml_add(ctx0, ggml_repeat(ctx0, @@ -1285,6 +1424,8 @@ static bool whisper_encode( cur = ggml_gelu(ctx0, cur); } + wctx.use_buf(ctx0, 3); + // =================================================================== // NOTE: experimenting with partial evaluation of the encoder (ignore) //static int iter = -1; @@ -1305,6 +1446,7 @@ static bool whisper_encode( struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset); cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur)); + // =================================================================== // original: @@ -1315,153 +1457,158 @@ static bool whisper_encode( for (int il = 0; il < n_layer; ++il) { const auto & layer = model.layers_encoder[il]; - // create separate context for each layer to reduce memory usage - - struct ggml_init_params paramsL; - paramsL.mem_size = wctx.buf_compute_layer.size(); - paramsL.mem_buffer = wctx.buf_compute_layer.data(); - - struct ggml_context * ctxL = ggml_init(paramsL); - // norm { - cur = ggml_norm(ctxL, inpL); + wctx.use_buf(ctx0, 0); + + cur = ggml_norm(ctx0, inpL); // cur = ln_0_w*cur + ln_0_b - cur = ggml_add(ctxL, - ggml_mul(ctxL, - ggml_repeat(ctxL, layer.attn_ln_0_w, cur), + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, layer.attn_ln_0_w, cur), cur), - ggml_repeat(ctxL, layer.attn_ln_0_b, cur)); + ggml_repeat(ctx0, layer.attn_ln_0_b, cur)); } // self-attention { - struct ggml_tensor * Qcur = ggml_mul_mat(ctxL, + wctx.use_buf(ctx0, 1); + + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.attn_q_w, cur); - Qcur = ggml_add(ctxL, - ggml_repeat(ctxL, + Qcur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.attn_q_b, Qcur), Qcur); - //Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); + //Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); // note: no bias for Key - struct ggml_tensor * Kcur = ggml_mul_mat(ctxL, + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.attn_k_w, cur); - //Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); + //Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); - struct ggml_tensor * Vcur = ggml_mul_mat(ctxL, + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.attn_v_w, cur); - Vcur = ggml_add(ctxL, - ggml_repeat(ctxL, + Vcur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.attn_v_b, Vcur), Vcur); // ------ + wctx.use_buf(ctx0, 0); + #ifdef WHISPER_USE_FLASH_ATTN struct ggml_tensor * Q = - ggml_permute(ctxL, - ggml_cpy(ctxL, + ggml_permute(ctx0, + ggml_cpy(ctx0, Qcur, - ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)), + ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)), 0, 2, 1, 3); struct ggml_tensor * K = - ggml_permute(ctxL, - ggml_cpy(ctxL, + ggml_permute(ctx0, + ggml_cpy(ctx0, Kcur, - ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)), + ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)), 0, 2, 1, 3); struct ggml_tensor * V = - ggml_cpy(ctxL, - ggml_permute(ctxL, - ggml_reshape_3d(ctxL, + ggml_cpy(ctx0, + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, Vcur, n_state/n_head, n_head, n_ctx), 1, 2, 0, 3), - ggml_new_tensor_3d(ctxL, wctx.wtype, n_ctx, n_state/n_head, n_head) + ggml_new_tensor_3d(ctx0, wctx.wtype, n_ctx, n_state/n_head, n_head) ); - struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false); + struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false); #else struct ggml_tensor * Q = - ggml_permute(ctxL, - ggml_cpy(ctxL, + ggml_permute(ctx0, + ggml_cpy(ctx0, Qcur, - ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)), + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)), 0, 2, 1, 3); struct ggml_tensor * K = - ggml_permute(ctxL, - ggml_cpy(ctxL, + ggml_permute(ctx0, + ggml_cpy(ctx0, Kcur, - ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)), + ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)), 0, 2, 1, 3); // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q); + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); struct ggml_tensor * KQ_scaled = - ggml_scale(ctxL, + ggml_scale(ctx0, KQ, - ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head)) + ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head)) ); - struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_scaled); + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_scaled); //struct ggml_tensor * V_trans = - // ggml_permute(ctxL, - // ggml_cpy(ctxL, + // ggml_permute(ctx0, + // ggml_cpy(ctx0, // Vcur, - // ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)), + // ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)), // 1, 2, 0, 3); - //struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max); + //struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); struct ggml_tensor * V = - ggml_cpy(ctxL, - ggml_permute(ctxL, - ggml_reshape_3d(ctxL, + ggml_cpy(ctx0, + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, Vcur, n_state/n_head, n_head, n_ctx), 0, 2, 1, 3), - ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_ctx, n_head) + ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_ctx, n_head) ); - struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, ggml_transpose(ctx0, V), KQ_soft_max); #endif + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3); + wctx.use_buf(ctx0, 1); - cur = ggml_cpy(ctxL, + cur = ggml_cpy(ctx0, KQV_merged, - ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, n_ctx)); + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx)); } // projection { - cur = ggml_mul_mat(ctxL, + wctx.use_buf(ctx0, 0); + + cur = ggml_mul_mat(ctx0, layer.attn_ln_1_w, cur); - cur = ggml_add(ctxL, - ggml_repeat(ctxL, layer.attn_ln_1_b, cur), + wctx.use_buf(ctx0, 1); + + cur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.attn_ln_1_b, cur), cur); } + wctx.use_buf(ctx0, 2); + // add the input - cur = ggml_add(ctxL, cur, inpL); + cur = ggml_add(ctx0, cur, inpL); struct ggml_tensor * inpFF = cur; @@ -1469,75 +1616,75 @@ static bool whisper_encode( { // norm { - cur = ggml_norm(ctxL, inpFF); + wctx.use_buf(ctx0, 0); + + cur = ggml_norm(ctx0, inpFF); + + wctx.use_buf(ctx0, 1); // cur = mlp_ln_w*cur + mlp_ln_b - cur = ggml_add(ctxL, - ggml_mul(ctxL, - ggml_repeat(ctxL, layer.mlp_ln_w, cur), + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, layer.mlp_ln_w, cur), cur), - ggml_repeat(ctxL, layer.mlp_ln_b, cur)); + ggml_repeat(ctx0, layer.mlp_ln_b, cur)); } #ifdef WHISPER_USE_FLASH_FF - cur = ggml_flash_ff(ctxL, - ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, wctx.wtype, n_state, N)), + wctx.use_buf(ctx0, 0); + + cur = ggml_flash_ff(ctx0, + ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wctx.wtype, n_state, n_ctx)), layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); #else + wctx.use_buf(ctx0, 0); + // fully connected - cur = ggml_mul_mat(ctxL, + cur = ggml_mul_mat(ctx0, layer.mlp_0_w, cur); - cur = ggml_add(ctxL, - ggml_repeat(ctxL, layer.mlp_0_b, cur), + wctx.use_buf(ctx0, 1); + + cur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.mlp_0_b, cur), cur); + wctx.use_buf(ctx0, 0); + // GELU activation - cur = ggml_gelu(ctxL, cur); + cur = ggml_gelu(ctx0, cur); + + wctx.use_buf(ctx0, 1); // projection - cur = ggml_mul_mat(ctxL, + cur = ggml_mul_mat(ctx0, layer.mlp_1_w, cur); - cur = ggml_add(ctxL, - ggml_repeat(ctxL, layer.mlp_1_b, cur), + wctx.use_buf(ctx0, 0); + + cur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.mlp_1_b, cur), cur); #endif } - // output from this layer - struct ggml_tensor * inpO = ggml_add(ctxL, cur, inpFF); - - { - struct ggml_cgraph gf = {}; - gf.n_threads = n_threads; - - ggml_build_forward_expand(&gf, inpO); - ggml_graph_compute (ctxL, &gf); - - //ggml_graph_print(&gf); - } - - // TODO: this is a hack to have per-layer computation graphs - need to come up with something better - // input for next layer (inpO -> inpL) - memcpy(inpL->data, inpO->data, ggml_nbytes(inpL)); - inpL->op = GGML_OP_NONE; - inpL->src0 = nullptr; - inpL->src1 = nullptr; - - //printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0); + wctx.use_buf(ctx0, 3); - ggml_free(ctxL); + inpL = ggml_add(ctx0, cur, inpFF); } cur = inpL; // norm { + wctx.use_buf(ctx0, 0); + cur = ggml_norm(ctx0, cur); + wctx.use_buf(ctx0, 1); + // cur = ln_f_g*cur + ln_f_b cur = ggml_add(ctx0, ggml_mul(ctx0, @@ -1546,6 +1693,8 @@ static bool whisper_encode( ggml_repeat(ctx0, model.e_ln_b, cur)); } + wctx.use_buf(ctx0, -1); + // run the computation { struct ggml_cgraph gf = {}; @@ -1584,12 +1733,16 @@ static bool whisper_encode( for (int il = 0; il < model.hparams.n_text_layer; ++il) { auto & layer = model.layers_decoder[il]; + wctx.use_buf(ctx0, 0); + struct ggml_tensor * Kcross = ggml_mul_mat(ctx0, layer.cross_attn_k_w, cur); Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); + wctx.use_buf(ctx0, 1); + struct ggml_tensor * Vcross = ggml_mul_mat(ctx0, layer.cross_attn_v_w, cur); @@ -1600,6 +1753,8 @@ static bool whisper_encode( Vcross), Vcross); + wctx.use_buf(ctx0, -1); + //struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); //struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*n_ctx)); @@ -1615,18 +1770,24 @@ static bool whisper_encode( //////////////////////////////////////////////////////////////////////////// - //printf("%s: used_mem = %f MB\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0); + //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, + // ggml_used_mem(ctx0)/1024.0/1024.0, + // wctx.get_buf_max_mem(0)/1024.0/1024.0, + // wctx.get_buf_max_mem(1)/1024.0/1024.0, + // wctx.get_buf_max_mem(2)/1024.0/1024.0, + // wctx.get_buf_max_mem(3)/1024.0/1024.0); ggml_free(ctx0); wctx.t_encode_us += ggml_time_us() - t_start_us; + wctx.n_encode++; return true; } // evaluate the decoder // -// given text prompt + audio features -> predicts the probabilities for the next token +// given text prompt + audio features -> computes the logits for the next token // // - model: the model // - n_threads: number of threads to use @@ -1670,6 +1831,9 @@ static bool whisper_decode( struct ggml_context * ctx0 = ggml_init(params); + struct ggml_cgraph gf = {}; + gf.n_threads = n_threads; + struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); memcpy(embd->data, tokens, N*ggml_element_size(embd)); @@ -1678,6 +1842,8 @@ static bool whisper_decode( ((int32_t *) position->data)[i] = n_past + i; } + wctx.use_buf(ctx0, 3); + // token encoding + position encoding struct ggml_tensor * cur = ggml_add(ctx0, @@ -1689,211 +1855,248 @@ static bool whisper_decode( for (int il = 0; il < n_layer; ++il) { const auto & layer = model.layers_decoder[il]; - struct ggml_init_params paramsL; - paramsL.mem_size = wctx.buf_compute_layer.size(); - paramsL.mem_buffer = wctx.buf_compute_layer.data(); - - struct ggml_context * ctxL = ggml_init(paramsL); - struct ggml_cgraph gf = {}; - gf.n_threads = n_threads; - // norm { - cur = ggml_norm(ctxL, inpL); + wctx.use_buf(ctx0, 0); + + cur = ggml_norm(ctx0, inpL); // cur = ln_0_w*cur + ln_0_b - cur = ggml_add(ctxL, - ggml_mul(ctxL, - ggml_repeat(ctxL, layer.attn_ln_0_w, cur), + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, layer.attn_ln_0_w, cur), cur), - ggml_repeat(ctxL, layer.attn_ln_0_b, cur)); + ggml_repeat(ctx0, layer.attn_ln_0_b, cur)); } // self-attention { - struct ggml_tensor * Qcur = ggml_mul_mat(ctxL, + wctx.use_buf(ctx0, 1); + + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.attn_q_w, cur); - Qcur = ggml_add(ctxL, - ggml_repeat(ctxL, + Qcur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.attn_q_b, Qcur), Qcur); - Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); + Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); // note: no bias for Key - struct ggml_tensor * Kcur = ggml_mul_mat(ctxL, + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.attn_k_w, cur); - Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); + Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); - struct ggml_tensor * Vcur = ggml_mul_mat(ctxL, + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.attn_v_w, cur); - Vcur = ggml_add(ctxL, - ggml_repeat(ctxL, + Vcur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.attn_v_b, Vcur), Vcur); // store key and value to memory { - struct ggml_tensor * k = ggml_view_1d(ctxL, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past)); - struct ggml_tensor * v = ggml_view_1d(ctxL, kv_self.v, N*n_state, (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + n_past)); + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past)); + struct ggml_tensor * v = ggml_view_1d(ctx0, kv_self.v, N*n_state, (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + n_past)); - ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Kcur, k)); - ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Vcur, v)); + ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v)); } // ------ + wctx.use_buf(ctx0, 0); + struct ggml_tensor * Q = - ggml_permute(ctxL, - ggml_cpy(ctxL, + ggml_permute(ctx0, + ggml_cpy(ctx0, Qcur, - ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)), + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)), 0, 2, 1, 3); struct ggml_tensor * K = - ggml_permute(ctxL, - ggml_reshape_3d(ctxL, - ggml_view_1d(ctxL, kv_self.k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.k)*n_state), + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.k)*n_state), n_state/n_head, n_head, n_past + N), 0, 2, 1, 3); + wctx.use_buf(ctx0, 1); + // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q); + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + + wctx.use_buf(ctx0, 0); //struct ggml_tensor * KQ_scaled = - // ggml_scale(ctxL, + // ggml_scale(ctx0, // KQ, - // ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head)) + // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head)) // ); - struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ, n_past); + struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past); + + wctx.use_buf(ctx0, 1); + + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); - struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_masked); + wctx.use_buf(ctx0, 0); struct ggml_tensor * V_trans = - ggml_permute(ctxL, - ggml_reshape_3d(ctxL, - ggml_view_1d(ctxL, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state), + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state), n_state/n_head, n_head, n_past + N), 1, 2, 0, 3); - struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max); + wctx.use_buf(ctx0, 1); - struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); - cur = ggml_cpy(ctxL, + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + cur = ggml_cpy(ctx0, KQV_merged, - ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N)); + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N)); } + // projection { - cur = ggml_mul_mat(ctxL, + wctx.use_buf(ctx0, 0); + + cur = ggml_mul_mat(ctx0, layer.attn_ln_1_w, cur); - cur = ggml_add(ctxL, - ggml_repeat(ctxL, layer.attn_ln_1_b, cur), + wctx.use_buf(ctx0, 1); + + cur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.attn_ln_1_b, cur), cur); } + wctx.use_buf(ctx0, 2); + // add the input - struct ggml_tensor * inpCA = ggml_add(ctxL, cur, inpL); + struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL); // norm { - cur = ggml_norm(ctxL, inpCA); // note: we use inpCA here + wctx.use_buf(ctx0, 0); + + cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here + + wctx.use_buf(ctx0, 1); // cur = ln_0_w*cur + ln_0_b - cur = ggml_add(ctxL, - ggml_mul(ctxL, - ggml_repeat(ctxL, layer.cross_attn_ln_0_w, cur), + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, layer.cross_attn_ln_0_w, cur), cur), - ggml_repeat(ctxL, layer.cross_attn_ln_0_b, cur)); + ggml_repeat(ctx0, layer.cross_attn_ln_0_b, cur)); } // cross-attention { - struct ggml_tensor * Qcur = ggml_mul_mat(ctxL, + wctx.use_buf(ctx0, 0); + + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.cross_attn_q_w, cur); - Qcur = ggml_add(ctxL, - ggml_repeat(ctxL, + Qcur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.cross_attn_q_b, Qcur), Qcur); - Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); + Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); // Kcross is already scaled struct ggml_tensor * Kcross = - ggml_reshape_3d(ctxL, - ggml_view_1d(ctxL, wctx.kv_cross.k, M*n_state, il*M*ggml_element_size(wctx.kv_cross.k)*n_state), + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, wctx.kv_cross.k, M*n_state, il*M*ggml_element_size(wctx.kv_cross.k)*n_state), n_state/n_head, n_head, M); struct ggml_tensor * Vcross = - ggml_reshape_3d(ctxL, - ggml_view_1d(ctxL, wctx.kv_cross.v, M*n_state, il*M*ggml_element_size(wctx.kv_cross.v)*n_state), + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, wctx.kv_cross.v, M*n_state, il*M*ggml_element_size(wctx.kv_cross.v)*n_state), n_state/n_head, n_head, M); + struct ggml_tensor * V_trans = ggml_permute(ctx0, Vcross, 1, 2, 0, 3); + // ------ + wctx.use_buf(ctx0, 1); + struct ggml_tensor * Q = - ggml_permute(ctxL, - ggml_cpy(ctxL, + ggml_permute(ctx0, + ggml_cpy(ctx0, Qcur, - ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)), + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)), 0, 2, 1, 3); - struct ggml_tensor * K = ggml_permute(ctxL, Kcross, 0, 2, 1, 3); + struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3); + + wctx.use_buf(ctx0, 0); // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q); + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); //struct ggml_tensor * KQ_scaled = - // ggml_scale(ctxL, + // ggml_scale(ctx0, // KQ, - // ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head)) + // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head)) // ); // no masking for cross-attention - //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ_scaled, n_past); + //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); - struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ); + wctx.use_buf(ctx0, 1); - struct ggml_tensor * V_trans = ggml_permute(ctxL, Vcross, 1, 2, 0, 3); + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ); - struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max); + wctx.use_buf(ctx0, 0); - struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); + + wctx.use_buf(ctx0, 1); + + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); // cur = KQV_merged.contiguous().view(n_state, N) - cur = ggml_cpy(ctxL, + cur = ggml_cpy(ctx0, KQV_merged, - ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N)); + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N)); } // projection { - cur = ggml_mul_mat(ctxL, + wctx.use_buf(ctx0, 0); + + cur = ggml_mul_mat(ctx0, layer.cross_attn_ln_1_w, cur); - cur = ggml_add(ctxL, - ggml_repeat(ctxL, layer.cross_attn_ln_1_b, cur), + wctx.use_buf(ctx0, 1); + + cur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur), cur); } + wctx.use_buf(ctx0, 2); + // add the input - cur = ggml_add(ctxL, cur, inpCA); + cur = ggml_add(ctx0, cur, inpCA); struct ggml_tensor * inpFF = cur; @@ -1901,68 +2104,67 @@ static bool whisper_decode( { // norm { - cur = ggml_norm(ctxL, inpFF); + wctx.use_buf(ctx0, 0); + + cur = ggml_norm(ctx0, inpFF); + + wctx.use_buf(ctx0, 1); // cur = mlp_ln_w*cur + mlp_ln_b - cur = ggml_add(ctxL, - ggml_mul(ctxL, - ggml_repeat(ctxL, layer.mlp_ln_w, cur), + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, layer.mlp_ln_w, cur), cur), - ggml_repeat(ctxL, layer.mlp_ln_b, cur)); + ggml_repeat(ctx0, layer.mlp_ln_b, cur)); } + wctx.use_buf(ctx0, 0); + // fully connected - cur = ggml_mul_mat(ctxL, + cur = ggml_mul_mat(ctx0, layer.mlp_0_w, cur); - cur = ggml_add(ctxL, - ggml_repeat(ctxL, layer.mlp_0_b, cur), + wctx.use_buf(ctx0, 1); + + cur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.mlp_0_b, cur), cur); + wctx.use_buf(ctx0, 0); + // GELU activation - cur = ggml_gelu(ctxL, cur); + cur = ggml_gelu(ctx0, cur); + + wctx.use_buf(ctx0, 1); // projection - cur = ggml_mul_mat(ctxL, + cur = ggml_mul_mat(ctx0, layer.mlp_1_w, cur); - cur = ggml_add(ctxL, - ggml_repeat(ctxL, layer.mlp_1_b, cur), - cur); - } + wctx.use_buf(ctx0, 0); - // output from this layer - struct ggml_tensor * inpO = ggml_add(ctxL, cur, inpFF); - - { - ggml_build_forward_expand(&gf, inpO); - ggml_graph_compute (ctxL, &gf); - - //ggml_graph_print(&gf); + cur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.mlp_1_b, cur), + cur); } - // TODO: this is a hack to have per-layer computation graphs - need to come up with something better - // input for next layer (inpO -> inpL) - memcpy(inpL->data, inpO->data, ggml_nbytes(inpL)); - inpL->op = GGML_OP_NONE; - inpL->src0 = nullptr; - inpL->src1 = nullptr; + wctx.use_buf(ctx0, 3); - if (N > 1) { - //printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0); - } - - ggml_free(ctxL); + inpL = ggml_add(ctx0, cur, inpFF); } cur = inpL; // norm { + wctx.use_buf(ctx0, 0); + cur = ggml_norm(ctx0, cur); + wctx.use_buf(ctx0, 1); + cur = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.d_ln_w, cur), @@ -1970,29 +2172,44 @@ static bool whisper_decode( ggml_repeat(ctx0, model.d_ln_b, cur)); } + wctx.use_buf(ctx0, 0); + + // compute logits only for the last token + // comment this line to compute logits for all N tokens + // might be useful in the future + cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]); + struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur); + wctx.use_buf(ctx0, -1); + // run the computation { - struct ggml_cgraph gf = {}; - gf.n_threads = n_threads; - ggml_build_forward_expand(&gf, logits); ggml_graph_compute (ctx0, &gf); } - logits_out.resize(N*n_vocab); - memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab); + // extract logits for all N tokens + //logits_out.resize(N*n_vocab); + //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab); + + // extract logits only for the last token + logits_out.resize(n_vocab); + memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab); if (N > 1) { - //const float mem_per_token = ggml_used_mem(ctx0)/1024.0/1024.0/N; - //printf("%s: used_mem = %f MB / %f per token\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, mem_per_token); - //printf("%s: max mem = %f MB\n", __func__, mem_per_token*model.hparams.n_text_ctx); + //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, + // ggml_used_mem(ctx0)/1024.0/1024.0, + // wctx.get_buf_max_mem(0)/1024.0/1024.0, + // wctx.get_buf_max_mem(1)/1024.0/1024.0, + // wctx.get_buf_max_mem(2)/1024.0/1024.0, + // wctx.get_buf_max_mem(3)/1024.0/1024.0); } ggml_free(ctx0); wctx.t_decode_us += ggml_time_us() - t_start_us; + wctx.n_decode++; return true; } @@ -2644,12 +2861,17 @@ whisper_token whisper_token_transcribe(void) { void whisper_print_timings(struct whisper_context * ctx) { const int64_t t_end_us = ggml_time_us(); + const int32_t n_sample = std::max(1, ctx->n_sample); + const int32_t n_encode = std::max(1, ctx->n_encode); + const int32_t n_decode = std::max(1, ctx->n_decode); + fprintf(stderr, "\n"); + fprintf(stderr, "%s: fallbacks = %3d p / %3d h\n", __func__, ctx->n_fail_p, ctx->n_fail_h); fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us/1000.0f); fprintf(stderr, "%s: mel time = %8.2f ms\n", __func__, ctx->t_mel_us/1000.0f); - fprintf(stderr, "%s: sample time = %8.2f ms\n", __func__, ctx->t_sample_us/1000.0f); - fprintf(stderr, "%s: encode time = %8.2f ms / %.2f ms per layer\n", __func__, ctx->t_encode_us/1000.0f, ctx->t_encode_us/1000.0f/ctx->model.hparams.n_audio_layer); - fprintf(stderr, "%s: decode time = %8.2f ms / %.2f ms per layer\n", __func__, ctx->t_decode_us/1000.0f, ctx->t_decode_us/1000.0f/ctx->model.hparams.n_text_layer); + fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_sample_us, n_sample, 1e-3f*ctx->t_sample_us/n_sample); + fprintf(stderr, "%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_encode_us, n_encode, 1e-3f*ctx->t_encode_us/n_encode); + fprintf(stderr, "%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_decode_us, n_decode, 1e-3f*ctx->t_decode_us/n_decode); fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); } @@ -2683,7 +2905,7 @@ const char * whisper_print_system_info(void) { struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) { struct whisper_full_params result = { - /*.strategy =*/ WHISPER_SAMPLING_GREEDY, + /*.strategy =*/ strategy, /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), /*.n_max_text_ctx =*/ 16384, @@ -2702,6 +2924,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.thold_pt =*/ 0.01f, /*.thold_ptsum =*/ 0.01f, /*.max_len =*/ 0, + /*.split_on_word =*/ false, /*.max_tokens =*/ 0, /*.speed_up =*/ false, @@ -2713,6 +2936,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.language =*/ "en", /*.suppress_blank =*/ true, + /*.suppress_non_speech_tokens =*/true, /*.temperature =*/ 0.0f, /*.max_initial_ts =*/ 1.0f, @@ -2768,9 +2992,35 @@ static void whisper_exp_compute_token_level_timestamps( float thold_pt, float thold_ptsum); +// trim from start (in place) +static inline void ltrim(std::string &s) { + s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) { + return !std::isspace(ch); + })); +} + +// trim from end (in place) +static inline void rtrim(std::string &s) { + s.erase(std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) { + return !std::isspace(ch); + }).base(), s.end()); +} + +// trim from both ends (in place) +static inline void trim(std::string &s) { + rtrim(s); + ltrim(s); +} + +static inline bool should_split_on_word(const char * txt, bool split_on_word) { + if (!split_on_word) return true; + + return txt[0] == ' '; +} + // wrap the last segment to max_len characters // returns the number of new segments -static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) { +static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool split_on_word) { auto segment = ctx.result_all.back(); int res = 1; @@ -2785,11 +3035,14 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) { } const auto txt = whisper_token_to_str(&ctx, token.id); - const int cur = strlen(txt); - if (acc + cur > max_len && i > 0) { + if (acc + cur > max_len && i > 0 && should_split_on_word(txt, split_on_word)) { // split here + if (split_on_word) { + trim(text); + } + ctx.result_all.back().text = std::move(text); ctx.result_all.back().t1 = token.t0; ctx.result_all.back().tokens.resize(i); @@ -2817,11 +3070,22 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) { } } + if (split_on_word) { + trim(text); + } ctx.result_all.back().text = std::move(text); return res; } +static const std::vector non_speech_tokens +{ + "\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^", + "_", "`", "{", "|", "}", "~", "「", "」", "『", "』", "<<", ">>", "<<<", ">>>", "--", + "---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪", + "♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯" +}; + // process the logits for the selected decoder // - applies logit filters // - computes logprobs and probs @@ -2878,6 +3142,37 @@ static void whisper_process_logits( logits[vocab.token_sot] = -INFINITY; logits[vocab.token_solm] = -INFINITY; + // suppress task tokens + logits[vocab.token_translate] = -INFINITY; + logits[vocab.token_transcribe] = -INFINITY; + + + // suppress non-speech tokens + // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 + if (params.suppress_non_speech_tokens) + { + for (const std::string &token : non_speech_tokens) + { + std::string suppress_tokens[] = {token, " " + token}; + for (const std::string &suppress_token : suppress_tokens) + { + if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end()) + { + logits[vocab.token_to_id.at(suppress_token)] = -INFINITY; + } + } + } + // allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word + if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) + { + logits[vocab.token_to_id.at(" -")] = -INFINITY; + } + if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) + { + logits[vocab.token_to_id.at(" '")] = -INFINITY; + } + } + // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424 { @@ -2910,6 +3205,16 @@ static void whisper_process_logits( } } + // condition timestamp tokens to be increasing + // ref: https://github.com/openai/whisper/pull/831#issuecomment-1385910556 + if (decoder.has_ts) { + const int tid0 = decoder.seek_delta/2; + + for (int i = vocab.token_beg; i < vocab.token_beg + tid0; ++i) { + logits[i] = -INFINITY; + } + } + // populate the logprobs array (log_softmax) { const float logit_max = *std::max_element(logits.begin(), logits.end()); @@ -3004,7 +3309,7 @@ static void whisper_process_logits( } static whisper_token_data whisper_sample_token( - const whisper_context & ctx, + whisper_context & ctx, const whisper_decoder & decoder, bool best) { whisper_token_data result = { @@ -3059,6 +3364,8 @@ static whisper_token_data whisper_sample_token( result.pt = result.p; } + ctx.n_sample++; + return result; } @@ -3091,10 +3398,10 @@ static std::vector whisper_sample_token_topk( std::vector result; result.reserve(k); - whisper_token tid; + whisper_token tid = vocab.token_beg; - float pt; - float ptsum; + float pt = 0.0; + float ptsum = 0.0; { double sum_ts = 0.0; @@ -3127,6 +3434,8 @@ static std::vector whisper_sample_token_topk( } } + ctx.n_sample++; + return result; } @@ -3211,7 +3520,7 @@ int whisper_full( fprintf(stderr, "%s: failed to auto-detect language\n", __func__); return -3; } - + ctx->lang_id = lang_id; params.language = whisper_lang_str(lang_id); fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]); @@ -3308,6 +3617,7 @@ int whisper_full( std::vector prompt_init = { whisper_token_sot(ctx) }; if (whisper_is_multilingual(ctx)) { const int lang_id = whisper_lang_id(params.language); + ctx->lang_id = lang_id; prompt_init.push_back(whisper_token_lang(ctx, lang_id)); if (params.translate) { prompt_init.push_back(whisper_token_translate()); @@ -3432,7 +3742,7 @@ int whisper_full( prompt.clear(); // if we have already generated some text, use it as a prompt to condition the next generation - if (!prompt_past.empty() && t_cur > 0.5f) { + if (!prompt_past.empty() && t_cur < 0.5f && params.n_max_text_ctx > 0) { int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size())); prompt = { whisper_token_prev(ctx) }; @@ -3443,11 +3753,11 @@ int whisper_full( prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); // print the prompt - //WHISPER_PRINT_DEBUG("\n\n"); - //for (int i = 0; i < (int) prompt.size(); i++) { - // WHISPER_PRINT_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str()); - //} - //WHISPER_PRINT_DEBUG("\n\n"); + WHISPER_PRINT_DEBUG("\n\n"); + for (int i = 0; i < (int) prompt.size(); i++) { + WHISPER_PRINT_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str()); + } + WHISPER_PRINT_DEBUG("\n\n"); if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) { fprintf(stderr, "%s: failed to decode\n", __func__); @@ -3544,7 +3854,7 @@ int whisper_full( return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all; }); - int cur_c = 0; + unsigned int cur_c = 0; for (int j = 0; j < n_decoders_cur; ++j) { auto & decoder = ctx->decoders[j]; @@ -3555,7 +3865,7 @@ int whisper_full( auto & cur = beam_candidates[cur_c++]; - while (beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) { + while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) { ++cur_c; } @@ -3721,11 +4031,12 @@ int whisper_full( WHISPER_PRINT_DEBUG("%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n", __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy); - if (decoder.sequence.result_len > 8 && decoder.sequence.entropy < params.entropy_thold) { + if (decoder.sequence.result_len > 32 && decoder.sequence.entropy < params.entropy_thold) { WHISPER_PRINT_DEBUG("%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n", __func__, j, decoder.sequence.entropy, params.entropy_thold); decoder.failed = true; + ctx->n_fail_h++; continue; } @@ -3747,6 +4058,7 @@ int whisper_full( if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) { success = false; + ctx->n_fail_p++; } if (success) { @@ -3801,6 +4113,7 @@ int whisper_full( if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) { const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx)); + if (!text.empty()) { const auto tt0 = params.speed_up ? 2*t0 : t0; const auto tt1 = params.speed_up ? 2*t1 : t1; @@ -3828,7 +4141,7 @@ int whisper_full( *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); if (params.max_len > 0) { - n_new = whisper_wrap_segment(*ctx, params.max_len); + n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word); } } if (params.new_segment_callback) { @@ -3872,7 +4185,7 @@ int whisper_full( *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); if (params.max_len > 0) { - n_new = whisper_wrap_segment(*ctx, params.max_len); + n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word); } } if (params.new_segment_callback) { @@ -4025,6 +4338,10 @@ int whisper_full_n_segments(struct whisper_context * ctx) { return ctx->result_all.size(); } +int whisper_full_lang_id(struct whisper_context * ctx) { + return ctx->lang_id; +} + int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) { return ctx->result_all[i_segment].t0; } @@ -4059,6 +4376,145 @@ float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int // ================================================================================================= +// +// Temporary interface needed for exposing ggml interface +// Will be removed in the future when ggml becomes a separate library +// + +WHISPER_API int whisper_bench_memcpy(int n_threads) { + ggml_time_init(); + + size_t n = 50; + size_t arr = n_threads > 0 ? 1024 : n_threads; // trick to avoid compiler optimizations + + // 1 GB array + const size_t size = arr*1024llu*1024llu; + + char * src = (char *) malloc(size); + char * dst = (char *) malloc(size); + + for (size_t i = 0; i < size; i++) src[i] = i; + + memcpy(dst, src, size); // heat-up + + double tsum = 0.0; + + for (size_t i = 0; i < n; i++) { + const int64_t t0 = ggml_time_us(); + + memcpy(dst, src, size); + + const int64_t t1 = ggml_time_us(); + + tsum += (t1 - t0)*1e-6; + + src[0] = rand(); + } + + fprintf(stderr, "memcpy: %.2f GB/s\n", (double) (n*size)/(tsum*1024llu*1024llu*1024llu)); + + // needed to prevent the compile from optimizing the memcpy away + { + double sum = 0.0; + + for (size_t i = 0; i < size; i++) sum += dst[i]; + + fprintf(stderr, "sum: %s %f\n", sum == -536870910.00 ? "ok" : "error", sum); + } + + free(src); + free(dst); + + return 0; +} + +WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) { + ggml_time_init(); + + const int n_max = 128; + + const std::vector sizes = { + 64, 128, 256, 512, 1024, 2048, 4096, + }; + + const size_t N_max = sizes.back(); + + // a: N*N*sizeof(float) + // b: N*N*sizeof(float) + // c: N*N*sizeof(float) + // when F16 is used, there is an extra work buffer of size N*N*sizeof(float) + std::vector buf(4llu*N_max*N_max*sizeof(float) + 4*256); + + for (size_t i = 0; i < buf.size(); i++) buf[i] = i; + + for (int j = 0; j < (int) sizes.size(); j++) { + int n_fp16 = 0; + int n_fp32 = 0; + + // GFLOPS/s + double s_fp16 = 0.0; + double s_fp32 = 0.0; + + const size_t N = sizes[j]; + + for (int k = 0; k < 2; ++k) { + const ggml_type wtype = k == 0 ? GGML_TYPE_F16 : GGML_TYPE_F32; + + double & s = k == 0 ? s_fp16 : s_fp32; + int & n = k == 0 ? n_fp16 : n_fp32; + + struct ggml_init_params gparams = { + /*.mem_size =*/ buf.size(), + /*.mem_buffer =*/ buf.data(), + }; + + struct ggml_context * ctx0 = ggml_init(gparams); + + struct ggml_tensor * a = ggml_new_tensor_2d(ctx0, wtype, N, N); + struct ggml_tensor * b = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, N, N); + + struct ggml_tensor * c = ggml_mul_mat(ctx0, a, b); + + struct ggml_cgraph gf = ggml_build_forward(c); + + gf.n_threads = n_threads; + + double tsum = 0.0; + + // heat-up + ggml_graph_compute(ctx0, &gf); + + for (int i = 0; i < n_max; ++i) { + const int64_t t0 = ggml_time_us(); + + ggml_graph_compute(ctx0, &gf); + + const int64_t t1 = ggml_time_us(); + + tsum += (t1 - t0)*1e-6; + n++; + + if (tsum > 1.0 && n >= 3) { + break; + } + } + + ggml_free(ctx0); + + s = ((2.0*N*N*N*n)/tsum)*1e-9; + } + + fprintf(stderr, "ggml_mul_mat: %5zu x %5zu: F16 %8.1f GFLOPS (%3d runs) / F32 %8.1f GFLOPS (%3d runs)\n", + N, N, s_fp16, n_fp16, s_fp32, n_fp32); + } + + return 0; +} + +// ================================================================================================= + +// ================================================================================================= + // // Experimental stuff below // diff --git a/examples/whisper/whisper.h b/examples/whisper/whisper.h index 84504b7..7eece79 100644 --- a/examples/whisper/whisper.h +++ b/examples/whisper/whisper.h @@ -113,6 +113,16 @@ extern "C" { int n_samples, int n_threads); + // Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2. + // The resulting spectrogram is stored inside the provided whisper context. + // Returns 0 on success + WHISPER_API int whisper_pcm_to_mel_phase_vocoder( + struct whisper_context* ctx, + const float* samples, + int n_samples, + int n_threads); + + // This can be used to set a custom log mel spectrogram inside the provided whisper context. // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram. // n_mel must be 80 @@ -245,7 +255,7 @@ extern "C" { int duration_ms; // audio duration to process in ms bool translate; - bool no_context; // do not use initial prompt for the decoder (if any) + bool no_context; // do not use past transcription (if any) as initial prompt for the decoder bool single_segment; // force single segment output (useful for streaming) bool print_special; // print special tokens (e.g. , , , etc.) bool print_progress; // print progress information @@ -257,6 +267,7 @@ extern "C" { float thold_pt; // timestamp token probability threshold (~0.01) float thold_ptsum; // timestamp token sum probability threshold (~0.01) int max_len; // max segment length in characters + bool split_on_word; // split on word rather than on token (when used with max_len) int max_tokens; // max tokens per segment (0 = no limit) // [EXPERIMENTAL] speed-up techniques @@ -274,6 +285,7 @@ extern "C" { // common decoding parameters: bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89 + bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478 float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97 @@ -329,6 +341,9 @@ extern "C" { // A segment can be a few words, a sentence, or even a paragraph. WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx); + // Language id associated with the current context + WHISPER_API int whisper_full_lang_id(struct whisper_context * ctx); + // Get the start and end time of the specified segment. WHISPER_API int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment); WHISPER_API int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment); @@ -350,6 +365,13 @@ extern "C" { // Get the probability of the specified token in the specified segment. WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token); + //////////////////////////////////////////////////////////////////////////// + + // Temporary helpers needed for exposing ggml interface + + WHISPER_API int whisper_bench_memcpy(int n_threads); + WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads); + #ifdef __cplusplus } #endif diff --git a/include/ggml/ggml.h b/include/ggml/ggml.h index f3c9e5a..18f317b 100644 --- a/include/ggml/ggml.h +++ b/include/ggml/ggml.h @@ -301,6 +301,13 @@ struct ggml_cgraph { int64_t perf_time_us; }; +// scratch buffer +struct ggml_scratch { + size_t offs; + size_t size; + void * data; +}; + struct ggml_init_params { // memory pool size_t mem_size; // bytes @@ -327,6 +334,8 @@ void ggml_free(struct ggml_context * ctx); size_t ggml_used_mem(const struct ggml_context * ctx); +size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch); + struct ggml_tensor * ggml_new_tensor( struct ggml_context * ctx, enum ggml_type type, diff --git a/src/ggml.c b/src/ggml.c index c59ee64..d67612c 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -339,8 +339,12 @@ int64_t ggml_cycles_per_ms(void) { #if defined(__cpp_lib_hardware_interference_size) #define CACHE_LINE_SIZE hardware_destructive_interference_size #else +#if defined(__POWER9_VECTOR__) +#define CACHE_LINE_SIZE 128 +#else #define CACHE_LINE_SIZE 64 #endif +#endif static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); @@ -609,9 +613,12 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); #define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ? \ vec_extract_fp32_from_shorth(vec_xl(0, p - GGML_F16_EPR)) : \ vec_extract_fp32_from_shortl(vec_xl(0, p)) -#define GGML_F16_VEC_STORE(p, r, i) \ - if (i & 0x1) \ - vec_xst(vec_pack_to_short_fp32(r[i], r[i - 1]), 0, p - GGML_F16_EPR) +#define GGML_ENDIAN_BYTE(i) ((unsigned char *)&(uint16_t){1})[i] +#define GGML_F16_VEC_STORE(p, r, i) \ + if (i & 0x1) \ + vec_xst(vec_pack_to_short_fp32(r[i - GGML_ENDIAN_BYTE(1)], \ + r[i - GGML_ENDIAN_BYTE(0)]), \ + 0, p - GGML_F16_EPR) #elif defined(__wasm_simd128__) @@ -1251,7 +1258,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { // struct ggml_object { - size_t offset; + size_t offs; size_t size; struct ggml_object * next; @@ -1277,6 +1284,9 @@ struct ggml_context { struct ggml_object * objects_begin; struct ggml_object * objects_end; + + struct ggml_scratch scratch; + struct ggml_scratch scratch_save; }; struct ggml_context_container { @@ -1339,7 +1349,7 @@ inline static void ggml_critical_section_end(void) { void ggml_print_object(const struct ggml_object * obj) { GGML_PRINT(" - ggml_object: offset = %zu, size = %zu, next = %p\n", - obj->offset, obj->size, (const void *) obj->next); + obj->offs, obj->size, (const void *) obj->next); } void ggml_print_objects(const struct ggml_context * ctx) { @@ -1535,12 +1545,14 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { } *ctx = (struct ggml_context) { - .mem_size = params.mem_size, - .mem_buffer = params.mem_buffer ? params.mem_buffer : malloc(params.mem_size), - .mem_buffer_owned = params.mem_buffer ? false : true, - .n_objects = 0, - .objects_begin = NULL, - .objects_end = NULL, + /*.mem_size =*/ params.mem_size, + /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : malloc(params.mem_size), + /*.mem_buffer_owned =*/ params.mem_buffer ? false : true, + /*.n_objects =*/ 0, + /*.objects_begin =*/ NULL, + /*.objects_end =*/ NULL, + /*.scratch =*/ { 0, 0, NULL, }, + /*.scratch_save =*/ { 0, 0, NULL, }, }; ggml_assert_aligned(ctx->mem_buffer); @@ -1563,7 +1575,7 @@ void ggml_free(struct ggml_context * ctx) { g_state.contexts[i].used = false; GGML_PRINT_DEBUG("%s: context %d with %d objects has been freed. memory used = %zu\n", - __func__, i, ctx->n_objects, ctx->objects_end->offset + ctx->objects_end->size); + __func__, i, ctx->n_objects, ctx->objects_end->offs + ctx->objects_end->size); if (ctx->mem_buffer_owned) { free(ctx->mem_buffer); @@ -1582,7 +1594,15 @@ void ggml_free(struct ggml_context * ctx) { } size_t ggml_used_mem(const struct ggml_context * ctx) { - return ctx->objects_end->offset + ctx->objects_end->size; + return ctx->objects_end->offs + ctx->objects_end->size; +} + +size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch) { + const size_t result = ctx->scratch.data ? ctx->scratch.offs : 0; + + ctx->scratch = scratch; + + return result; } //////////////////////////////////////////////////////////////////////////////// @@ -1596,9 +1616,9 @@ struct ggml_tensor * ggml_new_tensor_impl( // always insert objects at the end of the context's memory pool struct ggml_object * obj_cur = ctx->objects_end; - const size_t cur_offset = obj_cur == NULL ? 0 : obj_cur->offset; - const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size; - const size_t cur_end = cur_offset + cur_size; + const size_t cur_offs = obj_cur == NULL ? 0 : obj_cur->offs; + const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size; + const size_t cur_end = cur_offs + cur_size; size_t size_needed = 0; @@ -1609,25 +1629,52 @@ struct ggml_tensor * ggml_new_tensor_impl( } // align to GGML_MEM_ALIGN size_needed = ((size_needed + GGML_MEM_ALIGN - 1)/GGML_MEM_ALIGN)*GGML_MEM_ALIGN; - - } - size_needed += sizeof(struct ggml_tensor); - - if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) { - GGML_PRINT("%s: not enough space in the context's memory pool\n", __func__); - assert(false); - return NULL; } char * const mem_buffer = ctx->mem_buffer; - struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end); - *obj_new = (struct ggml_object) { - .offset = cur_end + GGML_OBJECT_SIZE, - .size = size_needed, - .next = NULL, - }; + if (ctx->scratch.data == NULL || data != NULL) { + size_needed += sizeof(struct ggml_tensor); + + if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) { + GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", + __func__, cur_end + size_needed + GGML_OBJECT_SIZE, ctx->mem_size); + assert(false); + return NULL; + } + + *obj_new = (struct ggml_object) { + .offs = cur_end + GGML_OBJECT_SIZE, + .size = size_needed, + .next = NULL, + }; + } else { + if (ctx->scratch.offs + size_needed > ctx->scratch.size) { + GGML_PRINT("%s: not enough space in the scratch memory\n", __func__); + assert(false); + return NULL; + } + + if (cur_end + sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE > ctx->mem_size) { + GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", + __func__, cur_end + sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE, ctx->mem_size); + assert(false); + return NULL; + } + + data = (char * const) ctx->scratch.data + ctx->scratch.offs; + + *obj_new = (struct ggml_object) { + .offs = cur_end + GGML_OBJECT_SIZE, + .size = sizeof(struct ggml_tensor), + .next = NULL, + }; + + //printf("scratch offs = %zu, size_needed = %zu\n", ctx->scratch.offs, size_needed); + + ctx->scratch.offs += size_needed; + } if (obj_cur != NULL) { obj_cur->next = obj_new; @@ -1638,9 +1685,9 @@ struct ggml_tensor * ggml_new_tensor_impl( ctx->objects_end = obj_new; - //GGML_PRINT_DEBUG("%s: inserted new object at %zu\n", __func__, cur_end); + //printf("%s: inserted new object at %zu, size = %zu\n", __func__, cur_end, obj_new->size); - struct ggml_tensor * const result = (struct ggml_tensor *)(mem_buffer + obj_new->offset); + struct ggml_tensor * const result = (struct ggml_tensor *)(mem_buffer + obj_new->offs); ggml_assert_aligned(result); @@ -1683,7 +1730,7 @@ struct ggml_tensor * ggml_new_tensor( struct ggml_context * ctx, enum ggml_type type, int n_dims, - const int* ne) { + const int * ne) { return ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL); } @@ -1725,16 +1772,26 @@ struct ggml_tensor * ggml_new_tensor_4d( } struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) { + ctx->scratch_save = ctx->scratch; + ctx->scratch.data = NULL; + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); + ctx->scratch = ctx->scratch_save; + ggml_set_i32(result, value); return result; } struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) { + ctx->scratch_save = ctx->scratch; + ctx->scratch.data = NULL; + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); + ctx->scratch = ctx->scratch_save; + ggml_set_f32(result, value); return result; @@ -2343,7 +2400,7 @@ struct ggml_tensor * ggml_repeat( result->op = GGML_OP_REPEAT; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src0 = a; - result->src1 = NULL; + result->src1 = b; return result; } @@ -2959,9 +3016,7 @@ struct ggml_tensor * ggml_diag_mask_inf( // TODO: when implement backward, fix this: //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); struct ggml_tensor * result = ggml_view_tensor(ctx, a); - - struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); - ((int32_t *) b->data)[0] = n_past; + struct ggml_tensor * b = ggml_new_i32(ctx, n_past); result->op = GGML_OP_DIAG_MASK_INF; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -4293,7 +4348,9 @@ static bool ggml_compute_forward_mul_mat_use_blas( const int ne1 = dst->ne[1]; // TODO: find the optimal values for these - if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ne0 >= 32 && ne1 >= 32 && ne10 >= 32) { + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ( + (ne0 >= 32 && ne1 >= 32 && ne10 >= 32) + )) { //printf("BLAS: %d %d %d\n", ne0, ne1, ne10); return true; } @@ -4373,7 +4430,9 @@ static void ggml_compute_forward_mul_mat_f32( if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { GGML_ASSERT(nb10 == sizeof(float)); - if (params->ith != 0) return; + if (params->ith != 0) { + return; + } if (params->type == GGML_TASK_INIT) { return; @@ -4616,7 +4675,9 @@ static void ggml_compute_forward_mul_mat_f16_f32( if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { GGML_ASSERT(nb10 == sizeof(float)); - if (params->ith != 0) return; + if (params->ith != 0) { + return; + } if (params->type == GGML_TASK_INIT) { return; @@ -7054,7 +7115,7 @@ struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cg #ifdef __APPLE__ //#include - +// //typedef os_unfair_lock ggml_lock_t; // //#define ggml_lock_init(x) UNUSED(x) @@ -7161,6 +7222,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { if (state->params.ith < state->params.nth) { ggml_compute_forward(&state->params, state->node); } + state->node = NULL; } else { break; @@ -7205,6 +7267,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) .node = NULL, .shared = &state_shared, }; + int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]); assert(rc == 0); UNUSED(rc); @@ -7273,8 +7336,12 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) node->src1->type == GGML_TYPE_F32) { #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { - node->n_tasks = 1; + node->n_tasks = 1; // TODO: this actually is doing nothing + // the threads are still spinning cur = sizeof(float)*(node->src0->ne[0]*node->src0->ne[1]); + //printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]); + //printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]); + //printf("cur = %zu\n", cur); } else { cur = sizeof(ggml_fp16_t)*ggml_nelements(node->src1); }