sync : latest changes from whisper.cpp

- Documentation
- whisper : token-level timestamps
- ggml : Windows build fixes
- etc.
pull/12/head
Georgi Gerganov 2 years ago
parent db13973820
commit 90ee5c6358
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -36,6 +36,7 @@ std::string to_timestamp(int64_t t, bool comma = false) {
return std::string(buf); return std::string(buf);
} }
// helper function to replace substrings
void replace_all(std::string & s, const std::string & search, const std::string & replace) { void replace_all(std::string & s, const std::string & search, const std::string & replace) {
for (size_t pos = 0; ; pos += replace.length()) { for (size_t pos = 0; ; pos += replace.length()) {
pos = s.find(search, pos); pos = s.find(search, pos);
@ -45,31 +46,6 @@ void replace_all(std::string & s, const std::string & search, const std::string
} }
} }
// a cost-function that is high for text that takes longer to pronounce
float voice_length(const std::string & text) {
float res = 0.0f;
for (size_t i = 0; i < text.size(); ++i) {
if (text[i] == ' ') {
res += 0.01f;
} else if (text[i] == ',') {
res += 2.00f;
} else if (text[i] == '.') {
res += 3.00f;
} else if (text[i] == '!') {
res += 3.00f;
} else if (text[i] == '?') {
res += 3.00f;
} else if (text[i] >= '0' && text[i] <= '9') {
res += 3.00f;
} else {
res += 1.00f;
}
}
return res;
}
// command-line parameters // command-line parameters
struct whisper_params { struct whisper_params {
int32_t seed = -1; // RNG seed, not used currently int32_t seed = -1; // RNG seed, not used currently
@ -77,7 +53,9 @@ struct whisper_params {
int32_t n_processors = 1; int32_t n_processors = 1;
int32_t offset_t_ms = 0; int32_t offset_t_ms = 0;
int32_t offset_n = 0; int32_t offset_n = 0;
int32_t duration_ms = 0;
int32_t max_context = -1; int32_t max_context = -1;
int32_t max_len = 0;
float word_thold = 0.01f; float word_thold = 0.01f;
@ -118,8 +96,12 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
params.offset_t_ms = std::stoi(argv[++i]); params.offset_t_ms = std::stoi(argv[++i]);
} else if (arg == "-on" || arg == "--offset-n") { } else if (arg == "-on" || arg == "--offset-n") {
params.offset_n = std::stoi(argv[++i]); params.offset_n = std::stoi(argv[++i]);
} else if (arg == "-d" || arg == "--duration") {
params.duration_ms = std::stoi(argv[++i]);
} else if (arg == "-mc" || arg == "--max-context") { } else if (arg == "-mc" || arg == "--max-context") {
params.max_context = std::stoi(argv[++i]); params.max_context = std::stoi(argv[++i]);
} else if (arg == "-ml" || arg == "--max-len") {
params.max_len = std::stoi(argv[++i]);
} else if (arg == "-wt" || arg == "--word-thold") { } else if (arg == "-wt" || arg == "--word-thold") {
params.word_thold = std::stof(argv[++i]); params.word_thold = std::stof(argv[++i]);
} else if (arg == "-v" || arg == "--verbose") { } else if (arg == "-v" || arg == "--verbose") {
@ -175,14 +157,16 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
fprintf(stderr, " -p N, --processors N number of processors to use during computation (default: %d)\n", params.n_processors); fprintf(stderr, " -p N, --processors N number of processors to use during computation (default: %d)\n", params.n_processors);
fprintf(stderr, " -ot N, --offset-t N time offset in milliseconds (default: %d)\n", params.offset_t_ms); fprintf(stderr, " -ot N, --offset-t N time offset in milliseconds (default: %d)\n", params.offset_t_ms);
fprintf(stderr, " -on N, --offset-n N segment index offset (default: %d)\n", params.offset_n); fprintf(stderr, " -on N, --offset-n N segment index offset (default: %d)\n", params.offset_n);
fprintf(stderr, " -d N, --duration N duration of audio to process in milliseconds (default: %d)\n", params.duration_ms);
fprintf(stderr, " -mc N, --max-context N maximum number of text context tokens to store (default: max)\n"); fprintf(stderr, " -mc N, --max-context N maximum number of text context tokens to store (default: max)\n");
fprintf(stderr, " -ml N, --max-len N maximum segment length in characters (default: %d)\n", params.max_len);
fprintf(stderr, " -wt N, --word-thold N word timestamp probability threshold (default: %f)\n", params.word_thold); fprintf(stderr, " -wt N, --word-thold N word timestamp probability threshold (default: %f)\n", params.word_thold);
fprintf(stderr, " -v, --verbose verbose output\n"); fprintf(stderr, " -v, --verbose verbose output\n");
fprintf(stderr, " --translate translate from source language to english\n"); fprintf(stderr, " --translate translate from source language to english\n");
fprintf(stderr, " -otxt, --output-txt output result in a text file\n"); fprintf(stderr, " -otxt, --output-txt output result in a text file\n");
fprintf(stderr, " -ovtt, --output-vtt output result in a vtt file\n"); fprintf(stderr, " -ovtt, --output-vtt output result in a vtt file\n");
fprintf(stderr, " -osrt, --output-srt output result in a srt file\n"); fprintf(stderr, " -osrt, --output-srt output result in a srt file\n");
fprintf(stderr, " -owts, --output-words output word-level timestamps to a text file\n"); fprintf(stderr, " -owts, --output-words output script for generating karaoke video\n");
fprintf(stderr, " -ps, --print_special print special tokens\n"); fprintf(stderr, " -ps, --print_special print special tokens\n");
fprintf(stderr, " -pc, --print_colors print colors\n"); fprintf(stderr, " -pc, --print_colors print colors\n");
fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n"); fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n");
@ -192,17 +176,18 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
void whisper_print_segment_callback(struct whisper_context * ctx, void * user_data) { void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
const whisper_params & params = *(whisper_params *) user_data; const whisper_params & params = *(whisper_params *) user_data;
const int n_segments = whisper_full_n_segments(ctx); const int n_segments = whisper_full_n_segments(ctx);
// print the last segment // print the last n_new segments
const int i = n_segments - 1; const int s0 = n_segments - n_new;
if (i == 0) { if (s0 == 0) {
printf("\n"); printf("\n");
} }
for (int i = s0; i < n_segments; i++) {
if (params.no_timestamps) { if (params.no_timestamps) {
if (params.print_colors) { if (params.print_colors) {
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
@ -254,6 +239,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, void * user_da
} }
} }
} }
}
bool output_txt(struct whisper_context * ctx, const char * fname) { bool output_txt(struct whisper_context * ctx, const char * fname) {
std::ofstream fout(fname); std::ofstream fout(fname);
@ -320,288 +306,41 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_
return true; return true;
} }
// word-level timestamps (experimental) // karaoke video generation
// TODO: probably still has bugs, needs refactoring, etc.. // outputs a bash script that uses ffmpeg to generate a video with the subtitles
// TODO: auto threshold
// TODO: extra pass to detect unused speech and assign to tokens
// TODO: font parameter adjustments // TODO: font parameter adjustments
bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, const std::vector<float> & pcmf32) { bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec) {
if (params.output_wts) {
std::vector<float> pcm_avg(pcmf32.size(), 0);
// average the fabs of the signal
{
const int hw = 32;
for (int i = 0; i < pcmf32.size(); i++) {
float sum = 0;
for (int j = -hw; j <= hw; j++) {
if (i + j >= 0 && i + j < pcmf32.size()) {
sum += fabs(pcmf32[i + j]);
}
}
pcm_avg[i] = sum/(2*hw + 1);
}
}
struct token_info {
int64_t t0 = -1;
int64_t t1 = -1;
int64_t tt0 = -1;
int64_t tt1 = -1;
whisper_token id;
whisper_token tid;
float p = 0.0f;
float pt = 0.0f;
float ptsum = 0.0f;
std::string text;
float vlen = 0.0f; // voice length of this token
};
int64_t t_beg = 0;
int64_t t_last = 0;
whisper_token tid_last = 0;
std::ofstream fout(fname); std::ofstream fout(fname);
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
fout << "!/bin/bash" << "\n"; // TODO: become parameter
fout << "\n"; static const char * font = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
fout << "ffmpeg -i " << fname_inp << " -f lavfi -i color=size=1200x120:duration=" << float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE << ":rate=25:color=black -vf \""; fout << "#!/bin/bash" << "\n";
fout << "\n";
bool is_first = true; fout << "ffmpeg -i " << fname_inp << " -f lavfi -i color=size=1200x120:duration=" << t_sec << ":rate=25:color=black -vf \"";
for (int i = 0; i < whisper_full_n_segments(ctx); i++) { for (int i = 0; i < whisper_full_n_segments(ctx); i++) {
const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
const char *text = whisper_full_get_segment_text(ctx, i);
const int s0 = std::max(0, (int) (t0*WHISPER_SAMPLE_RATE/100));
const int s1 = std::min((int) pcmf32.size(), (int) (t1*WHISPER_SAMPLE_RATE/100));
const int n = whisper_full_n_tokens(ctx, i); const int n = whisper_full_n_tokens(ctx, i);
std::vector<token_info> tokens(n); std::vector<whisper_token_data> tokens(n);
if (n <= 1) {
continue;
}
for (int j = 0; j < n; ++j) { for (int j = 0; j < n; ++j) {
struct whisper_token_data token = whisper_full_get_token_data(ctx, i, j); tokens[j] = whisper_full_get_token_data(ctx, i, j);
if (j == 0) {
if (token.id == whisper_token_beg(ctx)) {
tokens[j ].t0 = t0;
tokens[j ].t1 = t0;
tokens[j + 1].t0 = t0;
t_beg = t0;
t_last = t0;
tid_last = whisper_token_beg(ctx);
} else {
tokens[j ].t0 = t_last;
}
}
const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx));
tokens[j].id = token.id;
tokens[j].tid = token.tid;
tokens[j].p = token.p;
tokens[j].pt = token.pt;
tokens[j].ptsum = token.ptsum;
tokens[j].text = whisper_token_to_str(ctx, token.id);
//tokens[j].vlen = tokens[j].pt;
tokens[j].vlen = voice_length(tokens[j].text);
if (token.pt > params.word_thold && token.ptsum > 0.01 && token.tid > tid_last && tt <= t1) {
if (j > 0) {
tokens[j - 1].t1 = tt;
}
tokens[j].t0 = tt;
tid_last = token.tid;
} }
}
tokens[n - 2].t1 = t1;
tokens[n - 1].t0 = t1;
tokens[n - 1].t1 = t1;
t_last = t1;
int p0 = 0;
int p1 = 0;
while (true) {
while (p1 < n && tokens[p1].t1 < 0) {
p1++;
}
if (p1 >= n) {
p1--;
}
if (p1 > p0) {
double psum = 0.0;
for (int j = p0; j <= p1; j++) {
psum += tokens[j].vlen;
}
//printf("analyzing %d - %d, psum = %f\n", p0, p1, psum);
const double dt = tokens[p1].t1 - tokens[p0].t0;
for (int j = p0 + 1; j <= p1; j++) { if (i > 0) {
const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum;
//const double ct = tokens[j - 1].t0 + (dt*(j - p0))/(p1 - p0 + 1);
//const double ct = tokens[p0].t0 + (dt*(j - p0))/(p1 - p0 + 1);
tokens[j - 1].t1 = ct;
tokens[j ].t0 = ct;
}
}
p1++;
p0 = p1;
if (p1 >= n) {
break;
}
}
for (int j = 0; j < n - 1; j++) {
if (tokens[j].t1 < 0) {
tokens[j + 1].t0 = tokens[j].t1;
}
if (j > 0) {
if (tokens[j - 1].t1 > tokens[j].t0) {
tokens[j].t0 = tokens[j - 1].t1;
tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1);
}
}
tokens[j].tt0 = tokens[j].t0;
tokens[j].tt1 = tokens[j].t1;
}
// VAD
{
const int hw = WHISPER_SAMPLE_RATE/8;
for (int j = 0; j < n; j++) {
if (tokens[j].id >= whisper_token_eot(ctx)) {
continue;
}
const int64_t t0 = tokens[j].t0;
const int64_t t1 = tokens[j].t1;
int s0 = std::max(0, (int) (t0*WHISPER_SAMPLE_RATE/100));
int s1 = std::min((int) pcmf32.size() - 1, (int) (t1*WHISPER_SAMPLE_RATE/100));
const int ss0 = std::max(0, (int) (t0*WHISPER_SAMPLE_RATE/100) - hw);
const int ss1 = std::min((int) pcmf32.size() - 1, (int) (t1*WHISPER_SAMPLE_RATE/100) + hw);
const int n = ss1 - ss0;
float sum = 0.0f;
for (int k = ss0; k < ss1; k++) {
sum += pcm_avg[k];
}
const float thold = 0.5*sum/n;
{
int k = s0;
if (pcm_avg[k] > thold && j > 0) {
while (k > 0 && pcm_avg[k] > thold) {
k--;
}
tokens[j].t0 = (int64_t) (100*k/WHISPER_SAMPLE_RATE);
if (tokens[j].t0 < tokens[j - 1].t1) {
tokens[j].t0 = tokens[j - 1].t1;
} else {
s0 = k;
}
} else {
while (pcm_avg[k] < thold && k < s1) {
k++;
}
s0 = k;
tokens[j].t0 = 100*k/WHISPER_SAMPLE_RATE;
}
}
{
int k = s1;
if (pcm_avg[k] > thold) {
while (k < (int) pcmf32.size() - 1 && pcm_avg[k] > thold) {
k++;
}
tokens[j].t1 = 100*k/WHISPER_SAMPLE_RATE;
if (j < n - 1 && tokens[j].t1 > tokens[j + 1].t0) {
tokens[j].t1 = tokens[j + 1].t0;
} else {
s1 = k;
}
} else {
while (pcm_avg[k] < thold && k > s0) {
k--;
}
s1 = k;
tokens[j].t1 = 100*k/WHISPER_SAMPLE_RATE;
}
}
}
}
const int t_expand = 0;
for (int j = 0; j < n; j++) {
if (j > 0) {
tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand));
}
if (j < n - 1) {
tokens[j].t1 = tokens[j].t1 + t_expand;
}
}
for (int j = 0; j < n; ++j) {
const auto & token = tokens[j];
const auto tt = token.pt > params.word_thold && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]";
printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__,
tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, token.text.c_str());
if (tokens[j].id >= whisper_token_eot(ctx)) {
continue;
}
//printf("[%s --> %s] %s\n", to_timestamp(token.t0).c_str(), to_timestamp(token.t1).c_str(), whisper_token_to_str(ctx, token.id));
//fout << "# " << to_timestamp(token.t0) << " --> " << to_timestamp(token.t1) << " " << whisper_token_to_str(ctx, token.id) << "\n";
}
static const int line_wrap = 60;
static const char * font = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
if (!is_first) {
fout << ","; fout << ",";
} }
// background text // background text
fout << "drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << "," << t0/100.0 << ")'"; fout << "drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << "," << t0/100.0 << ")'";
is_first = false; bool is_first = true;
for (int j = 0; j < n; ++j) { for (int j = 0; j < n; ++j) {
const auto & token = tokens[j]; const auto & token = tokens[j];
@ -645,17 +384,6 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
} }
ncnt += txt.size(); ncnt += txt.size();
if (ncnt > line_wrap) {
if (k < j) {
txt_bg = "> ";
txt_fg = "> ";
txt_ul = "\\ \\ ";
ncnt = 0;
} else {
break;
}
}
} }
::replace_all(txt_bg, "'", ""); ::replace_all(txt_bg, "'", "");
@ -664,8 +392,11 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
::replace_all(txt_fg, "\"", "\\\""); ::replace_all(txt_fg, "\"", "\\\"");
} }
if (is_first) {
// background text // background text
fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='" << txt_bg << "':enable='between(t," << token.tt0/100.0 << "," << token.tt1/100.0 << ")'"; fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='" << txt_bg << "':enable='between(t," << t0/100.0 << "," << t1/100.0 << ")'";
is_first = false;
}
// foreground text // foreground text
fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2:text='" << txt_fg << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'"; fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2:text='" << txt_fg << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'";
@ -686,7 +417,6 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
fout.close(); fout.close();
fprintf(stderr, "%s: run 'source %s' to generate karaoke video\n", __func__, fname); fprintf(stderr, "%s: run 'source %s' to generate karaoke video\n", __func__, fname);
}
return true; return true;
} }
@ -724,9 +454,30 @@ int main(int argc, char ** argv) {
std::vector<float> pcmf32; std::vector<float> pcmf32;
{ {
drwav wav; drwav wav;
if (!drwav_init_file(&wav, fname_inp.c_str(), NULL)) {
fprintf(stderr, "%s: failed to open WAV file '%s' - check your input\n", argv[0], fname_inp.c_str()); if (fname_inp == "-") {
whisper_print_usage(argc, argv, {}); std::vector<uint8_t> wav_data;
{
uint8_t buf[1024];
while (true)
{
const size_t n = fread(buf, 1, sizeof(buf), stdin);
if (n == 0)
{
break;
}
wav_data.insert(wav_data.end(), buf, buf + n);
}
}
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), NULL) == false)
{
fprintf(stderr, "error: failed to open WAV file from stdin\n");
return 4;
}
}
else if (drwav_init_file(&wav, fname_inp.c_str(), NULL) == false) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
return 4; return 4;
} }
@ -806,6 +557,11 @@ int main(int argc, char ** argv) {
wparams.n_threads = params.n_threads; wparams.n_threads = params.n_threads;
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx; wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
wparams.offset_ms = params.offset_t_ms; wparams.offset_ms = params.offset_t_ms;
wparams.duration_ms = params.duration_ms;
wparams.token_timestamps = params.output_wts || params.max_len > 0;
wparams.thold_pt = params.word_thold;
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
// this callback is called on each new segment // this callback is called on each new segment
if (!wparams.print_realtime) { if (!wparams.print_realtime) {
@ -844,7 +600,7 @@ int main(int argc, char ** argv) {
// output to WTS file // output to WTS file
if (params.output_wts) { if (params.output_wts) {
const auto fname_wts = fname_inp + ".wts"; const auto fname_wts = fname_inp + ".wts";
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, pcmf32); output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
} }
} }
} }

@ -133,11 +133,19 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
static const size_t MB = 1024*1024; static const size_t MB = 1024*1024;
static const std::map<e_model, size_t> MEM_REQ_MODEL = { static const std::map<e_model, size_t> MEM_REQ_MODEL = {
{ MODEL_TINY, 86ull*MB }, { MODEL_TINY, 74ull*MB },
{ MODEL_BASE, 165ull*MB }, { MODEL_BASE, 142ull*MB },
{ MODEL_SMALL, 540ull*MB }, { MODEL_SMALL, 466ull*MB },
{ MODEL_MEDIUM, 1650ull*MB }, { MODEL_MEDIUM, 1464ull*MB },
{ MODEL_LARGE, 3260ull*MB }, { MODEL_LARGE, 2952ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_MEMORY = {
{ MODEL_TINY, 12ull*MB },
{ MODEL_BASE, 24ull*MB },
{ MODEL_SMALL, 70ull*MB },
{ MODEL_MEDIUM, 184ull*MB },
{ MODEL_LARGE, 306ull*MB },
}; };
static const std::map<e_model, size_t> MEM_REQ_ENCODE = { static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
@ -410,6 +418,12 @@ struct whisper_context {
std::vector<whisper_segment> result_all; std::vector<whisper_segment> result_all;
std::vector<whisper_token> prompt_past; std::vector<whisper_token> prompt_past;
// [EXPERIMENTAL] token-level timestamps data
int64_t t_beg;
int64_t t_last;
whisper_token tid_last;
std::vector<float> energy; // PCM signal energy
}; };
// load the model from a ggml file // load the model from a ggml file
@ -423,7 +437,7 @@ struct whisper_context {
// //
// see the convert-pt-to-ggml.py script for details // see the convert-pt-to-ggml.py script for details
// //
bool whisper_model_load(const std::string & fname, whisper_context & wctx) { static bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
fprintf(stderr, "%s: loading model from '%s'\n", __func__, fname.c_str()); fprintf(stderr, "%s: loading model from '%s'\n", __func__, fname.c_str());
auto & model = wctx.model; auto & model = wctx.model;
@ -498,7 +512,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
wctx.buf_model = new std::vector<uint8_t>(); wctx.buf_model = new std::vector<uint8_t>();
wctx.buf_model->resize(MEM_REQ_MODEL.at(model.type)); wctx.buf_model->resize(MEM_REQ_MODEL.at(model.type));
wctx.buf_memory.resize(std::max(MEM_REQ_MODEL.at(model.type), MEM_REQ_MODEL.at(model.type))); // TODO: TMP !!! wctx.buf_memory.resize(MEM_REQ_MEMORY.at(model.type));
wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type))); wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
@ -722,20 +736,6 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
} }
} }
// create the ggml memory context
{
struct ggml_init_params params = {
.mem_size = wctx.buf_memory.size(),
.mem_buffer = wctx.buf_memory.data(),
};
model.ctx_mem = ggml_init(params);
if (!model.ctx_mem) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}
}
// prepare memory for the weights // prepare memory for the weights
{ {
auto & ctx = model.ctx; auto & ctx = model.ctx;
@ -932,6 +932,20 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
} }
} }
// create the ggml memory context
{
struct ggml_init_params params = {
.mem_size = wctx.buf_memory.size(),
.mem_buffer = wctx.buf_memory.data(),
};
model.ctx_mem = ggml_init(params);
if (!model.ctx_mem) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}
}
// key + value memory // key + value memory
{ {
auto & ctx = model.ctx_mem; auto & ctx = model.ctx_mem;
@ -1054,7 +1068,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
// - n_threads: number of threads to use // - n_threads: number of threads to use
// - mel_offset: offset in the mel spectrogram (i.e. audio offset) // - mel_offset: offset in the mel spectrogram (i.e. audio offset)
// //
bool whisper_encode( static bool whisper_encode(
whisper_context & wctx, whisper_context & wctx,
const int n_threads, const int n_threads,
const int mel_offset) { const int mel_offset) {
@ -1440,7 +1454,7 @@ bool whisper_encode(
// - n_tokens: number of tokens in the prompt // - n_tokens: number of tokens in the prompt
// - n_past: number of past tokens to prefix the prompt with // - n_past: number of past tokens to prefix the prompt with
// //
bool whisper_decode( static bool whisper_decode(
whisper_context & wctx, whisper_context & wctx,
const int n_threads, const int n_threads,
const whisper_token * tokens, const whisper_token * tokens,
@ -1803,10 +1817,12 @@ bool whisper_decode(
} }
// the most basic sampling scheme - select the top token // the most basic sampling scheme - select the top token
whisper_token_data whisper_sample_best( static whisper_token_data whisper_sample_best(
const whisper_vocab & vocab, const whisper_vocab & vocab,
const float * probs) { const float * probs) {
whisper_token_data result; whisper_token_data result = {
0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
};
int n_logits = vocab.id_to_token.size(); int n_logits = vocab.id_to_token.size();
@ -1879,7 +1895,7 @@ whisper_token_data whisper_sample_best(
} }
// samples only from the timestamps tokens // samples only from the timestamps tokens
whisper_vocab::id whisper_sample_timestamp( static whisper_vocab::id whisper_sample_timestamp(
const whisper_vocab & vocab, const whisper_vocab & vocab,
const float * probs) { const float * probs) {
int n_logits = vocab.id_to_token.size(); int n_logits = vocab.id_to_token.size();
@ -1931,7 +1947,7 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
// naive Discrete Fourier Transform // naive Discrete Fourier Transform
// input is real-valued // input is real-valued
// output is complex-valued // output is complex-valued
void dft(const std::vector<float> & in, std::vector<float> & out) { static void dft(const std::vector<float> & in, std::vector<float> & out) {
int N = in.size(); int N = in.size();
out.resize(N*2); out.resize(N*2);
@ -1955,7 +1971,7 @@ void dft(const std::vector<float> & in, std::vector<float> & out) {
// poor man's implementation - use something better // poor man's implementation - use something better
// input is real-valued // input is real-valued
// output is complex-valued // output is complex-valued
void fft(const std::vector<float> & in, std::vector<float> & out) { static void fft(const std::vector<float> & in, std::vector<float> & out) {
out.resize(in.size()*2); out.resize(in.size()*2);
int N = in.size(); int N = in.size();
@ -2006,7 +2022,7 @@ void fft(const std::vector<float> & in, std::vector<float> & out) {
} }
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124 // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
bool log_mel_spectrogram( static bool log_mel_spectrogram(
const float * samples, const float * samples,
const int n_samples, const int n_samples,
const int sample_rate, const int sample_rate,
@ -2323,6 +2339,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
/*.n_max_text_ctx =*/ 16384, /*.n_max_text_ctx =*/ 16384,
/*.offset_ms =*/ 0, /*.offset_ms =*/ 0,
/*.duration_ms =*/ 0,
/*.translate =*/ false, /*.translate =*/ false,
/*.no_context =*/ false, /*.no_context =*/ false,
@ -2331,6 +2348,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.print_realtime =*/ false, /*.print_realtime =*/ false,
/*.print_timestamps =*/ true, /*.print_timestamps =*/ true,
/*.token_timestamps =*/ false,
/*.thold_pt =*/ 0.01f,
/*.thold_ptsum =*/ 0.01f,
/*.max_len =*/ 0,
/*.language =*/ "en", /*.language =*/ "en",
/*.greedy =*/ { /*.greedy =*/ {
@ -2355,6 +2377,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
/*.n_max_text_ctx =*/ 16384, /*.n_max_text_ctx =*/ 16384,
/*.offset_ms =*/ 0, /*.offset_ms =*/ 0,
/*.duration_ms =*/ 0,
/*.translate =*/ false, /*.translate =*/ false,
/*.no_context =*/ false, /*.no_context =*/ false,
@ -2363,6 +2386,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.print_realtime =*/ false, /*.print_realtime =*/ false,
/*.print_timestamps =*/ true, /*.print_timestamps =*/ true,
/*.token_timestamps =*/ false,
/*.thold_pt =*/ 0.01f,
/*.thold_ptsum =*/ 0.01f,
/*.max_len =*/ 0,
/*.language =*/ "en", /*.language =*/ "en",
/*.greedy =*/ { /*.greedy =*/ {
@ -2384,6 +2412,68 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
return result; return result;
} }
// forward declarations
static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window);
static void whisper_exp_compute_token_level_timestamps(
struct whisper_context * ctx,
int i_segment,
float thold_pt,
float thold_ptsum);
// wrap the last segment to max_len characters
// returns the number of new segments
static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) {
auto segment = ctx->result_all.back();
int res = 1;
int acc = 0;
std::string text;
for (int i = 0; i < (int) segment.tokens.size(); i++) {
const auto & token = segment.tokens[i];
if (token.id >= whisper_token_eot(ctx)) {
continue;
}
const auto txt = whisper_token_to_str(ctx, token.id);
const int cur = strlen(txt);
if (acc + cur > max_len && i > 0) {
// split here
ctx->result_all.back().text = std::move(text);
ctx->result_all.back().t1 = token.t0;
ctx->result_all.back().tokens.resize(i);
ctx->result_all.push_back({});
ctx->result_all.back().t0 = token.t0;
ctx->result_all.back().t1 = segment.t1;
// add tokens [i, end] to the new segment
ctx->result_all.back().tokens.insert(
ctx->result_all.back().tokens.end(),
segment.tokens.begin() + i,
segment.tokens.end());
acc = 0;
text = "";
segment = ctx->result_all.back();
i = -1;
res++;
} else {
acc += cur;
text += txt;
}
}
ctx->result_all.back().text = std::move(text);
return res;
}
int whisper_full( int whisper_full(
struct whisper_context * ctx, struct whisper_context * ctx,
struct whisper_full_params params, struct whisper_full_params params,
@ -2400,12 +2490,20 @@ int whisper_full(
return -1; return -1;
} }
if (params.token_timestamps) {
ctx->t_beg = 0;
ctx->t_last = 0;
ctx->tid_last = 0;
ctx->energy = get_signal_energy(samples, n_samples, 32);
}
const int seek_start = params.offset_ms/10; const int seek_start = params.offset_ms/10;
const int seek_end = seek_start + (params.duration_ms == 0 ? whisper_n_len(ctx) : params.duration_ms/10);
// if length of spectrogram is less than 1s (100 samples), then return // if length of spectrogram is less than 1s (100 samples), then return
// basically don't process anything that is less than 1s // basically don't process anything that is less than 1s
// see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39 // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
if (whisper_n_len(ctx) < 100 + seek_start) { if (seek_end < 100 + seek_start) {
return 0; return 0;
} }
@ -2438,7 +2536,7 @@ int whisper_full(
// main loop // main loop
int seek = seek_start; int seek = seek_start;
while (true) { while (true) {
int progress_cur = (100*seek)/whisper_n_len(ctx); const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start);
while (progress_cur >= progress_prev + progress_step) { while (progress_cur >= progress_prev + progress_step) {
progress_prev += progress_step; progress_prev += progress_step;
if (params.print_progress) { if (params.print_progress) {
@ -2446,7 +2544,7 @@ int whisper_full(
} }
} }
if (seek + 100 >= whisper_n_len(ctx)) { if (seek + 100 >= seek_end) {
break; break;
} }
@ -2527,7 +2625,7 @@ int whisper_full(
// end of text token // end of text token
if (token.id == whisper_token_eot(ctx)) { if (token.id == whisper_token_eot(ctx)) {
if (result_len == 0) { if (result_len == 0) {
if (seek + seek_delta + 100 >= whisper_n_len(ctx)) { if (seek + seek_delta + 100 >= seek_end) {
result_len = i + 1; result_len = i + 1;
} else { } else {
// TODO: figure out how to resolve this // TODO: figure out how to resolve this
@ -2549,6 +2647,7 @@ int whisper_full(
} }
} }
// shrink down to result_len
tokens_cur.resize(result_len); tokens_cur.resize(result_len);
for (const auto & r : tokens_cur) { for (const auto & r : tokens_cur) {
@ -2587,8 +2686,19 @@ int whisper_full(
for (int j = i0; j <= i; j++) { for (int j = i0; j <= i; j++) {
result_all.back().tokens.push_back(tokens_cur[j]); result_all.back().tokens.push_back(tokens_cur[j]);
} }
int n_new = 1;
if (params.token_timestamps) {
whisper_exp_compute_token_level_timestamps(
ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
if (params.max_len > 0) {
n_new = whisper_wrap_segment(ctx, params.max_len);
}
}
if (params.new_segment_callback) { if (params.new_segment_callback) {
params.new_segment_callback(ctx, params.new_segment_callback_user_data); params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
} }
} }
text = ""; text = "";
@ -2617,8 +2727,19 @@ int whisper_full(
for (int j = i0; j < (int) tokens_cur.size(); j++) { for (int j = i0; j < (int) tokens_cur.size(); j++) {
result_all.back().tokens.push_back(tokens_cur[j]); result_all.back().tokens.push_back(tokens_cur[j]);
} }
int n_new = 1;
if (params.token_timestamps) {
whisper_exp_compute_token_level_timestamps(
ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
if (params.max_len > 0) {
n_new = whisper_wrap_segment(ctx, params.max_len);
}
}
if (params.new_segment_callback) { if (params.new_segment_callback) {
params.new_segment_callback(ctx, params.new_segment_callback_user_data); params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
} }
} }
} }
@ -2752,7 +2873,7 @@ int whisper_full_parallel(
// call the new_segment_callback for each segment // call the new_segment_callback for each segment
if (params.new_segment_callback) { if (params.new_segment_callback) {
params.new_segment_callback(ctx, params.new_segment_callback_user_data); params.new_segment_callback(ctx, 1, params.new_segment_callback_user_data);
} }
} }
@ -2828,3 +2949,304 @@ const char * whisper_print_system_info() {
return s.c_str(); return s.c_str();
} }
// =================================================================================================
//
// Experimental stuff below
//
// Not sure if these should be part of the library at all, because the quality of the results is not
// guaranteed. Might get removed at some point unless a robust algorithm implementation is found
//
// =================================================================================================
//
// token-level timestamps
//
static int timestamp_to_sample(int64_t t, int n_samples) {
return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
}
static int64_t sample_to_timestamp(int i_sample) {
return (100*i_sample)/WHISPER_SAMPLE_RATE;
}
// a cost-function / heuristic that is high for text that takes longer to pronounce
// obviously, can be improved
static float voice_length(const std::string & text) {
float res = 0.0f;
for (size_t i = 0; i < text.size(); ++i) {
if (text[i] == ' ') {
res += 0.01f;
} else if (text[i] == ',') {
res += 2.00f;
} else if (text[i] == '.') {
res += 3.00f;
} else if (text[i] == '!') {
res += 3.00f;
} else if (text[i] == '?') {
res += 3.00f;
} else if (text[i] >= '0' && text[i] <= '9') {
res += 3.00f;
} else {
res += 1.00f;
}
}
return res;
}
// average the fabs of the signal
static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window) {
const int hw = n_samples_per_half_window;
std::vector<float> result(n_samples);
for (int i = 0; i < n_samples; i++) {
float sum = 0;
for (int j = -hw; j <= hw; j++) {
if (i + j >= 0 && i + j < n_samples) {
sum += fabs(signal[i + j]);
}
}
result[i] = sum/(2*hw + 1);
}
return result;
}
static void whisper_exp_compute_token_level_timestamps(
struct whisper_context * ctx,
int i_segment,
float thold_pt,
float thold_ptsum) {
auto & segment = ctx->result_all[i_segment];
auto & tokens = segment.tokens;
const int n_samples = ctx->energy.size();
if (n_samples == 0) {
fprintf(stderr, "%s: no signal data available\n", __func__);
return;
}
const int64_t t0 = segment.t0;
const int64_t t1 = segment.t1;
const int s0 = timestamp_to_sample(t0, n_samples);
const int s1 = timestamp_to_sample(t1, n_samples);
const int n = tokens.size();
if (n == 0) {
return;
}
if (n == 1) {
tokens[0].t0 = t0;
tokens[0].t1 = t1;
return;
}
auto & t_beg = ctx->t_beg;
auto & t_last = ctx->t_last;
auto & tid_last = ctx->tid_last;
for (int j = 0; j < n; ++j) {
auto & token = tokens[j];
if (j == 0) {
if (token.id == whisper_token_beg(ctx)) {
tokens[j ].t0 = t0;
tokens[j ].t1 = t0;
tokens[j + 1].t0 = t0;
t_beg = t0;
t_last = t0;
tid_last = whisper_token_beg(ctx);
} else {
tokens[j ].t0 = t_last;
}
}
const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx));
tokens[j].id = token.id;
tokens[j].tid = token.tid;
tokens[j].p = token.p;
tokens[j].pt = token.pt;
tokens[j].ptsum = token.ptsum;
tokens[j].vlen = voice_length(whisper_token_to_str(ctx, token.id));
if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) {
if (j > 0) {
tokens[j - 1].t1 = tt;
}
tokens[j].t0 = tt;
tid_last = token.tid;
}
}
tokens[n - 2].t1 = t1;
tokens[n - 1].t0 = t1;
tokens[n - 1].t1 = t1;
t_last = t1;
// find intervals of tokens with unknown timestamps
// fill the timestamps by proportionally splitting the interval based on the token voice lengths
{
int p0 = 0;
int p1 = 0;
while (true) {
while (p1 < n && tokens[p1].t1 < 0) {
p1++;
}
if (p1 >= n) {
p1--;
}
if (p1 > p0) {
double psum = 0.0;
for (int j = p0; j <= p1; j++) {
psum += tokens[j].vlen;
}
//printf("analyzing %d - %d, psum = %f\n", p0, p1, psum);
const double dt = tokens[p1].t1 - tokens[p0].t0;
// split the time proportionally to the voice length
for (int j = p0 + 1; j <= p1; j++) {
const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum;
tokens[j - 1].t1 = ct;
tokens[j ].t0 = ct;
}
}
p1++;
p0 = p1;
if (p1 >= n) {
break;
}
}
}
// fix up (just in case)
for (int j = 0; j < n - 1; j++) {
if (tokens[j].t1 < 0) {
tokens[j + 1].t0 = tokens[j].t1;
}
if (j > 0) {
if (tokens[j - 1].t1 > tokens[j].t0) {
tokens[j].t0 = tokens[j - 1].t1;
tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1);
}
}
}
// VAD
// expand or contract tokens based on voice activity
{
const int hw = WHISPER_SAMPLE_RATE/8;
for (int j = 0; j < n; j++) {
if (tokens[j].id >= whisper_token_eot(ctx)) {
continue;
}
int s0 = timestamp_to_sample(tokens[j].t0, n_samples);
int s1 = timestamp_to_sample(tokens[j].t1, n_samples);
const int ss0 = std::max(s0 - hw, 0);
const int ss1 = std::min(s1 + hw, n_samples);
const int ns = ss1 - ss0;
float sum = 0.0f;
for (int k = ss0; k < ss1; k++) {
sum += ctx->energy[k];
}
const float thold = 0.5*sum/ns;
{
int k = s0;
if (ctx->energy[k] > thold && j > 0) {
while (k > 0 && ctx->energy[k] > thold) {
k--;
}
tokens[j].t0 = sample_to_timestamp(k);
if (tokens[j].t0 < tokens[j - 1].t1) {
tokens[j].t0 = tokens[j - 1].t1;
} else {
s0 = k;
}
} else {
while (ctx->energy[k] < thold && k < s1) {
k++;
}
s0 = k;
tokens[j].t0 = sample_to_timestamp(k);
}
}
{
int k = s1;
if (ctx->energy[k] > thold) {
while (k < n_samples - 1 && ctx->energy[k] > thold) {
k++;
}
tokens[j].t1 = sample_to_timestamp(k);
if (j < ns - 1 && tokens[j].t1 > tokens[j + 1].t0) {
tokens[j].t1 = tokens[j + 1].t0;
} else {
s1 = k;
}
} else {
while (ctx->energy[k] < thold && k > s0) {
k--;
}
s1 = k;
tokens[j].t1 = sample_to_timestamp(k);
}
}
}
}
// fixed token expand (optional)
//{
// const int t_expand = 0;
// for (int j = 0; j < n; j++) {
// if (j > 0) {
// tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand));
// }
// if (j < n - 1) {
// tokens[j].t1 = tokens[j].t1 + t_expand;
// }
// }
//}
// debug info
//for (int j = 0; j < n; ++j) {
// const auto & token = tokens[j];
// const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]";
// printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__,
// tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(ctx, token.id));
// if (tokens[j].id >= whisper_token_eot(ctx)) {
// continue;
// }
//}
}

@ -68,14 +68,21 @@ extern "C" {
typedef int whisper_token; typedef int whisper_token;
struct whisper_token_data { typedef struct whisper_token_data {
whisper_token id; // token id whisper_token id; // token id
whisper_token tid; // forced timestamp token id whisper_token tid; // forced timestamp token id
float p; // probability of the token float p; // probability of the token
float pt; // probability of the timestamp token float pt; // probability of the timestamp token
float ptsum; // sum of probabilities of all timestamp tokens float ptsum; // sum of probabilities of all timestamp tokens
};
// token-level timestamp data
// do not use if you haven't computed token-level timestamps
int64_t t0; // start time of the token
int64_t t1; // end time of the token
float vlen; // voice length of the token
} whisper_token_data;
// Allocates all memory needed for the model and loads the model from the given file. // Allocates all memory needed for the model and loads the model from the given file.
// Returns NULL on failure. // Returns NULL on failure.
@ -129,7 +136,7 @@ extern "C" {
// You can also implement your own sampling method using the whisper_get_probs() function. // You can also implement your own sampling method using the whisper_get_probs() function.
// whisper_sample_best() returns the token with the highest probability // whisper_sample_best() returns the token with the highest probability
// whisper_sample_timestamp() returns the most probable timestamp token // whisper_sample_timestamp() returns the most probable timestamp token
WHISPER_API struct whisper_token_data whisper_sample_best(struct whisper_context * ctx); WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx); WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx);
// Return the id of the specified language, returns -1 if not found // Return the id of the specified language, returns -1 if not found
@ -172,14 +179,15 @@ extern "C" {
// Text segment callback // Text segment callback
// Called on every newly generated text segment // Called on every newly generated text segment
// Use the whisper_full_...() functions to obtain the text segments // Use the whisper_full_...() functions to obtain the text segments
typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, void * user_data); typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data);
struct whisper_full_params { struct whisper_full_params {
enum whisper_sampling_strategy strategy; enum whisper_sampling_strategy strategy;
int n_threads; int n_threads;
int n_max_text_ctx; int n_max_text_ctx;
int offset_ms; int offset_ms; // start offset in ms
int duration_ms; // audio duration to process in ms
bool translate; bool translate;
bool no_context; bool no_context;
@ -188,6 +196,12 @@ extern "C" {
bool print_realtime; bool print_realtime;
bool print_timestamps; bool print_timestamps;
// [EXPERIMENTAL] token-level timestamps
bool token_timestamps; // enable token-level timestamps
float thold_pt; // timestamp token probability threshold (~0.01)
float thold_ptsum; // timestamp token sum probability threshold (~0.01)
int max_len; // max segment length in characters
const char * language; const char * language;
struct { struct {
@ -244,7 +258,7 @@ extern "C" {
// Get token data for the specified token in the specified segment. // Get token data for the specified token in the specified segment.
// This contains probabilities, timestamps, etc. // This contains probabilities, timestamps, etc.
WHISPER_API struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token); WHISPER_API whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token);
// Get the probability of the specified token in the specified segment. // Get the probability of the specified token in the specified segment.
WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token); WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);

@ -1,5 +1,174 @@
#pragma once #pragma once
//
// GGML Tensor Library
//
// This documentation is still a work in progress.
// If you wish some specific topics to be covered, feel free to drop a comment:
//
// https://github.com/ggerganov/whisper.cpp/issues/40
//
// ## Overview
//
// This library implements:
//
// - a set of tensor operations
// - automatic differentiation
// - basic optimization algorithms
//
// The aim of this library is to provide a minimalistic approach for various machine learning tasks. This includes,
// but is not limited to, the following:
//
// - linear regression
// - support vector machines
// - neural networks
//
// The library allows the user to define a certain function using the available tensor operations. This function
// definition is represented internally via a computation graph. Each tensor operation in the function definition
// corresponds to a node in the graph. Having the computation graph defined, the user can choose to compute the
// function's value and/or its gradient with respect to the input variables. Optionally, the function can be optimized
// using one of the available optimization algorithms.
//
// For example, here we define the function: f(x) = a*x^2 + b
//
// {
// struct ggml_init_params params = {
// .mem_size = 16*1024*1024,
// .mem_buffer = NULL,
// };
//
// // memory allocation happens here
// struct ggml_context * ctx = ggml_init(params);
//
// struct ggml_tensor * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
//
// ggml_set_param(ctx, x); // x is an input variable
//
// struct ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
// struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
// struct ggml_tensor * x2 = ggml_mul(ctx, x, x);
// struct ggml_tensor * f = ggml_add(ctx, ggml_mul(ctx, a, x2), b);
//
// ...
// }
//
// Notice that the function definition above does not involve any actual computation. The computation is performed only
// when the user explicitly requests it. For example, to compute the function's value at x = 2.0:
//
// {
// ...
//
// struct ggml_cgraph gf = ggml_build_forward(f);
//
// // set the input variable and parameter values
// ggml_set_f32(x, 2.0f);
// ggml_set_f32(a, 3.0f);
// ggml_set_f32(b, 4.0f);
//
// ggml_graph_compute(ctx0, &gf);
//
// printf("f = %f\n", ggml_get_f32_1d(f, 0));
//
// ...
// }
//
// The actual computation is performed in the ggml_graph_compute() function.
//
// The ggml_new_tensor_...() functions create new tensors. They are allocated in the memory buffer provided to the
// ggml_init() function. You have to be careful not to exceed the memory buffer size. Therefore, you have to know
// in advance how much memory you need for your computation. Alternatively, you can allocate a large enough memory
// and after defining the computation graph, call the ggml_used_mem() function to find out how much memory was
// actually needed.
//
// The ggml_set_param() function marks a tensor as an input variable. This is used by the automatic
// differentiation and optimization algorithms.
//
// The described approach allows to define the function graph once and then compute its forward or backward graphs
// multiple times. All computations will use the same memory buffer allocated in the ggml_init() function. This way
// the user can avoid the memory allocation overhead at runtime.
//
// The library supports multi-dimensional tensors - up to 4 dimensions. The FP16 and FP32 data types are first class
// citizens, but in theory the library can be extended to support FP8 and integer data types.
//
// Each tensor operation produces a new tensor. Initially the library was envisioned to support only the use of unary
// and binary operations. Most of the available operations fall into one of these two categories. With time, it became
// clear that the library needs to support more complex operations. The way to support these operations is not clear
// yet, but a few examples are demonstrated in the following operations:
//
// - ggml_permute()
// - ggml_conv_1d_1s()
// - ggml_conv_1d_2s()
//
// For each tensor operator, the library implements a forward and backward computation function. The forward function
// computes the output tensor value given the input tensor values. The backward function computes the adjoint of the
// input tensors given the adjoint of the output tensor. For a detailed explanation of what this means, take a
// calculus class, or watch the following video:
//
// What is Automatic Differentiation?
// https://www.youtube.com/watch?v=wG_nF1awSSY
//
//
// ## Tensor data (struct ggml_tensor)
//
// The tensors are stored in memory via the ggml_tensor struct. The structure provides information about the size of
// the tensor, the data type, and the memory buffer where the tensor data is stored. Additionally, it contains
// pointers to the "source" tensors - i.e. the tensors that were used to compute the current tensor. For example:
//
// {
// struct ggml_tensor * c = ggml_add(ctx, a, b);
//
// assert(c->src[0] == a);
// assert(c->src[1] == b);
// }
//
// The multi-dimensional tensors are stored in row-major order. The ggml_tensor struct contains fields for the
// number of elements in each dimension ("ne") as well as the number of bytes ("nb", a.k.a. stride). This allows
// to store tensors that are not contiguous in memory, which is useful for operations such as transposition and
// permutation. All tensor operations have to take the stride into account and not assume that the tensor is
// contiguous in memory.
//
// The data of the tensor is accessed via the "data" pointer. For example:
//
// {
// struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2, 3);
//
// // a[1, 2] = 1.0f;
// *(float *) ((char *) a->data + 2*a->nb[1] + 1*a->nb[0]) = 1.0f;
//
// // a[2, 0] = 2.0f;
// *(float *) ((char *) a->data + 0*a->nb[1] + 2*a->nb[0]) = 2.0f;
//
// ...
// }
//
// Alternatively, there are helper functions, such as ggml_get_f32_1d() and ggml_set_f32_1d() that can be used.
//
// ## The matrix multiplication operator (ggml_mul_mat)
//
// TODO
//
//
// ## Multi-threading
//
// TODO
//
//
// ## Overview of ggml.c
//
// TODO
//
//
// ## SIMD optimizations
//
// TODO
//
//
// ## Debugging ggml
//
// TODO
//
//
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
@ -21,6 +190,7 @@ typedef __fp16 ggml_fp16_t;
typedef uint16_t ggml_fp16_t; typedef uint16_t ggml_fp16_t;
#endif #endif
// convert FP16 <-> FP32
float ggml_fp16_to_fp32(ggml_fp16_t x); float ggml_fp16_to_fp32(ggml_fp16_t x);
ggml_fp16_t ggml_fp32_to_fp16(float x); ggml_fp16_t ggml_fp32_to_fp16(float x);
@ -36,6 +206,7 @@ enum ggml_type {
GGML_TYPE_COUNT, GGML_TYPE_COUNT,
}; };
// available tensor operations:
enum ggml_op { enum ggml_op {
GGML_OP_NONE = 0, GGML_OP_NONE = 0,
@ -136,7 +307,7 @@ struct ggml_init_params {
void * mem_buffer; // if NULL, memory will be allocated internally void * mem_buffer; // if NULL, memory will be allocated internally
}; };
void ggml_time_init(void); void ggml_time_init(void); // call this once at the beginning of the program
int64_t ggml_time_ms(void); int64_t ggml_time_ms(void);
int64_t ggml_time_us(void); int64_t ggml_time_us(void);
int64_t ggml_cycles(void); int64_t ggml_cycles(void);

@ -14,7 +14,7 @@
#include <stdint.h> #include <stdint.h>
#include <stdio.h> #include <stdio.h>
#if defined _MSC_VER #if defined _MSC_VER || defined(__MINGW32__)
#include <Windows.h> #include <Windows.h>
typedef volatile LONG atomic_int; typedef volatile LONG atomic_int;
@ -37,8 +37,14 @@ typedef HANDLE pthread_t;
typedef DWORD thread_ret_t; typedef DWORD thread_ret_t;
static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) { static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) {
out = CreateThread(NULL, 0, func, arg, 0, NULL); HANDLE handle = CreateThread(NULL, 0, func, arg, 0, NULL);
return out != NULL; if (handle == NULL)
{
return EAGAIN;
}
*out = handle;
return 0;
} }
static int pthread_join(pthread_t thread, void* unused) { static int pthread_join(pthread_t thread, void* unused) {
@ -198,7 +204,7 @@ static ggml_fp16_t table_exp_f16[1 << 16];
// timing // timing
// //
#if defined(_MSC_VER) #if defined(_MSC_VER) || defined(__MINGW32__)
static int64_t timer_freq; static int64_t timer_freq;
void ggml_time_init(void) { void ggml_time_init(void) {
LARGE_INTEGER frequency; LARGE_INTEGER frequency;
@ -3150,7 +3156,10 @@ void ggml_compute_forward_add_f32(
GGML_ASSERT(nb00 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float));
if (nb10 == sizeof(float)) { if (nb10 == sizeof(float)) {
for (int j = ith; j < n; j += nth) { const int j0 = (n/nth)*ith;
const int j1 = ith == nth - 1 ? n : (n/nth)*(ith + 1);
for (int j = j0; j < j1; j++) {
ggml_vec_add_f32(nc, ggml_vec_add_f32(nc,
(float *) ((char *) dst->data + j*nb1), (float *) ((char *) dst->data + j*nb1),
(float *) ((char *) src0->data + j*nb01), (float *) ((char *) src0->data + j*nb01),
@ -6857,7 +6866,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
} break; } break;
case GGML_OP_ADD: case GGML_OP_ADD:
{ {
node->n_tasks = 1; node->n_tasks = n_threads;
} break; } break;
case GGML_OP_SUB: case GGML_OP_SUB:
case GGML_OP_MUL: case GGML_OP_MUL:
@ -8089,7 +8098,7 @@ int ggml_cpu_has_avx512(void) {
} }
int ggml_cpu_has_neon(void) { int ggml_cpu_has_neon(void) {
#if defined(__ARM_NEON__) #if defined(__ARM_NEON)
return 1; return 1;
#else #else
return 0; return 0;

Loading…
Cancel
Save