Compare commits

...

23 Commits
t5 ... master

Author SHA1 Message Date
katsu560 4c2f924553
cmake : update CMakeLists.txt to add correct flags (#26)
1 year ago
Georgi Gerganov ba3e8a3d7f
readme : update Roadmap
1 year ago
Georgi Gerganov 2546cb7780
readme : add Roadmap section
1 year ago
Georgi Gerganov 8f8a5aca99
sync : latest whisper.cpp
1 year ago
Georgi Gerganov efa2cc36a2
tests : fix cblas_sgemm call
1 year ago
Georgi Gerganov 3b3ad42906
tests : add SVD experiments
1 year ago
Georgi Gerganov a6acb3318a
sync : latest whisper.cpp (scratch buffers in ggml)
1 year ago
Georgi Gerganov 47b297224e
Update README.md
1 year ago
Takuya Takeuchi 0467385010
cmake : configure CMAKE_C_FLAGS and target_link_libraries for MSVC (#15)
1 year ago
Georgi Gerganov fb64edddb7
gpt : fix sampling to use the temperature (close #16)
1 year ago
Georgi Gerganov c40a5b51a0
ggml : sync latest whisper.cpp
1 year ago
Georgi Gerganov a0f2f68cdb
gpt-2 : fix broken prompt due to recent experiments
1 year ago
Georgi Gerganov dee3684fec
ggml : sync latest whisper.cpp
1 year ago
Georgi Gerganov 6ed4da0b03
cmake : disable warnings about unused functions
1 year ago
Georgi Gerganov 06e2a3b721
ggml : bugfix in new soft max computation
1 year ago
Georgi Gerganov 78af1420bf
tests : change test2 eps
1 year ago
Georgi Gerganov 1af4cf0102
ggml : sync with latest whisper.cpp
1 year ago
Georgi Gerganov 73a7916d30
tests : some more quantization experiments
1 year ago
Georgi Gerganov e0abac1be7
sync : forgot to sync ggml.h
1 year ago
Georgi Gerganov 45fc4fed0b
sync : latest changes from whisper.cpp
1 year ago
Georgi Gerganov deb0c486c7
tests : wip quantized matrix multiplication method 2
1 year ago
Georgi Gerganov d677c7f61d tests : minor fixes for x86
1 year ago
Georgi Gerganov 446ccf3ab1
tests : experiments with n-bit quantized matrix multiplication
1 year ago

1
.gitignore vendored

@ -9,3 +9,4 @@ compile_commands.json
.DS_Store
src/arm_neon.h
tests/arm_neon.h

@ -47,6 +47,7 @@ endif()
#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -ffast-math")
#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=native")
#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mcpu=native")
# dependencies

@ -2,6 +2,9 @@
Tensor library for machine learning
***Note that this project is under development and not ready for production use. \
Some of the development is currently happening in the [whisper.cpp](https://github.com/ggerganov/whisper.cpp) repo***
## Features
- Written in C
@ -13,9 +16,15 @@ Tensor library for machine learning
- No third-party dependencies
- Zero memory allocations during runtime
***Note that this project is under development and not ready for production use.
Most of the development is currently happening in the [whisper.cpp](https://github.com/ggerganov/whisper.cpp) repo,
so if you are interested in this project, make sure to follow what is happening there.***
## Roadmap
- [X] Example of GPT-2 inference [examples/gpt-2](https://github.com/ggerganov/ggml/tree/master/examples/gpt-2)
- [X] Example of GPT-J inference [examples/gpt-j](https://github.com/ggerganov/ggml/tree/master/examples/gpt-j)
- [X] Example of Whisper inference [examples/whisper](https://github.com/ggerganov/ggml/tree/master/examples/whisper)
- [ ] Support 4-bit integer quantization https://github.com/ggerganov/ggml/pull/27
- [ ] Example of FLAN-T5 inference https://github.com/ggerganov/ggml/pull/12
- [ ] Example of LLaMA inference
- [ ] Example of RWKV inference
## Whisper inference (example)

@ -347,7 +347,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
// - n_threads: number of threads to use
// - n_past: the context size so far
// - embd_inp: the embeddings of the tokens in the context
// - embd_w: the predicted probabilities of the next token
// - embd_w: the predicted logits for the next token
//
bool gpt2_eval(
const gpt2_model & model,
@ -496,7 +496,6 @@ bool gpt2_eval(
ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
);
#if 0
// KQ_masked = mask_past(KQ_scaled)
// [n_past + N, N, 12]
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
@ -504,15 +503,6 @@ bool gpt2_eval(
// KQ = soft_max(KQ_masked)
// [n_past + N, N, 12]
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
#else
// KQ_masked = mask_past(KQ_scaled)
// [n_past + N, N, 12]
//struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
// KQ = soft_max(KQ_masked)
// [n_past + N, N, 12]
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_scaled);
#endif
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
// [n_past + N, 64, 12]
@ -637,7 +627,7 @@ bool gpt2_eval(
inpL = ggml_mul_mat(ctx0, model.wte, inpL);
// logits -> probs
inpL = ggml_soft_max(ctx0, inpL);
//inpL = ggml_soft_max(ctx0, inpL);
// run the computation
ggml_build_forward_expand(&gf, inpL);
@ -651,7 +641,7 @@ bool gpt2_eval(
//embd_w.resize(n_vocab*N);
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
// return result for just the last token
// return result just for the last token
embd_w.resize(n_vocab);
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
@ -708,7 +698,7 @@ int main(int argc, char ** argv) {
int64_t t_sample_us = 0;
int64_t t_predict_us = 0;
std::vector<float> embd_w;
std::vector<float> logits;
// tokenize the prompt
std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);
@ -724,14 +714,14 @@ int main(int argc, char ** argv) {
// determine the required inference memory per token:
size_t mem_per_token = 0;
gpt2_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, embd_w, mem_per_token);
gpt2_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
// predict
if (embd.size() > 0) {
const int64_t t_start_us = ggml_time_us();
if (!gpt2_eval(model, params.n_threads, n_past, embd, embd_w, mem_per_token)) {
if (!gpt2_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) {
printf("Failed to predict\n");
return 1;
}
@ -755,7 +745,7 @@ int main(int argc, char ** argv) {
{
const int64_t t_start_sample_us = ggml_time_us();
id = gpt_sample_top_k_top_p(vocab, embd_w.data() + (embd_w.size() - n_vocab), top_k, top_p, temp, rng);
id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng);
t_sample_us += ggml_time_us() - t_start_sample_us;
}

@ -214,8 +214,11 @@ make -j4 gpt-j
```
To run the `gpt-j` tool, you need the 12GB `ggml-model.bin` file which contains the GPT-J model in
[ggml](https://github.com/ggerganov/ggml) compatible format. In the instructions above, I download the binary file
directly from one of my servers, using the [download-ggml-model.sh](download-ggml-model.sh) script.
[ggml](https://github.com/ggerganov/ggml) compatible format. In the instructions above, the binary file
is downloaded from my repository on Hugging Face using the [download-ggml-model.sh](download-ggml-model.sh) script.
You can also, download the file manually from this link:
https://huggingface.co/datasets/ggerganov/ggml/tree/main
---

@ -355,7 +355,7 @@ bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab &
// - n_threads: number of threads to use
// - n_past: the context size so far
// - embd_inp: the embeddings of the tokens in the context
// - embd_w: the predicted probabilities of the next token
// - embd_w: the predicted logits for the next token
//
// The GPT-J model requires about 16MB of memory per input token.
//
@ -559,7 +559,7 @@ bool gptj_eval(
}
// logits -> probs
inpL = ggml_soft_max(ctx0, inpL);
//inpL = ggml_soft_max(ctx0, inpL);
// run the computation
ggml_build_forward_expand(&gf, inpL);
@ -630,7 +630,7 @@ int main(int argc, char ** argv) {
int64_t t_sample_us = 0;
int64_t t_predict_us = 0;
std::vector<float> embd_w;
std::vector<float> logits;
// tokenize the prompt
std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);
@ -644,14 +644,14 @@ int main(int argc, char ** argv) {
// determine the required inference memory per token:
size_t mem_per_token = 0;
gptj_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, embd_w, mem_per_token);
gptj_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
// predict
if (embd.size() > 0) {
const int64_t t_start_us = ggml_time_us();
if (!gptj_eval(model, params.n_threads, n_past, embd, embd_w, mem_per_token)) {
if (!gptj_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) {
printf("Failed to predict\n");
return 1;
}
@ -675,7 +675,7 @@ int main(int argc, char ** argv) {
{
const int64_t t_start_sample_us = ggml_time_us();
id = gpt_sample_top_k_top_p(vocab, embd_w.data() + (embd_w.size() - n_vocab), top_k, top_p, temp, rng);
id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng);
t_sample_us += ggml_time_us() - t_start_sample_us;
}

@ -261,8 +261,11 @@ gpt_vocab::id gpt_sample_top_k_top_p(
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
logits_id.reserve(n_logits);
for (int i = 0; i < n_logits; i++) {
logits_id.push_back(std::make_pair(logits[i], i));
{
const double scale = 1.0/temp;
for (int i = 0; i < n_logits; ++i) {
logits_id.push_back(std::make_pair(logits[i]*scale, i));
}
}
// find the top K tokens
@ -275,59 +278,51 @@ gpt_vocab::id gpt_sample_top_k_top_p(
logits_id.resize(top_k);
// normalize
{
double sum = 0.0f;
for (int i = 0; i < (int)logits_id.size(); i++) {
sum += logits_id[i].first;
}
double maxl = -INFINITY;
for (const auto & kv : logits_id) {
maxl = std::max(maxl, kv.first);
}
sum = 1.0/sum;
for (int i = 0; i < (int)logits_id.size(); i++) {
logits_id[i].first *= sum;
}
// compute probs for the top K tokens
std::vector<double> probs;
probs.reserve(logits_id.size());
double sum = 0.0;
for (const auto & kv : logits_id) {
double p = exp(kv.first - maxl);
probs.push_back(p);
sum += p;
}
// normalize the probs
for (auto & p : probs) {
p /= sum;
}
if (top_p < 1.0f) {
{
double cumsum = 0.0f;
for (int i = 0; i < top_k; i++) {
cumsum += logits_id[i].first;
if (cumsum >= top_p) {
logits_id.resize(i+1);
break;
}
double cumsum = 0.0f;
for (int i = 0; i < top_k; i++) {
cumsum += probs[i];
if (cumsum >= top_p) {
top_k = i + 1;
probs.resize(top_k);
logits_id.resize(top_k);
break;
}
}
// normalize again
{
double sum = 0.0f;
for (int i = 0; i < (int)logits_id.size(); i++) {
sum += logits_id[i].first;
}
sum = 1.0/sum;
for (int i = 0; i < (int)logits_id.size(); i++) {
logits_id[i].first *= sum;
}
cumsum = 1.0/cumsum;
for (int i = 0; i < (int) probs.size(); i++) {
probs[i] *= cumsum;
}
}
//printf("\n");
//for (int i = 0; i < (int)logits_id.size(); i++) {
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), logits_id[i].first);
//for (int i = 0; i < (int) probs.size(); i++) {
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
//}
//exit(0);
// sample from the obtained distribution
std::vector<double> probs;
probs.reserve(logits_id.size());
for (int i = 0; i < (int) logits_id.size(); i++) {
probs.push_back(logits_id[i].first);
}
std::discrete_distribution<> dist(probs.begin(), probs.end());
int idx = dist(rng);

@ -1,7 +1,7 @@
#
# whisper
add_library(whisper-cpp SHARED
add_library(whisper-cpp
whisper.cpp
)
@ -10,6 +10,6 @@ target_link_libraries(whisper-cpp PRIVATE
)
set(TEST_TARGET whisper)
add_executable(${TEST_TARGET} main.cpp)
add_executable(${TEST_TARGET} main.cpp common.cpp)
target_link_libraries(${TEST_TARGET} PRIVATE whisper-cpp)
target_include_directories(${TEST_TARGET} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/..)

@ -0,0 +1,162 @@
#include "common.h"
// third-party utilities
// use your favorite implementations
#define DR_WAV_IMPLEMENTATION
#include "dr_wav.h"
#include <cmath>
#include <regex>
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif
std::string trim(const std::string & s) {
std::regex e("^\\s+|\\s+$");
return std::regex_replace(s, e, "");
}
std::string replace(const std::string & s, const std::string & from, const std::string & to) {
std::string result = s;
size_t pos = 0;
while ((pos = result.find(from, pos)) != std::string::npos) {
result.replace(pos, from.length(), to);
pos += to.length();
}
return result;
}
bool read_wav(const std::string & fname, std::vector<float>& pcmf32, std::vector<std::vector<float>>& pcmf32s, bool stereo) {
drwav wav;
std::vector<uint8_t> wav_data; // used for pipe input from stdin
if (fname == "-") {
{
uint8_t buf[1024];
while (true)
{
const size_t n = fread(buf, 1, sizeof(buf), stdin);
if (n == 0) {
break;
}
wav_data.insert(wav_data.end(), buf, buf + n);
}
}
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
fprintf(stderr, "error: failed to open WAV file from stdin\n");
return false;
}
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
}
else if (drwav_init_file(&wav, fname.c_str(), nullptr) == false) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname.c_str());
return false;
}
if (wav.channels != 1 && wav.channels != 2) {
fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", __func__, fname.c_str());
return false;
}
if (stereo && wav.channels != 2) {
fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization\n", __func__, fname.c_str());
return false;
}
if (wav.sampleRate != COMMON_SAMPLE_RATE) {
fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", __func__, fname.c_str(), COMMON_SAMPLE_RATE/1000);
return false;
}
if (wav.bitsPerSample != 16) {
fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", __func__, fname.c_str());
return false;
}
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
std::vector<int16_t> pcm16;
pcm16.resize(n*wav.channels);
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
drwav_uninit(&wav);
// convert to mono, float
pcmf32.resize(n);
if (wav.channels == 1) {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[i])/32768.0f;
}
} else {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
}
}
if (stereo) {
// convert to stereo, float
pcmf32s.resize(2);
pcmf32s[0].resize(n);
pcmf32s[1].resize(n);
for (uint64_t i = 0; i < n; i++) {
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
}
}
return true;
}
void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
const float rc = 1.0f / (2.0f * M_PI * cutoff);
const float dt = 1.0f / sample_rate;
const float alpha = dt / (rc + dt);
float y = data[0];
for (size_t i = 1; i < data.size(); i++) {
y = alpha * (y + data[i] - data[i - 1]);
data[i] = y;
}
}
bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
const int n_samples = pcmf32.size();
const int n_samples_last = (sample_rate * last_ms) / 1000;
if (n_samples_last >= n_samples) {
// not enough samples - assume no speech
return false;
}
if (freq_thold > 0.0f) {
high_pass_filter(pcmf32, freq_thold, sample_rate);
}
float energy_all = 0.0f;
float energy_last = 0.0f;
for (int i = 0; i < n_samples; i++) {
energy_all += fabsf(pcmf32[i]);
if (i >= n_samples - n_samples_last) {
energy_last += fabsf(pcmf32[i]);
}
}
energy_all /= n_samples;
energy_last /= n_samples_last;
if (verbose) {
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
}
if (energy_last > vad_thold*energy_all) {
return false;
}
return true;
}

@ -0,0 +1,40 @@
#pragma once
// needs to match WHISPER_SAMPLE_RATE
#define COMMON_SAMPLE_RATE 16000
#include <vector>
#include <string>
std::string trim(const std::string & s);
std::string replace(
const std::string & s,
const std::string & from,
const std::string & to);
// Read WAV audio file and store the PCM data into pcmf32
// The sample rate of the audio must be equal to COMMON_SAMPLE_RATE
// If stereo flag is set and the audio has 2 channels, the pcmf32s will contain 2 channel PCM
bool read_wav(
const std::string & fname,
std::vector<float> & pcmf32,
std::vector<std::vector<float>> & pcmf32s,
bool stereo);
// Apply a high-pass frequency filter to PCM audio
// Suppresses frequencies below cutoff Hz
void high_pass_filter(
std::vector<float> & data,
float cutoff,
float sample_rate);
// Basic voice activity detection (VAD) using audio energy adaptive threshold
bool vad_simple(
std::vector<float> & pcmf32,
int sample_rate,
int last_ms,
float vad_thold,
float freq_thold,
bool verbose);

@ -1,9 +1,6 @@
#include "whisper.h"
#include "common.h"
// third-party utilities
// use your favorite implementations
#define DR_WAV_IMPLEMENTATION
#include "dr_wav.h"
#include "whisper.h"
#include <cmath>
#include <fstream>
@ -53,18 +50,24 @@ void replace_all(std::string & s, const std::string & search, const std::string
// command-line parameters
struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t n_processors = 1;
int32_t offset_t_ms = 0;
int32_t offset_n = 0;
int32_t duration_ms = 0;
int32_t n_processors = 1;
int32_t offset_t_ms = 0;
int32_t offset_n = 0;
int32_t duration_ms = 0;
int32_t max_context = -1;
int32_t max_len = 0;
int32_t max_len = 0;
int32_t best_of = 5;
int32_t beam_size = -1;
float word_thold = 0.01f;
float word_thold = 0.01f;
float entropy_thold = 2.40f;
float logprob_thold = -1.00f;
bool speed_up = false;
bool translate = false;
bool diarize = false;
bool split_on_word = false;
bool no_fallback = false;
bool output_txt = false;
bool output_vtt = false;
bool output_srt = false;
@ -80,6 +83,7 @@ struct whisper_params {
std::string model = "models/ggml-base.en.bin";
std::vector<std::string> fname_inp = {};
std::vector<std::string> fname_out = {};
};
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@ -88,6 +92,11 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
if (arg == "-"){
params.fname_inp.push_back(arg);
continue;
}
if (arg[0] != '-') {
params.fname_inp.push_back(arg);
continue;
@ -104,15 +113,22 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; }
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
@ -136,31 +152,38 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false");
fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
fprintf(stderr, " -of FNAME, --output-file FNAME [%-7s] output file path (without file extension)\n", "");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
fprintf(stderr, "\n");
}
@ -176,90 +199,81 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
const int n_segments = whisper_full_n_segments(ctx);
std::string speaker = "";
int64_t t0;
int64_t t1;
// print the last n_new segments
const int s0 = n_segments - n_new;
if (s0 == 0) {
printf("\n");
}
for (int i = s0; i < n_segments; i++) {
if (params.no_timestamps) {
if (params.print_colors) {
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
}
}
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
if (!params.no_timestamps || params.diarize) {
t0 = whisper_full_get_segment_t0(ctx, i);
t1 = whisper_full_get_segment_t1(ctx, i);
}
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
if (!params.no_timestamps) {
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
}
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
}
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("%s", text);
}
fflush(stdout);
} else {
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
if (params.diarize && pcmf32s.size() == 2) {
const int64_t n_samples = pcmf32s[0].size();
std::string speaker;
const int64_t is0 = timestamp_to_sample(t0, n_samples);
const int64_t is1 = timestamp_to_sample(t1, n_samples);
if (params.diarize && pcmf32s.size() == 2) {
const int64_t n_samples = pcmf32s[0].size();
double energy0 = 0.0f;
double energy1 = 0.0f;
const int64_t is0 = timestamp_to_sample(t0, n_samples);
const int64_t is1 = timestamp_to_sample(t1, n_samples);
for (int64_t j = is0; j < is1; j++) {
energy0 += fabs(pcmf32s[0][j]);
energy1 += fabs(pcmf32s[1][j]);
}
double energy0 = 0.0f;
double energy1 = 0.0f;
if (energy0 > 1.1*energy1) {
speaker = "(speaker 0)";
} else if (energy1 > 1.1*energy0) {
speaker = "(speaker 1)";
} else {
speaker = "(speaker ?)";
}
for (int64_t j = is0; j < is1; j++) {
energy0 += fabs(pcmf32s[0][j]);
energy1 += fabs(pcmf32s[1][j]);
}
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
}
if (energy0 > 1.1*energy1) {
speaker = "(speaker 0)";
} else if (energy1 > 1.1*energy0) {
speaker = "(speaker 1)";
} else {
speaker = "(speaker ?)";
if (params.print_colors) {
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
}
}
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
}
if (params.print_colors) {
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
}
}
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size()))));
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
}
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
}
printf("\n");
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("%s%s", speaker.c_str(), text);
}
printf("[%s --> %s] %s%s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), speaker.c_str(), text);
}
// with timestamps or speakers: each segment on new line
if (!params.no_timestamps || params.diarize) {
printf("\n");
}
fflush(stdout);
}
}
@ -340,20 +354,16 @@ bool output_csv(struct whisper_context * ctx, const char * fname) {
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
if (text[0] == ' ')
text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character.
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
//need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
fout << 10 * t0 << ", "
<< 10 * t1 << ", \""
<< text << "\"\n";
//need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
fout << 10 * t0 << ", " << 10 * t1 << ", \"" << text << "\"\n";
}
return true;
}
// karaoke video generation
// outputs a bash script that uses ffmpeg to generate a video with the subtitles
// TODO: font parameter adjustments
@ -487,7 +497,7 @@ int main(int argc, char ** argv) {
// whisper init
struct whisper_context * ctx = whisper_init(params.model.c_str());
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
if (ctx == nullptr) {
fprintf(stderr, "error: failed to initialize whisper context\n");
@ -512,90 +522,14 @@ int main(int argc, char ** argv) {
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto fname_inp = params.fname_inp[f];
const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
// WAV input
{
drwav wav;
std::vector<uint8_t> wav_data; // used for pipe input from stdin
if (fname_inp == "-") {
{
uint8_t buf[1024];
while (true)
{
const size_t n = fread(buf, 1, sizeof(buf), stdin);
if (n == 0) {
break;
}
wav_data.insert(wav_data.end(), buf, buf + n);
}
}
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
fprintf(stderr, "error: failed to open WAV file from stdin\n");
return 4;
}
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
}
else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
return 5;
}
if (wav.channels != 1 && wav.channels != 2) {
fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str());
return 6;
}
if (params.diarize && wav.channels != 2 && params.no_timestamps == false) {
fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", argv[0], fname_inp.c_str());
return 6;
}
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str());
return 8;
}
if (wav.bitsPerSample != 16) {
fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str());
return 9;
}
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
std::vector<int16_t> pcm16;
pcm16.resize(n*wav.channels);
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
drwav_uninit(&wav);
// convert to mono, float
pcmf32.resize(n);
if (wav.channels == 1) {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[i])/32768.0f;
}
} else {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
}
}
if (params.diarize) {
// convert to stereo, float
pcmf32s.resize(2);
pcmf32s[0].resize(n);
pcmf32s[1].resize(n);
for (uint64_t i = 0; i < n; i++) {
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
}
}
if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.diarize)) {
fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str());
continue;
}
// print system information
@ -629,6 +563,8 @@ int main(int argc, char ** argv) {
{
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
wparams.print_realtime = false;
wparams.print_progress = params.print_progress;
wparams.print_timestamps = !params.no_timestamps;
@ -643,11 +579,19 @@ int main(int argc, char ** argv) {
wparams.token_timestamps = params.output_wts || params.max_len > 0;
wparams.thold_pt = params.word_thold;
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
wparams.split_on_word = params.split_on_word;
wparams.speed_up = params.speed_up;
wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
wparams.greedy.best_of = params.best_of;
wparams.beam_search.beam_size = params.beam_size;
wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc;
wparams.entropy_thold = params.entropy_thold;
wparams.logprob_thold = params.logprob_thold;
whisper_print_user_data user_data = { &params, &pcmf32s };
@ -682,34 +626,33 @@ int main(int argc, char ** argv) {
// output to text file
if (params.output_txt) {
const auto fname_txt = fname_inp + ".txt";
const auto fname_txt = fname_out + ".txt";
output_txt(ctx, fname_txt.c_str());
}
// output to VTT file
if (params.output_vtt) {
const auto fname_vtt = fname_inp + ".vtt";
const auto fname_vtt = fname_out + ".vtt";
output_vtt(ctx, fname_vtt.c_str());
}
// output to SRT file
if (params.output_srt) {
const auto fname_srt = fname_inp + ".srt";
const auto fname_srt = fname_out + ".srt";
output_srt(ctx, fname_srt.c_str(), params);
}
// output to WTS file
if (params.output_wts) {
const auto fname_wts = fname_inp + ".wts";
const auto fname_wts = fname_out + ".wts";
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
}
// output to CSV file
// output to CSV file
if (params.output_csv) {
const auto fname_csv = fname_inp + ".csv";
const auto fname_csv = fname_out + ".csv";
output_csv(ctx, fname_csv.c_str());
}
}
}

File diff suppressed because it is too large Load Diff

@ -1,6 +1,7 @@
#ifndef WHISPER_H
#define WHISPER_H
#include <stddef.h>
#include <stdint.h>
#include <stdbool.h>
@ -40,7 +41,7 @@ extern "C" {
//
// ...
//
// struct whisper_context * ctx = whisper_init("/path/to/ggml-base.en.bin");
// struct whisper_context * ctx = whisper_init_from_file("/path/to/ggml-base.en.bin");
//
// if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
// fprintf(stderr, "failed to process audio\n");
@ -73,6 +74,7 @@ extern "C" {
whisper_token tid; // forced timestamp token id
float p; // probability of the token
float plog; // log probability of the token
float pt; // probability of the timestamp token
float ptsum; // sum of probabilities of all timestamp tokens
@ -84,9 +86,20 @@ extern "C" {
float vlen; // voice length of the token
} whisper_token_data;
// Allocates all memory needed for the model and loads the model from the given file.
// Returns NULL on failure.
WHISPER_API struct whisper_context * whisper_init(const char * path_model);
typedef struct whisper_model_loader {
void * context;
size_t (*read)(void * ctx, void * output, size_t read_size);
bool (*eof)(void * ctx);
void (*close)(void * ctx);
} whisper_model_loader;
// Various functions for loading a ggml whisper model.
// Allocate (almost) all memory needed for the model.
// Return NULL on failure
WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model);
WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size);
WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader);
// Frees all memory allocated by the model.
WHISPER_API void whisper_free(struct whisper_context * ctx);
@ -100,6 +113,16 @@ extern "C" {
int n_samples,
int n_threads);
// Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2.
// The resulting spectrogram is stored inside the provided whisper context.
// Returns 0 on success
WHISPER_API int whisper_pcm_to_mel_phase_vocoder(
struct whisper_context* ctx,
const float* samples,
int n_samples,
int n_threads);
// This can be used to set a custom log mel spectrogram inside the provided whisper context.
// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
// n_mel must be 80
@ -124,6 +147,7 @@ extern "C" {
// tokens + n_tokens is the provided context for the decoder.
// n_past is the number of tokens to use from previous decoder calls.
// Returns 0 on success
// TODO: add support for multiple decoders
WHISPER_API int whisper_decode(
struct whisper_context * ctx,
const whisper_token * tokens,
@ -131,14 +155,6 @@ extern "C" {
int n_past,
int n_threads);
// Token sampling methods.
// These are provided for convenience and can be used after each call to whisper_decode().
// You can also implement your own sampling method using the whisper_get_probs() function.
// whisper_sample_best() returns the token with the highest probability
// whisper_sample_timestamp() returns the most probable timestamp token
WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial);
// Convert the provided text into tokens.
// The tokens pointer must be large enough to hold the resulting tokens.
// Returns the number of tokens on success, no more than n_max_tokens
@ -148,7 +164,7 @@ extern "C" {
struct whisper_context * ctx,
const char * text,
whisper_token * tokens,
int n_max_tokens);
int n_max_tokens);
// Largest language id (i.e. number of available languages - 1)
WHISPER_API int whisper_lang_max_id();
@ -180,8 +196,11 @@ extern "C" {
WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
// The probabilities for the next token
WHISPER_API float * whisper_get_probs(struct whisper_context * ctx);
// Token logits obtained from the last call to whisper_decode()
// The logits for the last token are stored in the last row
// Rows: n_tokens
// Cols: n_vocab
WHISPER_API float * whisper_get_logits(struct whisper_context * ctx);
// Token Id -> String. Uses the vocabulary in the provided context
WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token);
@ -210,8 +229,8 @@ extern "C" {
// Available sampling strategies
enum whisper_sampling_strategy {
WHISPER_SAMPLING_GREEDY, // Always select the most probable token
WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet!
WHISPER_SAMPLING_GREEDY, // similar to OpenAI's GreefyDecoder
WHISPER_SAMPLING_BEAM_SEARCH, // similar to OpenAI's BeamSearchDecoder
};
// Text segment callback
@ -224,6 +243,16 @@ extern "C" {
// If it returns false, the computation is aborted
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data);
// Logits filter callback
// Can be used to modify the logits before sampling
// If not NULL, called after applying temperature to logits
typedef void (*whisper_logits_filter_callback)(
struct whisper_context * ctx,
const whisper_token_data * tokens,
int n_tokens,
float * logits,
void * user_data);
// Parameters for the whisper_full() function
// If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp:
// whisper_full_default_params()
@ -231,30 +260,32 @@ extern "C" {
enum whisper_sampling_strategy strategy;
int n_threads;
int n_max_text_ctx;
int n_max_text_ctx; // max tokens to use from past text as prompt for the decoder
int offset_ms; // start offset in ms
int duration_ms; // audio duration to process in ms
bool translate;
bool no_context;
bool no_context; // do not use past transcription (if any) as initial prompt for the decoder
bool single_segment; // force single segment output (useful for streaming)
bool print_special;
bool print_progress;
bool print_realtime;
bool print_timestamps;
bool print_special; // print special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.)
bool print_progress; // print progress information
bool print_realtime; // print results from within whisper.cpp (avoid it, use callback instead)
bool print_timestamps; // print timestamps for each text segment when printing realtime
// [EXPERIMENTAL] token-level timestamps
bool token_timestamps; // enable token-level timestamps
float thold_pt; // timestamp token probability threshold (~0.01)
float thold_ptsum; // timestamp token sum probability threshold (~0.01)
int max_len; // max segment length in characters
bool split_on_word; // split on word rather than on token (when used with max_len)
int max_tokens; // max tokens per segment (0 = no limit)
// [EXPERIMENTAL] speed-up techniques
// note: these can significantly reduce the quality of the output
bool speed_up; // speed-up the audio by 2x using Phase Vocoder
int audio_ctx; // overwrite the audio context size (0 = use default)
// tokens to provide the whisper model as initial prompt
// tokens to provide to the whisper decoder as initial prompt
// these are prepended to any existing text context from a previous call
const whisper_token * prompt_tokens;
int prompt_n_tokens;
@ -262,21 +293,42 @@ extern "C" {
// for auto-detection, set to nullptr, "" or "auto"
const char * language;
// common decoding parameters:
bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478
float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
float length_penalty; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L267
// fallback parameters
// ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L274-L278
float temperature_inc;
float entropy_thold; // similar to OpenAI's "compression_ratio_threshold"
float logprob_thold;
float no_speech_thold; // TODO: not implemented
struct {
int n_past;
int best_of; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264
} greedy;
struct {
int n_past;
int beam_width;
int n_best;
int beam_size; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265
float patience; // TODO: not implemented, ref: https://arxiv.org/pdf/2204.05424.pdf
} beam_search;
// called for every newly generated text segment
whisper_new_segment_callback new_segment_callback;
void * new_segment_callback_user_data;
// called each time before the encoder starts
whisper_encoder_begin_callback encoder_begin_callback;
void * encoder_begin_callback_user_data;
// called by each decoder to filter obtained logits
whisper_logits_filter_callback logits_filter_callback;
void * logits_filter_callback_user_data;
};
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
@ -303,6 +355,9 @@ extern "C" {
// A segment can be a few words, a sentence, or even a paragraph.
WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx);
// Language id associated with the current context
WHISPER_API int whisper_full_lang_id(struct whisper_context * ctx);
// Get the start and end time of the specified segment.
WHISPER_API int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment);
WHISPER_API int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment);
@ -324,6 +379,13 @@ extern "C" {
// Get the probability of the specified token in the specified segment.
WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);
////////////////////////////////////////////////////////////////////////////
// Temporary helpers needed for exposing ggml interface
WHISPER_API int whisper_bench_memcpy(int n_threads);
WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads);
#ifdef __cplusplus
}
#endif

@ -301,6 +301,13 @@ struct ggml_cgraph {
int64_t perf_time_us;
};
// scratch buffer
struct ggml_scratch {
size_t offs;
size_t size;
void * data;
};
struct ggml_init_params {
// memory pool
size_t mem_size; // bytes
@ -327,6 +334,8 @@ void ggml_free(struct ggml_context * ctx);
size_t ggml_used_mem(const struct ggml_context * ctx);
size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch);
struct ggml_tensor * ggml_new_tensor(
struct ggml_context * ctx,
enum ggml_type type,
@ -731,6 +740,8 @@ int ggml_cpu_has_f16c(void);
int ggml_cpu_has_fp16_va(void);
int ggml_cpu_has_wasm_simd(void);
int ggml_cpu_has_blas(void);
int ggml_cpu_has_sse3(void);
int ggml_cpu_has_vsx(void);
#ifdef __cplusplus
}

@ -9,6 +9,7 @@ if (GGML_ALL_WARNINGS)
-Wcast-qual \
-Wstrict-prototypes \
-Wpointer-arith \
-Wno-unused-function \
")
else()
# todo : windows
@ -17,17 +18,101 @@ endif()
# compiler flags
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Werror=vla")
#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fno-math-errno -ffinite-math-only -funsafe-math-optimizations")
if (NOT MSVC)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Werror=vla")
#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fno-math-errno -ffinite-math-only -funsafe-math-optimizations")
endif()
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
if (NOT UNAME_S)
execute_process(COMMAND uname -s OUTPUT_VARIABLE UNAME_S)
endif()
if (NOT UNAME_P)
execute_process(COMMAND uname -p OUTPUT_VARIABLE UNAME_P)
endif()
if (NOT UNAME_M)
execute_process(COMMAND uname -m OUTPUT_VARIABLE UNAME_M)
endif()
message(STATUS "UNAME_S: ${UNAME_S} UNAME_P: ${UNAME_P} UNAME_M: ${UNAME_M}")
# Mac OS + Arm can report x86_64
# ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789
if (UNAME_S MATCHES "Darwin")
if (NOT UNAME_P MATCHES "arm")
execute_process(COMMAND sysctl -n hw.optional.arm64 OUTPUT_VARIABLE SYSCTL_M)
if (SYSCTL_M MATCHES "1")
#set(UNAME_P "arm")
#set(UNAME_M "arm64")
message(WARNING "Your arch is announced as x86_64, but it seems to actually be ARM64. Not fixing that can lead to bad performance. For more info see: https://github.com/ggerganov/whisper.cpp/issues/66\#issuecomment-#1282546789")
endif()
endif()
endif()
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64")
message(STATUS "ARM detected")
#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mcpu=apple-m1")
else()
message(STATUS "x86 detected")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx -mavx2 -mfma -mf16c")
#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx -mavx2 -mfma -mf16c")
if (UNAME_S MATCHES "Darwin")
execute_process(COMMAND sysctl machdep.cpu.features OUTPUT_VARIABLE AVX1_M)
if (AVX1_M MATCHES "AVX1.0")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx")
endif()
execute_process(COMMAND sysctl machdep.cpu.leaf7_features OUTPUT_VARIABLE AVX2_M)
if (AVX2_M MATCHES "AVX2")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx2")
endif()
if (AVX1_M MATCHES "FMA")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma")
endif()
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mf16c")
elseif (UNAME_S MATCHES "Linux")
message(STATUS "Linux detected")
execute_process(COMMAND grep "avx " /proc/cpuinfo OUTPUT_VARIABLE AVX1_M)
if (AVX1_M MATCHES "avx")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx")
endif()
execute_process(COMMAND grep "avx2 " /proc/cpuinfo OUTPUT_VARIABLE AVX2_M)
if (AVX2_M MATCHES "avx2")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx2")
endif()
execute_process(COMMAND grep "fma " /proc/cpuinfo OUTPUT_VARIABLE FMA_M)
if (FMA_M MATCHES "fma")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma")
endif()
execute_process(COMMAND grep "f16c " /proc/cpuinfo OUTPUT_VARIABLE F16C_M)
if (F16C_M MATCHES "f16c")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mf16c")
endif()
execute_process(COMMAND grep "sse3 " /proc/cpuinfo OUTPUT_VARIABLE SSE3_M)
if (SSE3_M MATCHES "sse3")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse3")
endif()
message(STATUS "CMAKE_C_FLAGS: ${CMAKE_C_FLAGS}")
elseif (UNAME_S MATCHES "Haiku")
message(STATUS "Haiku detected")
execute_process(COMMAND sysinfo -cpu | grep "AVX " OUTPUT_VARIABLE AVX1_M)
if (AVX1_M MATCHES "avx")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx")
endif()
execute_process(COMMAND sysinfo -cpu | grep "AVX2 " OUTPUT_VARIABLE AVX2_M)
if (AVX2_M MATCHES "avx2")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx2")
endif()
execute_process(COMMAND sysinfo -cpu | grep "FMA " OUTPUT_VARIABLE FMA_M)
if (FMA_M MATCHES "fma")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma")
endif()
execute_process(COMMAND sysinfo -cpu | grep "F16C " OUTPUT_VARIABLE F16C_M)
if (F16C_M MATCHES "f16c")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mf16c")
endif()
message(STATUS "CMAKE_C_FLAGS: ${CMAKE_C_FLAGS}")
else()
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma -mf16c -mavx -mavx2")
endif()
endif()
@ -62,7 +147,11 @@ target_include_directories(${TARGET} PUBLIC
../include/ggml
)
target_link_libraries(${TARGET} PUBLIC m ${GGML_EXTRA_LIBS} ${CMAKE_THREAD_LIBS_INIT})
if (MSVC)
target_link_libraries(${TARGET} PUBLIC ${GGML_EXTRA_LIBS} ${CMAKE_THREAD_LIBS_INIT})
else()
target_link_libraries(${TARGET} PUBLIC m ${GGML_EXTRA_LIBS} ${CMAKE_THREAD_LIBS_INIT})
endif()
if (BUILD_SHARED_LIBS)
target_link_libraries(${TARGET} PUBLIC

File diff suppressed because it is too large Load Diff

@ -65,6 +65,14 @@ if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" AND NOT GGML_NO_ACCELERATE)
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
endif()
#
# test-mul-mat2
set(TEST_TARGET test-mul-mat2)
add_executable(${TEST_TARGET} ${TEST_TARGET}.c)
target_link_libraries(${TEST_TARGET} PRIVATE ggml)
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
#
# test0
@ -96,3 +104,15 @@ set(TEST_TARGET test3)
add_executable(${TEST_TARGET} ${TEST_TARGET}.c)
target_link_libraries(${TEST_TARGET} PRIVATE ggml)
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
#
# test-svd0 (arm)
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" AND NOT GGML_NO_ACCELERATE)
set(TEST_TARGET test-svd0)
add_executable(${TEST_TARGET} ${TEST_TARGET}.c)
target_link_libraries(${TEST_TARGET} PRIVATE ggml ${GGML_EXTRA_LIBS})
target_compile_options(${TEST_TARGET} PRIVATE ${GGML_EXTRA_FLAGS})
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
endif()

@ -13,14 +13,20 @@
#include <Accelerate/Accelerate.h>
const int M = 1280;
const int N = 1500;
const int N = 1536;
const int K = 1280;
uint64_t get_time_us() {
struct timeval tv;
gettimeofday(&tv, NULL);
return tv.tv_sec * 1000000 + tv.tv_usec;
}
//
// naive implementation
//
void mul_mat_vec_f32_0(
void mul_mat_f32_0(
const float * restrict src0, // M x K
const float * restrict src1, // N x K (transposed)
float * dst,
@ -36,7 +42,7 @@ void mul_mat_vec_f32_0(
}
}
void mul_mat_vec_f16_0(
void mul_mat_f16_0(
const __fp16 * src0,
const __fp16 * src1,
float * dst,
@ -102,7 +108,7 @@ void mul_mat_vec_f16_0(
}
// blocking with block size 32
void mul_mat_vec_f16_1(
void mul_mat_f16_1(
const __fp16 * src0,
const __fp16 * src1,
float * dst,
@ -174,7 +180,7 @@ void mul_mat_vec_f16_1(
}
void mul_mat_vec_f8_0(
void mul_mat_f8_0(
const uint8_t * src0,
const uint8_t * src1,
float * dst,
@ -206,12 +212,6 @@ void mul_mat_vec_f8_0(
}
}
uint64_t get_time_us() {
struct timeval tv;
gettimeofday(&tv, NULL);
return tv.tv_sec * 1000000 + tv.tv_usec;
}
int main(int argc, const char ** argv) {
float * src0 = (float *)malloc(sizeof(float)*M*K);
float * src1 = (float *)malloc(sizeof(float)*N*K);
@ -258,7 +258,7 @@ int main(int argc, const char ** argv) {
method = atoi(argv[1]);
}
const int nIter = 10000;
const int nIter = 1;
const clock_t start = clock();
const uint64_t start_us = get_time_us();
@ -267,24 +267,24 @@ int main(int argc, const char ** argv) {
double sum = 0.0f;
for (int i = 0; i < nIter; i++) {
if (method == 0) {
mul_mat_vec_f32_0(src0, src1, dst, M, N, K);
mul_mat_f32_0(src0, src1, dst, M, N, K);
}
if (method == 1) {
mul_mat_vec_f16_0(src0_fp16, src1_fp16, dst, M, N, K);
mul_mat_f16_0(src0_fp16, src1_fp16, dst, M, N, K);
}
if (method == 2) {
mul_mat_vec_f16_1(src0_fp16, src1_fp16, dst, M, N, K);
mul_mat_f16_1(src0_fp16, src1_fp16, dst, M, N, K);
}
if (method == 3) {
mul_mat_vec_f8_0(src0_fp8, src1_fp8, dst, M, N, K);
mul_mat_f8_0(src0_fp8, src1_fp8, dst, M, N, K);
}
if (method == 4) {
// Use BLAS sgemm from Accelerate framework
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, K, N, 1.0f, src0, N, src1, N, 0.0f, dst, N);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, M, N, K, 1.0f, src0, K, src1, K, 0.0f, dst, N);
}
}

@ -0,0 +1,475 @@
// quantized matrix multiplication
#include "ggml.h"
#include <float.h>
#include <stdint.h>
#include <stdio.h>
#include <assert.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include <math.h>
#include <sys/time.h>
#ifdef __ARM_NEON
#include "arm_neon.h"
#endif
#ifndef MIN
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#endif
const int M = 1280;
const int N = 1536;
const int K = 1280;
const int QK = 64;
#define QB 7
//#define GGML_GQ_USE_FP16_SCALE
#if defined(GGML_GQ_USE_FP16_SCALE)
#define gq_scale_t ggml_fp16_t
#define GGML_FP32_TO_GQ(x) ggml_fp32_to_fp16(x)
#define GGML_GQ_TO_FP32(x) ggml_fp16_to_fp32(x)
#else
#define gq_scale_t float
#define GGML_FP32_TO_GQ(x) (x)
#define GGML_GQ_TO_FP32(x) (x)
#endif
#define gq_quant_t uint64_t
#define gq_t_bits 64
uint64_t get_time_us() {
struct timeval tv;
gettimeofday(&tv, NULL);
return tv.tv_sec * 1000000 + tv.tv_usec;
}
//
// naive implementation
//
void mul_mat_f32_naive(
const float * restrict src0, // M x K
const float * restrict src1, // N x K (transposed)
float * dst,
int m, int n, int k) {
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
float sum = 0;
for (int l = 0; l < k; l++) {
sum += src0[i*k + l] * src1[j*k + l];
}
dst[i*n + j] = sum;
}
}
}
//
// method 1
//
void quantize_1(const float * src, void * dst, int n, int k) {
char * p0 = dst;
gq_quant_t pp[QB];
for (int j = 0; j < n; j++) {
for (int i = 0; i < k/QK; i++) {
float min = FLT_MAX;
float max = -FLT_MAX;
// find min/max
#ifdef __ARM_NEON
{
float32x4_t minv = vdupq_n_f32(FLT_MAX);
float32x4_t maxv = vdupq_n_f32(-FLT_MAX);
for (int l = 0; l < QK; l += 4) {
float32x4_t v = vld1q_f32(src + j*k + i*QK + l);
minv = vminq_f32(minv, v);
maxv = vmaxq_f32(maxv, v);
}
float32x2_t minv32 = vpmin_f32(vget_low_f32(minv), vget_high_f32(minv));
float32x2_t maxv32 = vpmax_f32(vget_low_f32(maxv), vget_high_f32(maxv));
min = MIN(vget_lane_f32(minv32, 0), vget_lane_f32(minv32, 1));
max = MAX(vget_lane_f32(maxv32, 0), vget_lane_f32(maxv32, 1));
//printf("SIMD min/max: %f %f\n", min, max);
}
#else
{
for (int l = 0; l < QK; l++) {
const float v = src[j*k + i*QK + l];
if (v < min) min = v;
if (v > max) max = v;
}
//printf("NORM min/max: %f %f\n", min, max);
}
#endif
const float d = (max - min) / ((1 << QB) - 1);
const float id = d ? 1.0/d : 0.0;
memcpy(p0, &min, sizeof(float)); p0 += sizeof(float);
memcpy(p0, &d, sizeof(float)); p0 += sizeof(float);
//printf("min/max/d/id: %f %f %f %f\n", min, max, d, id);
for (int s = 0; s < QK/gq_t_bits; ++s) {
memset(pp, 0, sizeof(pp));
for (int l = 0; l < gq_t_bits; l++) {
const float v = src[j*k + i*QK + s*gq_t_bits + l];
const uint8_t q = (v - min)*id;
for (int b = 0; b < QB; b++) {
pp[b] |= q & (1 << b) ? (1ULL << l) : 0;
}
}
for (int b = 0; b < QB; b++) {
memcpy(p0, &pp[b], sizeof(gq_quant_t)); p0 += sizeof(gq_quant_t);
}
}
}
}
}
void mul_mat_gq_1(
const void * src0,
const void * src1,
float * dst,
int m, int n, int k) {
const int kp = k & ~(gq_t_bits - 1);
const char * restrict p0 = src0;
const char * restrict p1 = src1;
float s0[QB + 1];
float s1[QB + 1];
gq_quant_t m0[QB + 1];
gq_quant_t m1[QB + 1];
for (int ir0 = 0; ir0 < m; ir0++) {
for (int ir1 = 0; ir1 < n; ir1++) {
float sumf = 0.0;
const char * restrict pp0 = p0 + ir0*((2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_quant_t))*(k/QK));
const char * restrict pp1 = p1 + ir1*((2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_quant_t))*(k/QK));
for (int i = 0; i < kp/QK; i++) {
float min0, d0;
memcpy(&min0, pp0, sizeof(float)); pp0 += sizeof(float);
memcpy(&d0, pp0, sizeof(float)); pp0 += sizeof(float);
float min1, d1;
memcpy(&min1, pp1, sizeof(float)); pp1 += sizeof(float);
memcpy(&d1, pp1, sizeof(float)); pp1 += sizeof(float);
//printf("min0/d0 = %f %f | min1/d1 = %f %f\n", min0, d0, min1, d1);
#if 1
// >>> General case for any QB
s0[0] = min0;
s1[0] = min1;
for (int b = 0; b < QB; b++) {
s0[b + 1] = d0*(1 << b);
s1[b + 1] = d1*(1 << b);
}
m0[0] = -1ULL;
m1[0] = -1ULL;
for (int s = 0; s < QK/gq_t_bits; ++s) {
for (int b = 0; b < QB; b++) {
memcpy(&m0[b + 1], pp0, sizeof(gq_quant_t)); pp0 += sizeof(gq_quant_t);
memcpy(&m1[b + 1], pp1, sizeof(gq_quant_t)); pp1 += sizeof(gq_quant_t);
}
for (int q0 = 0; q0 < QB + 1; q0++) {
for (int q1 = 0; q1 < QB + 1; q1++) {
sumf += s0[q0]*s1[q1]*__builtin_popcountll(m0[q0] & m1[q1]);
}
}
}
#else
#endif
}
dst[ir0*n + ir1] = sumf;
}
}
}
//
// method 2
//
static inline int quantize_2_blocks_per_row(int k) {
return k/QK;
}
static inline int quantize_2_quants_per_block() {
return QK/gq_t_bits;
}
static inline int quantize_2_row_size(int k) {
const int nb = quantize_2_blocks_per_row(k);
const int nq = quantize_2_quants_per_block();
return nb*(2*sizeof(gq_scale_t) + nq*QB*sizeof(gq_quant_t));
}
void quantize_2_row(const float * restrict src, void * restrict dst, int k) {
assert(k % QK == 0);
const int nb = quantize_2_blocks_per_row(k);
const int nq = quantize_2_quants_per_block();
gq_scale_t * restrict pm = (gq_scale_t *) (dst);
gq_scale_t * restrict pd = (gq_scale_t *) (pm + nb);
gq_quant_t * restrict pb = (gq_quant_t *) (pd + nb);
gq_quant_t pp[QB];
for (int i = 0; i < nb; i++) {
float min = FLT_MAX;
float max = -FLT_MAX;
for (int l = 0; l < QK; l++) {
const float v = src[i*QK + l];
if (v < min) min = v;
if (v > max) max = v;
}
const float d = (max - min) / ((1 << QB) - 1);
const float id = d ? 1.0/d : 0.0;
pm[i] = GGML_FP32_TO_GQ(min);
pd[i] = GGML_FP32_TO_GQ(d);
for (int s = 0; s < nq; ++s) {
memset(pp, 0, sizeof(pp));
for (int l = 0; l < gq_t_bits; l++) {
const float v = src[i*QK + s*gq_t_bits + l];
const uint8_t q = (v - min)*id;
for (int b = 0; b < QB; b++) {
pp[b] |= q & (1 << b) ? (1ULL << l) : 0;
}
}
for (int b = 0; b < QB; b++) {
pb[i*nq*QB + s*QB + b] = pp[b];
}
}
}
}
// reimplementation of quantize_2 using quantize_2_row
void quantize_2(const float * restrict src, char * restrict dst, int n, int k) {
assert(k % QK == 0);
for (int j = 0; j < n; j++) {
quantize_2_row(src + j*k, dst, k);
dst = (char *) dst + quantize_2_row_size(k);
}
}
void vec_dot_gq_2(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
float sumf[(QB + 1)*(QB + 1)];
memset(sumf, 0, sizeof(sumf));
const int nb = quantize_2_blocks_per_row(n);
const int nq = quantize_2_quants_per_block();
const gq_scale_t * restrict pm0 = (const gq_scale_t *) x;
const gq_scale_t * restrict pm1 = (const gq_scale_t *) y;
const gq_scale_t * restrict pd0 = pm0 + nb;
const gq_scale_t * restrict pd1 = pm1 + nb;
const gq_quant_t * restrict pb0 = (const gq_quant_t *) (pd0 + nb);
const gq_quant_t * restrict pb1 = (const gq_quant_t *) (pd1 + nb);
#if 1
float s0[QB + 1];
float s1[QB + 1];
for (int i = 0; i < nb; i++) {
const float m0 = GGML_GQ_TO_FP32(pm0[i]);
const float d0 = GGML_GQ_TO_FP32(pd0[i]);
const float m1 = GGML_GQ_TO_FP32(pm1[i]);
const float d1 = GGML_GQ_TO_FP32(pd1[i]);
s0[0] = m0;
s1[0] = m1;
for (int b = 0; b < QB; b++) {
s0[b + 1] = d0*(1 << b);
s1[b + 1] = d1*(1 << b);
}
for (int s = 0; s < nq; ++s) {
for (int q0 = 0; q0 < QB + 1; q0++) {
const gq_quant_t mm0 = q0 ? pb0[i*nq*QB + s*QB + q0 - 1] : -1ULL;
for (int q1 = 0; q1 < QB + 1; q1++) {
const gq_quant_t mm1 = q1 ? pb1[i*nq*QB + s*QB + q1 - 1] : -1ULL;
sumf[q0*(QB + 1) + q1] += s0[q0]*s1[q1]*__builtin_popcountll(mm0 & mm1);
}
}
}
}
#else
// SIMD-ify with the assumptions:
// - nb is a multiple of 4
// - gq_scale_t is float
// - gq_quant_t is uint64_t
// - QB == 7
assert(nb % 4 == 0);
#ifdef __ARM_NEON
#else
// TODO
#endif
#endif
for (int q0 = 0; q0 < QB + 1; q0++) {
for (int q1 = 1; q1 < QB + 1; q1++) {
sumf[q0*(QB + 1)] += sumf[q0*(QB + 1) + q1];
}
}
*s = sumf[0];
for (int q0 = 1; q0 < QB + 1; q0++) {
*s += sumf[q0*(QB + 1)];
}
}
// use vec_dot_gq_2 to compute the dot product of two rows
void mul_mat_gq_2(
const void * src0,
const void * src1, // transposed
float * dst,
int m, int n, int k) {
assert(k % QK == 0);
const int nb = quantize_2_blocks_per_row(k);
const int nq = quantize_2_quants_per_block();
for (int ir0 = 0; ir0 < m; ir0++) {
for (int ir1 = 0; ir1 < n; ir1++) {
vec_dot_gq_2(k, dst + ir1, src0, src1);
src1 = (const char *) src1 + quantize_2_row_size(k);
}
src0 = (const char *) src0 + quantize_2_row_size(k);
src1 = (const char *) src1 - n*quantize_2_row_size(k);
dst = (float *) dst + n;
}
}
int main(int argc, const char ** argv) {
assert(sizeof(gq_quant_t)*8 == gq_t_bits);
float * src0 = (float *)malloc(sizeof(float)*M*K);
float * src1 = (float *)malloc(sizeof(float)*N*K);
float * dst = (float *)malloc(sizeof(float)*M*N);
for (int i = 0; i < M*K; i++) {
src0[i] = rand() / (float)RAND_MAX;
}
for (int i = 0; i < N*K; i++) {
src1[i] = rand() / (float)RAND_MAX;
}
void * src0_gq = calloc(1, quantize_2_row_size(K)*M);
void * src1_gq = calloc(1, quantize_2_row_size(K)*N);
const size_t sizef16 = sizeof(ggml_fp16_t)*M*K + sizeof(ggml_fp16_t)*N*K;
const size_t sizegq = quantize_2_row_size(K)*M + quantize_2_row_size(K)*N;
printf("compression: %f\n", (float)sizegq/sizef16);
int method = 0;
if (argc > 1) {
method = atoi(argv[1]);
}
// convert fp32 -> gq
{
const uint64_t t_start = get_time_us();
if (method == 1) {
quantize_1(src0, src0_gq, M, K);
quantize_1(src1, src1_gq, N, K);
}
if (method == 2) {
quantize_2(src0, src0_gq, M, K);
quantize_2(src1, src1_gq, N, K);
}
const uint64_t t_end = get_time_us();
printf("convert time: %f ms / method = %d\n", (t_end - t_start) / 1000.0, method);
}
const int nIter = 1;
const clock_t start = clock();
const uint64_t start_us = get_time_us();
double iM = 1.0/M;
double sum = 0.0f;
for (int i = 0; i < nIter; i++) {
if (method == 0) {
mul_mat_f32_naive(src0, src1, dst, M, N, K);
}
if (method == 1) {
mul_mat_gq_1(src0_gq, src1_gq, dst, M, N, K);
}
if (method == 2) {
mul_mat_gq_2(src0_gq, src1_gq, dst, M, N, K);
}
}
for (int i = 0; i < N; i++) {
sum += dst[i]*iM;
}
{
const clock_t end = clock();
const uint64_t end_us = get_time_us();
printf("%s: elapsed ticks: %ld\n", __func__, end - start);
printf("%s: elapsed us: %d / %f ms\n", __func__, (int)(end_us - start_us), (end_us - start_us) / 1000.0 / nIter);
}
printf("%f\n", sum);
free(src0);
free(src1);
free(dst);
free(src0_gq);
free(src1_gq);
return 0;
}

@ -0,0 +1,218 @@
// SVD dimensionality reduction
#include <float.h>
#include <stdint.h>
#include <stdio.h>
#include <assert.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include <math.h>
#include <sys/time.h>
#ifdef GGML_USE_ACCELERATE
#include <Accelerate/Accelerate.h>
#endif
float frand() {
return (float) rand() / (float) RAND_MAX;
}
//int sgesvd_(char *__jobu, char *__jobvt, __CLPK_integer *__m,
// __CLPK_integer *__n, __CLPK_real *__a, __CLPK_integer *__lda,
// __CLPK_real *__s, __CLPK_real *__u, __CLPK_integer *__ldu,
// __CLPK_real *__vt, __CLPK_integer *__ldvt, __CLPK_real *__work,
// __CLPK_integer *__lwork,
// __CLPK_integer *__info)
int main(int argc, const char ** argv) {
int m = 10;
int n = 5;
float * A = (float *) malloc(n * m * sizeof(float));
float * A0 = (float *) malloc(n * m * sizeof(float));
for (int i = 0; i < n; ++i) {
for (int j = 0; j < m; ++j) {
A[i * m + j] = (float) (10.0f*(i + 1) + 1.0f * frand());
//A[i * m + j] = (float) (10.0f*(i%2 + 1) + 0.1f * frand());
//if (i == 2) {
// A[i * m + j] += 20*frand();
//}
if ((i == 1 || i == 3) && j > m/2) {
A[i * m + j] = -A[i * m + j];
}
}
}
// average vector
//float * M = (float *) malloc(m * sizeof(float));
//{
// for (int j = 0; j < m; ++j) {
// M[j] = 0.0f;
// }
// for (int i = 0; i < n; ++i) {
// for (int j = 0; j < m; ++j) {
// M[j] += A[i * m + j];
// }
// }
// for (int j = 0; j < m; ++j) {
// M[j] /= (float) n;
// }
//}
//// subtract average vector
//for (int i = 0; i < n; ++i) {
// for (int j = 0; j < m; ++j) {
// A[i * m + j] -= M[j];
// }
//}
memcpy(A0, A, n * m * sizeof(float));
// print A
printf("A:\n");
for (int i = 0; i < n; ++i) {
printf("col %d : ", i);
for (int j = 0; j < m; ++j) {
printf("%9.5f ", A[i * m + j]);
}
printf("\n");
}
printf("\n");
// SVD
// A = U * S * V^T
float * U = (float *) malloc(n * m * sizeof(float));
float * S = (float *) malloc(n * sizeof(float));
float * V = (float *) malloc(n * n * sizeof(float));
int lda = m;
int ldu = m;
int ldvt = n;
float work_size;
int lwork = -1;
int info = 0;
sgesvd_("S", "S", &m, &n, A, &lda, S, U, &ldu, V, &ldvt, &work_size, &lwork, &info);
lwork = (int) work_size;
printf("work_size = %f, info = %d, lwork = %d\n", work_size, info, lwork);
float * work = (float *) malloc(lwork * sizeof(float));
sgesvd_("S", "S", &m, &n, A, &lda, S, U, &ldu, V, &ldvt, work, &lwork, &info);
// print U
printf("U:\n");
for (int i = 0; i < n; ++i) {
printf("col %d : ", i);
for (int j = 0; j < m; ++j) {
printf("%9.5f ", U[i * m + j]);
}
printf("\n");
}
printf("\n");
// normalize S
{
double sum = 0.0;
for (int i = 0; i < n; ++i) {
sum += S[i];
}
sum *= sqrt((double) m);
for (int i = 0; i < n; ++i) {
S[i] /= sum;
}
}
// print S
printf("S:\n");
for (int i = 0; i < n; ++i) {
printf("- %d = %9.5f\n", i, S[i]);
}
printf("\n");
// print V
printf("V:\n");
for (int i = 0; i < n; ++i) {
printf("col %d : ", i);
for (int j = 0; j < n; ++j) {
printf("%9.5f ", V[i * n + j]);
}
printf("\n");
}
printf("\n");
// print A
printf("A:\n");
for (int i = 0; i < n; ++i) {
printf("col %d : ", i);
for (int j = 0; j < m; ++j) {
printf("%9.5f ", A[i * m + j]);
}
printf("\n");
}
printf("\n");
// compute singular vectors in U
for (int i = 0; i < n; ++i) {
for (int j = 0; j < m; ++j) {
U[i * m + j] *= S[i];
}
}
// normalize U
for (int i = 0; i < n; ++i) {
double sum = 0.0;
for (int j = 0; j < m; ++j) {
sum += U[i * m + j] * U[i * m + j];
}
sum = sqrt(sum);
for (int j = 0; j < m; ++j) {
U[i * m + j] /= sum*sqrt((double) m);
}
}
// print U
printf("U:\n");
for (int i = 0; i < n; ++i) {
printf("col %d : ", i);
for (int j = 0; j < m; ++j) {
printf("%9.5f ", U[i * m + j]);
}
printf("\n");
}
printf("\n");
// project A0 onto U
float * A1 = (float *) malloc(n * n * sizeof(float));
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
A1[i * n + j] = 0.0f;
for (int k = 0; k < m; ++k) {
A1[i * n + j] += A0[i * m + k] * U[j * m + k];
}
}
}
// print A1
printf("A1:\n");
for (int i = 0; i < n; ++i) {
printf("col %d : ", i);
for (int j = 0; j < n; ++j) {
printf("%9.5f ", A1[i * n + j]);
}
printf("\n");
}
printf("\n");
return 0;
}

@ -96,8 +96,8 @@ int main(int argc, const char ** argv) {
enum ggml_opt_result res = ggml_opt(NULL, opt_params, f);
assert(res == GGML_OPT_OK);
assert(is_close(ggml_get_f32_1d(t0, 0), 5.0f, 1e-3f));
assert(is_close(ggml_get_f32_1d(t1, 0), 10.0f, 1e-3f));
assert(is_close(ggml_get_f32_1d(t0, 0), 5.0f, 1e-2f));
assert(is_close(ggml_get_f32_1d(t1, 0), 10.0f, 1e-2f));
}
{

Loading…
Cancel
Save