ggml : sync with latest code from whisper.cpp

pull/12/head
Georgi Gerganov 2 years ago
parent 90ee5c6358
commit f56828ed78
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -36,6 +36,10 @@ std::string to_timestamp(int64_t t, bool comma = false) {
return std::string(buf); return std::string(buf);
} }
int timestamp_to_sample(int64_t t, int n_samples) {
return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
}
// helper function to replace substrings // helper function to replace substrings
void replace_all(std::string & s, const std::string & search, const std::string & replace) { void replace_all(std::string & s, const std::string & search, const std::string & replace) {
for (size_t pos = 0; ; pos += replace.length()) { for (size_t pos = 0; ; pos += replace.length()) {
@ -48,7 +52,6 @@ void replace_all(std::string & s, const std::string & search, const std::string
// command-line parameters // command-line parameters
struct whisper_params { struct whisper_params {
int32_t seed = -1; // RNG seed, not used currently
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t n_processors = 1; int32_t n_processors = 1;
int32_t offset_t_ms = 0; int32_t offset_t_ms = 0;
@ -59,15 +62,16 @@ struct whisper_params {
float word_thold = 0.01f; float word_thold = 0.01f;
bool verbose = false; bool speed_up = false;
bool translate = false; bool translate = false;
bool output_txt = false; bool diarize = false;
bool output_vtt = false; bool output_txt = false;
bool output_srt = false; bool output_vtt = false;
bool output_wts = false; bool output_srt = false;
bool print_special_tokens = false; bool output_wts = false;
bool print_colors = false; bool print_special = false;
bool no_timestamps = false; bool print_colors = false;
bool no_timestamps = false;
std::string language = "en"; std::string language = "en";
std::string model = "models/ggml-base.en.bin"; std::string model = "models/ggml-base.en.bin";
@ -86,57 +90,32 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
continue; continue;
} }
if (arg == "-s" || arg == "--seed") { if (arg == "-h" || arg == "--help") {
params.seed = std::stoi(argv[++i]);
} else if (arg == "-t" || arg == "--threads") {
params.n_threads = std::stoi(argv[++i]);
} else if (arg == "-p" || arg == "--processors") {
params.n_processors = std::stoi(argv[++i]);
} else if (arg == "-ot" || arg == "--offset-t") {
params.offset_t_ms = std::stoi(argv[++i]);
} else if (arg == "-on" || arg == "--offset-n") {
params.offset_n = std::stoi(argv[++i]);
} else if (arg == "-d" || arg == "--duration") {
params.duration_ms = std::stoi(argv[++i]);
} else if (arg == "-mc" || arg == "--max-context") {
params.max_context = std::stoi(argv[++i]);
} else if (arg == "-ml" || arg == "--max-len") {
params.max_len = std::stoi(argv[++i]);
} else if (arg == "-wt" || arg == "--word-thold") {
params.word_thold = std::stof(argv[++i]);
} else if (arg == "-v" || arg == "--verbose") {
params.verbose = true;
} else if (arg == "--translate") {
params.translate = true;
} else if (arg == "-l" || arg == "--language") {
params.language = argv[++i];
if (whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
} else if (arg == "-otxt" || arg == "--output-txt") {
params.output_txt = true;
} else if (arg == "-ovtt" || arg == "--output-vtt") {
params.output_vtt = true;
} else if (arg == "-osrt" || arg == "--output-srt") {
params.output_srt = true;
} else if (arg == "-owts" || arg == "--output-words") {
params.output_wts = true;
} else if (arg == "-ps" || arg == "--print_special") {
params.print_special_tokens = true;
} else if (arg == "-pc" || arg == "--print_colors") {
params.print_colors = true;
} else if (arg == "-nt" || arg == "--no_timestamps") {
params.no_timestamps = true;
} else if (arg == "-m" || arg == "--model") {
params.model = argv[++i];
} else if (arg == "-f" || arg == "--file") {
params.fname_inp.push_back(argv[++i]);
} else if (arg == "-h" || arg == "--help") {
whisper_print_usage(argc, argv, params); whisper_print_usage(argc, argv, params);
exit(0); exit(0);
} else { }
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); }
else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); }
else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); }
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_inp.push_back(argv[++i]); }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params); whisper_print_usage(argc, argv, params);
exit(0); exit(0);
@ -151,33 +130,40 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]); fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "options:\n"); fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help show this help message and exit\n"); fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n"); fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
fprintf(stderr, " -p N, --processors N number of processors to use during computation (default: %d)\n", params.n_processors); fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
fprintf(stderr, " -ot N, --offset-t N time offset in milliseconds (default: %d)\n", params.offset_t_ms); fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
fprintf(stderr, " -on N, --offset-n N segment index offset (default: %d)\n", params.offset_n); fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
fprintf(stderr, " -d N, --duration N duration of audio to process in milliseconds (default: %d)\n", params.duration_ms); fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
fprintf(stderr, " -mc N, --max-context N maximum number of text context tokens to store (default: max)\n"); fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
fprintf(stderr, " -ml N, --max-len N maximum segment length in characters (default: %d)\n", params.max_len); fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
fprintf(stderr, " -wt N, --word-thold N word timestamp probability threshold (default: %f)\n", params.word_thold); fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -v, --verbose verbose output\n"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " --translate translate from source language to english\n"); fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
fprintf(stderr, " -otxt, --output-txt output result in a text file\n"); fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
fprintf(stderr, " -ovtt, --output-vtt output result in a vtt file\n"); fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
fprintf(stderr, " -osrt, --output-srt output result in a srt file\n"); fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
fprintf(stderr, " -owts, --output-words output script for generating karaoke video\n"); fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
fprintf(stderr, " -ps, --print_special print special tokens\n"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pc, --print_colors print colors\n"); fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n"); fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str()); fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -m FNAME, --model FNAME model path (default: %s)\n", params.model.c_str()); fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME input WAV file path\n"); fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
struct whisper_print_user_data {
const whisper_params * params;
const std::vector<std::vector<float>> * pcmf32s;
};
void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) { void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
const whisper_params & params = *(whisper_params *) user_data; const auto & params = *((whisper_print_user_data *) user_data)->params;
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
const int n_segments = whisper_full_n_segments(ctx); const int n_segments = whisper_full_n_segments(ctx);
@ -191,7 +177,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
if (params.no_timestamps) { if (params.no_timestamps) {
if (params.print_colors) { if (params.print_colors) {
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special_tokens == false) { if (params.print_special == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j); const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) { if (id >= whisper_token_eot(ctx)) {
continue; continue;
@ -214,10 +200,37 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
std::string speaker = "";
if (params.diarize && pcmf32s.size() == 2) {
const int64_t n_samples = pcmf32s[0].size();
const int64_t is0 = timestamp_to_sample(t0, n_samples);
const int64_t is1 = timestamp_to_sample(t1, n_samples);
double energy0 = 0.0f;
double energy1 = 0.0f;
for (int64_t j = is0; j < is1; j++) {
energy0 += fabs(pcmf32s[0][j]);
energy1 += fabs(pcmf32s[1][j]);
}
if (energy0 > 1.1*energy1) {
speaker = "(speaker 0)";
} else if (energy1 > 1.1*energy0) {
speaker = "(speaker 1)";
} else {
speaker = "(speaker ?)";
}
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
}
if (params.print_colors) { if (params.print_colors) {
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str()); printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special_tokens == false) { if (params.print_special == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j); const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) { if (id >= whisper_token_eot(ctx)) {
continue; continue;
@ -229,13 +242,13 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size())))); const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m"); printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
} }
printf("\n"); printf("\n");
} else { } else {
const char * text = whisper_full_get_segment_text(ctx, i); const char * text = whisper_full_get_segment_text(ctx, i);
printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text); printf("[%s --> %s] %s%s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), speaker.c_str(), text);
} }
} }
} }
@ -263,7 +276,7 @@ bool output_vtt(struct whisper_context * ctx, const char * fname) {
std::ofstream fout(fname); std::ofstream fout(fname);
if (!fout.is_open()) { if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return 9; return false;
} }
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
@ -386,9 +399,9 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
ncnt += txt.size(); ncnt += txt.size();
} }
::replace_all(txt_bg, "'", ""); ::replace_all(txt_bg, "'", "\u2019");
::replace_all(txt_bg, "\"", "\\\""); ::replace_all(txt_bg, "\"", "\\\"");
::replace_all(txt_fg, "'", ""); ::replace_all(txt_fg, "'", "\u2019");
::replace_all(txt_fg, "\"", "\\\""); ::replace_all(txt_fg, "\"", "\\\"");
} }
@ -428,16 +441,18 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
if (params.seed < 0) {
params.seed = time(NULL);
}
if (params.fname_inp.empty()) { if (params.fname_inp.empty()) {
fprintf(stderr, "error: no input files specified\n"); fprintf(stderr, "error: no input files specified\n");
whisper_print_usage(argc, argv, params); whisper_print_usage(argc, argv, params);
return 2; return 2;
} }
if (whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
// whisper init // whisper init
struct whisper_context * ctx = whisper_init(params.model.c_str()); struct whisper_context * ctx = whisper_init(params.model.c_str());
@ -450,53 +465,60 @@ int main(int argc, char ** argv) {
for (int f = 0; f < (int) params.fname_inp.size(); ++f) { for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto fname_inp = params.fname_inp[f]; const auto fname_inp = params.fname_inp[f];
std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
// WAV input // WAV input
std::vector<float> pcmf32;
{ {
drwav wav; drwav wav;
std::vector<uint8_t> wav_data; // used for pipe input from stdin
if (fname_inp == "-") { if (fname_inp == "-") {
std::vector<uint8_t> wav_data;
{ {
uint8_t buf[1024]; uint8_t buf[1024];
while (true) while (true)
{ {
const size_t n = fread(buf, 1, sizeof(buf), stdin); const size_t n = fread(buf, 1, sizeof(buf), stdin);
if (n == 0) if (n == 0) {
{
break; break;
} }
wav_data.insert(wav_data.end(), buf, buf + n); wav_data.insert(wav_data.end(), buf, buf + n);
} }
} }
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), NULL) == false) if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), NULL) == false) {
{
fprintf(stderr, "error: failed to open WAV file from stdin\n"); fprintf(stderr, "error: failed to open WAV file from stdin\n");
return 4; return 4;
} }
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
} }
else if (drwav_init_file(&wav, fname_inp.c_str(), NULL) == false) { else if (drwav_init_file(&wav, fname_inp.c_str(), NULL) == false) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str()); fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
return 4; return 5;
} }
if (wav.channels != 1 && wav.channels != 2) { if (wav.channels != 1 && wav.channels != 2) {
fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str()); fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str());
return 5; return 6;
}
if (params.diarize && wav.channels != 2 && params.no_timestamps == false) {
fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", argv[0], fname_inp.c_str());
return 6;
} }
if (wav.sampleRate != WHISPER_SAMPLE_RATE) { if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str()); fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str());
return 6; return 8;
} }
if (wav.bitsPerSample != 16) { if (wav.bitsPerSample != 16) {
fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str()); fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str());
return 7; return 9;
} }
int n = wav.totalPCMFrameCount; const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
std::vector<int16_t> pcm16; std::vector<int16_t> pcm16;
pcm16.resize(n*wav.channels); pcm16.resize(n*wav.channels);
@ -514,6 +536,18 @@ int main(int argc, char ** argv) {
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f; pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
} }
} }
if (params.diarize) {
// convert to stereo, float
pcmf32s.resize(2);
pcmf32s[0].resize(n);
pcmf32s[1].resize(n);
for (int i = 0; i < n; i++) {
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
}
}
} }
// print system information // print system information
@ -548,30 +582,47 @@ int main(int argc, char ** argv) {
{ {
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.print_realtime = false; wparams.print_realtime = false;
wparams.print_progress = false; wparams.print_progress = false;
wparams.print_timestamps = !params.no_timestamps; wparams.print_timestamps = !params.no_timestamps;
wparams.print_special_tokens = params.print_special_tokens; wparams.print_special = params.print_special;
wparams.translate = params.translate; wparams.translate = params.translate;
wparams.language = params.language.c_str(); wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads; wparams.n_threads = params.n_threads;
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx; wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
wparams.offset_ms = params.offset_t_ms; wparams.offset_ms = params.offset_t_ms;
wparams.duration_ms = params.duration_ms; wparams.duration_ms = params.duration_ms;
wparams.token_timestamps = params.output_wts || params.max_len > 0; wparams.token_timestamps = params.output_wts || params.max_len > 0;
wparams.thold_pt = params.word_thold; wparams.thold_pt = params.word_thold;
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
wparams.speed_up = params.speed_up;
whisper_print_user_data user_data = { &params, &pcmf32s };
// this callback is called on each new segment // this callback is called on each new segment
if (!wparams.print_realtime) { if (!wparams.print_realtime) {
wparams.new_segment_callback = whisper_print_segment_callback; wparams.new_segment_callback = whisper_print_segment_callback;
wparams.new_segment_callback_user_data = &params; wparams.new_segment_callback_user_data = &user_data;
}
// example for abort mechanism
// in this example, we do not abort the processing, but we could if the flag is set to true
// the callback is called before every encoder run - if it returns false, the processing is aborted
{
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
wparams.encoder_begin_callback = [](struct whisper_context * ctx, void * user_data) {
bool is_aborted = *(bool*)user_data;
return !is_aborted;
};
wparams.encoder_begin_callback_user_data = &is_aborted;
} }
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) { if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
fprintf(stderr, "%s: failed to process audio\n", argv[0]); fprintf(stderr, "%s: failed to process audio\n", argv[0]);
return 8; return 10;
} }
} }

@ -424,6 +424,9 @@ struct whisper_context {
int64_t t_last; int64_t t_last;
whisper_token tid_last; whisper_token tid_last;
std::vector<float> energy; // PCM signal energy std::vector<float> energy; // PCM signal energy
// [EXPERIMENTAL] speed-up techniques
int32_t exp_n_audio_ctx; // 0 - use default
}; };
// load the model from a ggml file // load the model from a ggml file
@ -515,15 +518,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
wctx.buf_memory.resize(MEM_REQ_MEMORY.at(model.type)); wctx.buf_memory.resize(MEM_REQ_MEMORY.at(model.type));
wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type))); wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
// this is the total memory required to run the inference
const size_t mem_required =
wctx.buf_model->size() +
wctx.buf_memory.size() +
wctx.buf_compute.size() +
wctx.buf_compute_layer.size();
fprintf(stderr, "%s: mem_required = %.2f MB\n", __func__, mem_required / 1024.0 / 1024.0);
} }
// load mel filters // load mel filters
@ -596,11 +590,21 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
} }
} }
{
// this is the total memory required to run the inference
const size_t mem_required =
wctx.buf_model->size() +
wctx.buf_memory.size() +
wctx.buf_compute.size() +
wctx.buf_compute_layer.size();
fprintf(stderr, "%s: mem_required = %7.2f MB\n", __func__, mem_required / 1024.0 / 1024.0);
}
// for the big tensors, we have the option to store the data in 16-bit floats // for the big tensors, we have the option to store the data in 16-bit floats
// in order to save memory and also to speed up the computation // in order to save memory and also to speed up the computation
const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32; const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
size_t ctx_size = 0; size_t ctx_size = 0;
size_t ctx_mem_size = 0; size_t ctx_mem_size = 0;
@ -613,7 +617,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
const int n_audio_state = hparams.n_audio_state; const int n_audio_state = hparams.n_audio_state;
const int n_audio_layer = hparams.n_audio_layer; const int n_audio_layer = hparams.n_audio_layer;
const int n_text_ctx = hparams.n_text_ctx; const int n_text_ctx = hparams.n_text_ctx;
const int n_text_state = hparams.n_text_state; const int n_text_state = hparams.n_text_state;
const int n_text_layer = hparams.n_text_layer; const int n_text_layer = hparams.n_text_layer;
@ -719,7 +723,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
fprintf(stderr, "%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0)); fprintf(stderr, "%s: ggml ctx size = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
} }
// create the ggml context // create the ggml context
@ -748,7 +752,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
const int n_audio_state = hparams.n_audio_state; const int n_audio_state = hparams.n_audio_state;
const int n_audio_layer = hparams.n_audio_layer; const int n_audio_layer = hparams.n_audio_layer;
const int n_text_ctx = hparams.n_text_ctx; const int n_text_ctx = hparams.n_text_ctx;
const int n_text_state = hparams.n_text_state; const int n_text_state = hparams.n_text_state;
const int n_text_layer = hparams.n_text_layer; const int n_text_layer = hparams.n_text_layer;
@ -967,7 +971,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
// key/value memory for the cross-attention layer // key/value memory for the cross-attention layer
{ {
const int n_audio_ctx = hparams.n_audio_ctx; const int n_audio_ctx = hparams.n_audio_ctx;
const int n_mem = n_text_layer*n_audio_ctx; const int n_mem = n_text_layer*n_audio_ctx;
const int n_elements = n_text_state*n_mem; const int n_elements = n_text_state*n_mem;
@ -980,7 +984,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) + ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) +
ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v); ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
fprintf(stderr, "%s: memory size = %8.2f MB\n", __func__, memory_size/1024.0/1024.0); fprintf(stderr, "%s: memory size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
} }
// load weights // load weights
@ -1039,12 +1043,12 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor)); fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
//printf("%24s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0); //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
total_size += ggml_nbytes(tensor); total_size += ggml_nbytes(tensor);
model.n_loaded++; model.n_loaded++;
} }
fprintf(stderr, "%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0); fprintf(stderr, "%s: model size = %7.2f MB\n", __func__, total_size/1024.0/1024.0);
if (model.n_loaded == 0) { if (model.n_loaded == 0) {
fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__); fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
@ -1076,13 +1080,11 @@ static bool whisper_encode(
const auto & mel_inp = wctx.mel; const auto & mel_inp = wctx.mel;
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
const int n_ctx = hparams.n_audio_ctx; const int n_ctx = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
const int n_state = hparams.n_audio_state; const int n_state = hparams.n_audio_state;
const int n_head = hparams.n_audio_head; const int n_head = hparams.n_audio_head;
const int n_layer = hparams.n_audio_layer; const int n_layer = hparams.n_audio_layer;
const int N = n_ctx;
const int n_mels = hparams.n_mels; const int n_mels = hparams.n_mels;
assert(mel_inp.n_mel == n_mels); assert(mel_inp.n_mel == n_mels);
@ -1132,7 +1134,30 @@ static bool whisper_encode(
cur = ggml_gelu(ctx0, cur); cur = ggml_gelu(ctx0, cur);
} }
cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur)); // ===================================================================
// NOTE: experimenting with partial evaluation of the encoder (ignore)
//static int iter = -1;
//const int n_iter = 1500/n_ctx;
//iter = (iter + 1) % n_iter;
//if (iter == 0) {
// memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k));
// memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v));
//}
static int iter = 0;
const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe);
const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur));
// ===================================================================
// original:
//cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
struct ggml_tensor * inpL = cur; struct ggml_tensor * inpL = cur;
@ -1198,14 +1223,14 @@ static bool whisper_encode(
ggml_permute(ctxL, ggml_permute(ctxL,
ggml_cpy(ctxL, ggml_cpy(ctxL,
Qcur, Qcur,
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)), ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3); 0, 2, 1, 3);
struct ggml_tensor * K = struct ggml_tensor * K =
ggml_permute(ctxL, ggml_permute(ctxL,
ggml_cpy(ctxL, ggml_cpy(ctxL,
Kcur, Kcur,
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)), ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3); 0, 2, 1, 3);
struct ggml_tensor * V = struct ggml_tensor * V =
@ -1213,9 +1238,9 @@ static bool whisper_encode(
ggml_permute(ctxL, ggml_permute(ctxL,
ggml_reshape_3d(ctxL, ggml_reshape_3d(ctxL,
Vcur, Vcur,
n_state/n_head, n_head, N), n_state/n_head, n_head, n_ctx),
1, 2, 0, 3), 1, 2, 0, 3),
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, N, n_state/n_head, n_head) ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_ctx, n_state/n_head, n_head)
); );
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false); struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
@ -1224,14 +1249,14 @@ static bool whisper_encode(
ggml_permute(ctxL, ggml_permute(ctxL,
ggml_cpy(ctxL, ggml_cpy(ctxL,
Qcur, Qcur,
ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)), ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3); 0, 2, 1, 3);
struct ggml_tensor * K = struct ggml_tensor * K =
ggml_permute(ctxL, ggml_permute(ctxL,
ggml_cpy(ctxL, ggml_cpy(ctxL,
Kcur, Kcur,
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)), ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3); 0, 2, 1, 3);
// K * Q // K * Q
@ -1249,7 +1274,7 @@ static bool whisper_encode(
// ggml_permute(ctxL, // ggml_permute(ctxL,
// ggml_cpy(ctxL, // ggml_cpy(ctxL,
// Vcur, // Vcur,
// ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)), // ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
// 1, 2, 0, 3); // 1, 2, 0, 3);
//struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max); //struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
@ -1259,9 +1284,9 @@ static bool whisper_encode(
ggml_permute(ctxL, ggml_permute(ctxL,
ggml_reshape_3d(ctxL, ggml_reshape_3d(ctxL,
Vcur, Vcur,
n_state/n_head, n_head, N), n_state/n_head, n_head, n_ctx),
0, 2, 1, 3), 0, 2, 1, 3),
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, N, n_head) ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_ctx, n_head)
); );
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max); struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
@ -1271,7 +1296,7 @@ static bool whisper_encode(
cur = ggml_cpy(ctxL, cur = ggml_cpy(ctxL,
KQV_merged, KQV_merged,
ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N)); ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, n_ctx));
} }
// projection // projection
@ -1425,6 +1450,8 @@ static bool whisper_encode(
Vcross), Vcross),
Vcross); Vcross);
//struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
//struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*n_ctx)); struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*n_ctx));
struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*n_ctx)); struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*n_ctx));
@ -1474,7 +1501,7 @@ static bool whisper_decode(
const int n_layer = hparams.n_text_layer; const int n_layer = hparams.n_text_layer;
const int N = n_tokens; const int N = n_tokens;
const int M = hparams.n_audio_ctx; const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
struct ggml_init_params params = { struct ggml_init_params params = {
.mem_size = wctx.buf_compute.size(), .mem_size = wctx.buf_compute.size(),
@ -1819,7 +1846,9 @@ static bool whisper_decode(
// the most basic sampling scheme - select the top token // the most basic sampling scheme - select the top token
static whisper_token_data whisper_sample_best( static whisper_token_data whisper_sample_best(
const whisper_vocab & vocab, const whisper_vocab & vocab,
const float * probs) { const float * probs,
bool force_timestamp,
bool is_initial) {
whisper_token_data result = { whisper_token_data result = {
0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f, 0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
}; };
@ -1842,7 +1871,18 @@ static whisper_token_data whisper_sample_best(
max_tx = std::max(max_tx, probs_id[i].first); max_tx = std::max(max_tx, probs_id[i].first);
} }
for (int i = vocab.token_beg; i < n_logits; i++) { const auto i0 = is_initial ? vocab.token_beg + 101 : vocab.token_beg;
const auto i1 = is_initial ? vocab.token_beg + 101 : n_logits;
// the initial timestamp cannot be larger than 100
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429
if (is_initial) {
for (int i = i0; i < n_logits; ++ i) {
probs_id[i].first = -INFINITY;
}
}
for (int i = vocab.token_beg; i < i1; i++) {
sum_ts += probs_id[i].first; sum_ts += probs_id[i].first;
if (probs_id[i].first > max_ts) { if (probs_id[i].first > max_ts) {
max_ts = probs_id[i].first; max_ts = probs_id[i].first;
@ -1852,7 +1892,7 @@ static whisper_token_data whisper_sample_best(
// if the probability sum of all timestamp tokens is higher than the max probability of the text tokens - sample a // if the probability sum of all timestamp tokens is higher than the max probability of the text tokens - sample a
// timestamp token // timestamp token
if (sum_ts > max_tx) { if (sum_ts > max_tx || force_timestamp) {
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438 // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
for (int i = 0; i < vocab.token_beg; i++) { for (int i = 0; i < vocab.token_beg; i++) {
probs_id[i].first = -INFINITY; probs_id[i].first = -INFINITY;
@ -1894,39 +1934,6 @@ static whisper_token_data whisper_sample_best(
return result; return result;
} }
// samples only from the timestamps tokens
static whisper_vocab::id whisper_sample_timestamp(
const whisper_vocab & vocab,
const float * probs) {
int n_logits = vocab.id_to_token.size();
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
probs_id.reserve(n_logits);
for (int i = vocab.token_beg + 1; i < n_logits; i++) {
probs_id.push_back(std::make_pair(probs[i], i));
}
const int top_k = 10;
// find the top K tokens
std::partial_sort(
probs_id.begin(),
probs_id.begin() + top_k, probs_id.end(),
[](const std::pair<double, whisper_vocab::id> & a, const std::pair<double, whisper_vocab::id> & b) {
return a.first > b.first;
});
probs_id.resize(top_k);
//printf("\n");
//for (int i = 0; i < (int) probs_id.size(); i++) {
// printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
//}
return probs_id[0].second;
}
// 500 -> 00:05.000 // 500 -> 00:05.000
// 6000 -> 01:00.000 // 6000 -> 01:00.000
static std::string to_timestamp(int64_t t, bool comma = false) { static std::string to_timestamp(int64_t t, bool comma = false) {
@ -2031,6 +2038,7 @@ static bool log_mel_spectrogram(
const int n_mel, const int n_mel,
const int n_threads, const int n_threads,
const whisper_filters & filters, const whisper_filters & filters,
const bool speed_up,
whisper_mel & mel) { whisper_mel & mel) {
// Hanning window // Hanning window
@ -2044,7 +2052,7 @@ static bool log_mel_spectrogram(
mel.n_len = (n_samples)/fft_step; mel.n_len = (n_samples)/fft_step;
mel.data.resize(mel.n_mel*mel.n_len); mel.data.resize(mel.n_mel*mel.n_len);
const int n_fft = 1 + fft_size/2; const int n_fft = 1 + (speed_up ? fft_size/4 : fft_size/2);
//printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len); //printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
//printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate); //printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);
@ -2091,6 +2099,13 @@ static bool log_mel_spectrogram(
//} //}
} }
if (speed_up) {
// scale down in the frequency domain results in a speed up in the time domain
for (int j = 0; j < n_fft; j++) {
fft_out[j] = 0.5*(fft_out[2*j] + fft_out[2*j + 1]);
}
}
// mel spectrogram // mel spectrogram
for (int j = 0; j < mel.n_mel; j++) { for (int j = 0; j < mel.n_mel; j++) {
double sum = 0.0; double sum = 0.0;
@ -2161,6 +2176,12 @@ struct whisper_context * whisper_init(const char * path_model) {
void whisper_free(struct whisper_context * ctx) { void whisper_free(struct whisper_context * ctx) {
if (ctx) { if (ctx) {
if (ctx->model.ctx) {
ggml_free(ctx->model.ctx);
}
if (ctx->model.ctx_mem) {
ggml_free(ctx->model.ctx_mem);
}
if (ctx->buf_model) { if (ctx->buf_model) {
delete ctx->buf_model; delete ctx->buf_model;
} }
@ -2171,7 +2192,21 @@ void whisper_free(struct whisper_context * ctx) {
int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
const int64_t t_start_us = ggml_time_us(); const int64_t t_start_us = ggml_time_us();
if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, ctx->mel)) { if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) {
fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
return -1;
}
ctx->t_mel_us = ggml_time_us() - t_start_us;
return 0;
}
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
const int64_t t_start_us = ggml_time_us();
if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) {
fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__); fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
return -1; return -1;
} }
@ -2229,19 +2264,17 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) { struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) {
const int64_t t_start_sample_us = ggml_time_us(); const int64_t t_start_sample_us = ggml_time_us();
// TODO: simplify const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), false, false);
auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab));
ctx->t_sample_us += ggml_time_us() - t_start_sample_us; ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
return res; return res;
} }
whisper_token whisper_sample_timestamp(struct whisper_context * ctx) { struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial) {
const int64_t t_start_sample_us = ggml_time_us(); const int64_t t_start_sample_us = ggml_time_us();
// TODO: simplify const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), true, is_initial);
auto res = whisper_sample_timestamp(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab));
ctx->t_sample_us += ggml_time_us() - t_start_sample_us; ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
@ -2305,11 +2338,11 @@ whisper_token whisper_token_beg(struct whisper_context * ctx) {
return ctx->vocab.token_beg; return ctx->vocab.token_beg;
} }
whisper_token whisper_token_translate() { whisper_token whisper_token_translate(void) {
return whisper_vocab::token_translate; return whisper_vocab::token_translate;
} }
whisper_token whisper_token_transcribe() { whisper_token whisper_token_transcribe(void) {
return whisper_vocab::token_transcribe; return whisper_vocab::token_transcribe;
} }
@ -2325,6 +2358,27 @@ void whisper_print_timings(struct whisper_context * ctx) {
fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
} }
void whisper_reset_timings(struct whisper_context * ctx) {
ctx->t_sample_us = 0;
ctx->t_encode_us = 0;
ctx->t_decode_us = 0;
}
const char * whisper_print_system_info(void) {
static std::string s;
s = "";
s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | ";
s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
return s.c_str();
}
//////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////
struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) { struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
@ -2334,77 +2388,99 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
case WHISPER_SAMPLING_GREEDY: case WHISPER_SAMPLING_GREEDY:
{ {
result = { result = {
/*.strategy =*/ WHISPER_SAMPLING_GREEDY, /*.strategy =*/ WHISPER_SAMPLING_GREEDY,
/*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
/*.n_max_text_ctx =*/ 16384,
/*.offset_ms =*/ 0,
/*.duration_ms =*/ 0,
/*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), /*.translate =*/ false,
/*.n_max_text_ctx =*/ 16384, /*.no_context =*/ false,
/*.offset_ms =*/ 0, /*.single_segment =*/ false,
/*.duration_ms =*/ 0, /*.print_special =*/ false,
/*.print_progress =*/ true,
/*.print_realtime =*/ false,
/*.print_timestamps =*/ true,
/*.translate =*/ false, /*.token_timestamps =*/ false,
/*.no_context =*/ false, /*.thold_pt =*/ 0.01f,
/*.print_special_tokens =*/ false, /*.thold_ptsum =*/ 0.01f,
/*.print_progress =*/ true, /*.max_len =*/ 0,
/*.print_realtime =*/ false, /*.max_tokens =*/ 0,
/*.print_timestamps =*/ true,
/*.token_timestamps =*/ false, /*.speed_up =*/ false,
/*.thold_pt =*/ 0.01f, /*.audio_ctx =*/ 0,
/*.thold_ptsum =*/ 0.01f,
/*.max_len =*/ 0,
/*.language =*/ "en", /*.prompt_tokens =*/ nullptr,
/*.prompt_n_tokens =*/ 0,
/*.greedy =*/ { /*.language =*/ "en",
/*.greedy =*/ {
/*.n_past =*/ 0, /*.n_past =*/ 0,
}, },
/*.beam_search =*/ { /*.beam_search =*/ {
/*.n_past =*/ -1, /*.n_past =*/ -1,
/*.beam_width =*/ -1, /*.beam_width =*/ -1,
/*.n_best =*/ -1, /*.n_best =*/ -1,
}, },
/*.new_segment_callback =*/ nullptr, /*.new_segment_callback =*/ nullptr,
/*.new_segment_callback_user_data =*/ nullptr, /*.new_segment_callback_user_data =*/ nullptr,
/*.encoder_begin_callback =*/ nullptr,
/*.encoder_begin_callback_user_data =*/ nullptr,
}; };
} break; } break;
case WHISPER_SAMPLING_BEAM_SEARCH: case WHISPER_SAMPLING_BEAM_SEARCH:
{ {
result = { result = {
/*.strategy =*/ WHISPER_SAMPLING_BEAM_SEARCH, /*.strategy =*/ WHISPER_SAMPLING_BEAM_SEARCH,
/*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
/*.n_max_text_ctx =*/ 16384, /*.n_max_text_ctx =*/ 16384,
/*.offset_ms =*/ 0, /*.offset_ms =*/ 0,
/*.duration_ms =*/ 0, /*.duration_ms =*/ 0,
/*.translate =*/ false, /*.translate =*/ false,
/*.no_context =*/ false, /*.no_context =*/ false,
/*.print_special_tokens =*/ false, /*.single_segment =*/ false,
/*.print_progress =*/ true, /*.print_special =*/ false,
/*.print_realtime =*/ false, /*.print_progress =*/ true,
/*.print_timestamps =*/ true, /*.print_realtime =*/ false,
/*.print_timestamps =*/ true,
/*.token_timestamps =*/ false, /*.token_timestamps =*/ false,
/*.thold_pt =*/ 0.01f, /*.thold_pt =*/ 0.01f,
/*.thold_ptsum =*/ 0.01f, /*.thold_ptsum =*/ 0.01f,
/*.max_len =*/ 0, /*.max_len =*/ 0,
/*.max_tokens =*/ 0,
/*.language =*/ "en", /*.speed_up =*/ false,
/*.audio_ctx =*/ 0,
/*.greedy =*/ { /*.prompt_tokens =*/ nullptr,
/*.prompt_n_tokens =*/ 0,
/*.language =*/ "en",
/*.greedy =*/ {
/*.n_past =*/ -1, /*.n_past =*/ -1,
}, },
/*.beam_search =*/ { /*.beam_search =*/ {
/*.n_past =*/ 0, /*.n_past =*/ 0,
/*.beam_width =*/ 10, /*.beam_width =*/ 10,
/*.n_best =*/ 5, /*.n_best =*/ 5,
}, },
/*.new_segment_callback =*/ nullptr, /*.new_segment_callback =*/ nullptr,
/*.new_segment_callback_user_data =*/ nullptr, /*.new_segment_callback_user_data =*/ nullptr,
/*.encoder_begin_callback =*/ nullptr,
/*.encoder_begin_callback_user_data =*/ nullptr,
}; };
} break; } break;
} }
@ -2485,9 +2561,16 @@ int whisper_full(
result_all.clear(); result_all.clear();
// compute log mel spectrogram // compute log mel spectrogram
if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) { if (params.speed_up) {
fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__); if (whisper_pcm_to_mel_phase_vocoder(ctx, samples, n_samples, params.n_threads) != 0) {
return -1; fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
return -1;
}
} else {
if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
return -1;
}
} }
if (params.token_timestamps) { if (params.token_timestamps) {
@ -2513,6 +2596,18 @@ int whisper_full(
prompt_past.clear(); prompt_past.clear();
} }
// prepend the prompt tokens to the prompt_past
if (params.prompt_tokens && params.prompt_n_tokens > 0) {
// parse tokens from the pointer
for (int i = 0; i < params.prompt_n_tokens; i++) {
prompt_past.push_back(params.prompt_tokens[i]);
}
std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
}
// overwrite audio_ctx
ctx->exp_n_audio_ctx = params.audio_ctx;
// these tokens determine the task that will be performed // these tokens determine the task that will be performed
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) }; std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
if (whisper_is_multilingual(ctx)) { if (whisper_is_multilingual(ctx)) {
@ -2548,6 +2643,13 @@ int whisper_full(
break; break;
} }
if (params.encoder_begin_callback) {
if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) {
fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__);
break;
}
}
// encode audio features starting at offset seek // encode audio features starting at offset seek
if (whisper_encode(ctx, seek, params.n_threads) != 0) { if (whisper_encode(ctx, seek, params.n_threads) != 0) {
fprintf(stderr, "%s: failed to encode\n", __func__); fprintf(stderr, "%s: failed to encode\n", __func__);
@ -2570,7 +2672,6 @@ int whisper_full(
prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
bool done = false;
int seek_delta = 100*WHISPER_CHUNK_SIZE; int seek_delta = 100*WHISPER_CHUNK_SIZE;
// print the prompt // print the prompt
@ -2584,7 +2685,9 @@ int whisper_full(
int result_len = 0; int result_len = 0;
tokens_cur.clear(); tokens_cur.clear();
for (int i = 0; i < whisper_n_text_ctx(ctx)/2 - 4; ++i) { bool failed = false;
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) { if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
fprintf(stderr, "%s: failed to decode\n", __func__); fprintf(stderr, "%s: failed to decode\n", __func__);
return 8; return 8;
@ -2601,15 +2704,19 @@ int whisper_full(
// feel free to experiment! // feel free to experiment!
// //
{ {
auto token = whisper_sample_best(ctx); const auto token = (i == 0) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx);
if (i == 0) {
token.tid = whisper_token_beg(ctx);
}
// timestamp token - update sliding window // timestamp token - update sliding window
if (token.id > whisper_token_beg(ctx)) { if (token.id > whisper_token_beg(ctx)) {
seek_delta = 2*(token.id - whisper_token_beg(ctx)); const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx));
// do not allow to go back in time
if (seek_delta != 100*WHISPER_CHUNK_SIZE &&
seek_delta > seek_delta_new && result_len < i) {
break;
}
seek_delta = seek_delta_new;
result_len = i + 1; result_len = i + 1;
} }
@ -2619,19 +2726,25 @@ int whisper_full(
//{ //{
// const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]"; // const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]";
// printf("%s: %10s %6.3f '%s'\n", __func__, tt.c_str(), token.pt, ctx->vocab.id_to_token[token.id].c_str()); // printf("%s: %10s %6d %6.3f '%s'\n", __func__, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str());
//} //}
// end of text token // end of text token
if (token.id == whisper_token_eot(ctx)) { if (token.id == whisper_token_eot(ctx) || (params.max_tokens > 0 && i > params.max_tokens)) {
if (result_len == 0) { if (result_len == 0) {
if (seek + seek_delta + 100 >= seek_end) { if (seek + seek_delta + 100 >= seek_end) {
result_len = i + 1; result_len = i + 1;
} else { } else {
// TODO: figure out how to resolve this failed = true;
fprintf(stderr, "\n%s: failed to generate timestamp token - this should not happen\n\n", __func__); break;
} }
} }
if (params.single_segment) {
result_len = i + 1;
seek_delta = 100*WHISPER_CHUNK_SIZE;
}
break; break;
} }
@ -2642,11 +2755,21 @@ int whisper_full(
} }
} }
if (done) { // sometimes, the decoding can get stuck in a repetition loop
// this is a simple strategy to avoid such cases - we simply flag the decoding as failed and advance
// the sliding window by 1 second
if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
failed = true;
break; break;
} }
} }
if (failed) {
fprintf(stderr, "\n%s: failed to generate timestamp token - using fallback strategy\n\n", __func__);
seek += 100;
continue;
}
// shrink down to result_len // shrink down to result_len
tokens_cur.resize(result_len); tokens_cur.resize(result_len);
@ -2666,23 +2789,26 @@ int whisper_full(
// ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p, // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p,
// ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt); // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt);
if (params.print_special_tokens == false && tokens_cur[i].id >= whisper_token_eot(ctx)) { if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) {
} else { } else {
text += whisper_token_to_str(ctx, tokens_cur[i].id); text += whisper_token_to_str(ctx, tokens_cur[i].id);
} }
if (tokens_cur[i].id > whisper_token_beg(ctx)) { if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) {
const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx)); const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
if (!text.empty()) { if (!text.empty()) {
const auto tt0 = params.speed_up ? 2*t0 : t0;
const auto tt1 = params.speed_up ? 2*t1 : t1;
if (params.print_realtime) { if (params.print_realtime) {
if (params.print_timestamps) { if (params.print_timestamps) {
printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text.c_str()); printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
} else { } else {
printf("%s", text.c_str()); printf("%s", text.c_str());
fflush(stdout); fflush(stdout);
} }
} }
result_all.push_back({ t0, t1, text, {} }); result_all.push_back({ tt0, tt1, text, {} });
for (int j = i0; j <= i; j++) { for (int j = i0; j <= i; j++) {
result_all.back().tokens.push_back(tokens_cur[j]); result_all.back().tokens.push_back(tokens_cur[j]);
} }
@ -2714,16 +2840,19 @@ int whisper_full(
if (!text.empty()) { if (!text.empty()) {
const auto t1 = seek + seek_delta; const auto t1 = seek + seek_delta;
const auto tt0 = params.speed_up ? 2*t0 : t0;
const auto tt1 = params.speed_up ? 2*t1 : t1;
if (params.print_realtime) { if (params.print_realtime) {
if (params.print_timestamps) { if (params.print_timestamps) {
printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text.c_str()); printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
} else { } else {
printf("%s", text.c_str()); printf("%s", text.c_str());
fflush(stdout); fflush(stdout);
} }
} }
result_all.push_back({ t0, t1, text, {} }); result_all.push_back({ tt0, tt1, text, {} });
for (int j = i0; j < (int) tokens_cur.size(); j++) { for (int j = i0; j < (int) tokens_cur.size(); j++) {
result_all.back().tokens.push_back(tokens_cur[j]); result_all.back().tokens.push_back(tokens_cur[j]);
} }
@ -2755,7 +2884,7 @@ int whisper_full_parallel(
struct whisper_full_params params, struct whisper_full_params params,
const float * samples, const float * samples,
int n_samples, int n_samples,
const int n_processors) { int n_processors) {
if (n_processors == 1) { if (n_processors == 1) {
return whisper_full(ctx, params, samples, n_samples); return whisper_full(ctx, params, samples, n_samples);
} }
@ -2805,7 +2934,7 @@ int whisper_full_parallel(
// key/value memory for the cross-attention layer // key/value memory for the cross-attention layer
{ {
const int n_audio_ctx = hparams.n_audio_ctx; const int n_audio_ctx = hparams.n_audio_ctx;
const int n_mem = n_text_layer*n_audio_ctx; const int n_mem = n_text_layer*n_audio_ctx;
const int n_elements = n_text_state*n_mem; const int n_elements = n_text_state*n_mem;
@ -2813,10 +2942,6 @@ int whisper_full_parallel(
model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
} }
const size_t memory_size =
ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) +
ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
} }
} }
@ -2936,20 +3061,6 @@ float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int
return ctx->result_all[i_segment].tokens[i_token].p; return ctx->result_all[i_segment].tokens[i_token].p;
} }
const char * whisper_print_system_info() {
static std::string s;
s = "";
s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
return s.c_str();
}
// ================================================================================================= // =================================================================================================
// //
@ -3036,9 +3147,6 @@ static void whisper_exp_compute_token_level_timestamps(
const int64_t t0 = segment.t0; const int64_t t0 = segment.t0;
const int64_t t1 = segment.t1; const int64_t t1 = segment.t1;
const int s0 = timestamp_to_sample(t0, n_samples);
const int s1 = timestamp_to_sample(t1, n_samples);
const int n = tokens.size(); const int n = tokens.size();
if (n == 0) { if (n == 0) {

@ -72,16 +72,16 @@ extern "C" {
whisper_token id; // token id whisper_token id; // token id
whisper_token tid; // forced timestamp token id whisper_token tid; // forced timestamp token id
float p; // probability of the token float p; // probability of the token
float pt; // probability of the timestamp token float pt; // probability of the timestamp token
float ptsum; // sum of probabilities of all timestamp tokens float ptsum; // sum of probabilities of all timestamp tokens
// token-level timestamp data // token-level timestamp data
// do not use if you haven't computed token-level timestamps // do not use if you haven't computed token-level timestamps
int64_t t0; // start time of the token int64_t t0; // start time of the token
int64_t t1; // end time of the token int64_t t1; // end time of the token
float vlen; // voice length of the token float vlen; // voice length of the token
} whisper_token_data; } whisper_token_data;
// Allocates all memory needed for the model and loads the model from the given file. // Allocates all memory needed for the model and loads the model from the given file.
@ -96,9 +96,9 @@ extern "C" {
// Returns 0 on success // Returns 0 on success
WHISPER_API int whisper_pcm_to_mel( WHISPER_API int whisper_pcm_to_mel(
struct whisper_context * ctx, struct whisper_context * ctx,
const float * samples, const float * samples,
int n_samples, int n_samples,
int n_threads); int n_threads);
// This can be used to set a custom log mel spectrogram inside the provided whisper context. // This can be used to set a custom log mel spectrogram inside the provided whisper context.
// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram. // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
@ -106,9 +106,9 @@ extern "C" {
// Returns 0 on success // Returns 0 on success
WHISPER_API int whisper_set_mel( WHISPER_API int whisper_set_mel(
struct whisper_context * ctx, struct whisper_context * ctx,
const float * data, const float * data,
int n_len, int n_len,
int n_mel); int n_mel);
// Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context. // Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context.
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first. // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
@ -116,8 +116,8 @@ extern "C" {
// Returns 0 on success // Returns 0 on success
WHISPER_API int whisper_encode( WHISPER_API int whisper_encode(
struct whisper_context * ctx, struct whisper_context * ctx,
int offset, int offset,
int n_threads); int n_threads);
// Run the Whisper decoder to obtain the logits and probabilities for the next token. // Run the Whisper decoder to obtain the logits and probabilities for the next token.
// Make sure to call whisper_encode() first. // Make sure to call whisper_encode() first.
@ -126,10 +126,10 @@ extern "C" {
// Returns 0 on success // Returns 0 on success
WHISPER_API int whisper_decode( WHISPER_API int whisper_decode(
struct whisper_context * ctx, struct whisper_context * ctx,
const whisper_token * tokens, const whisper_token * tokens,
int n_tokens, int n_tokens,
int n_past, int n_past,
int n_threads); int n_threads);
// Token sampling methods. // Token sampling methods.
// These are provided for convenience and can be used after each call to whisper_decode(). // These are provided for convenience and can be used after each call to whisper_decode().
@ -137,7 +137,7 @@ extern "C" {
// whisper_sample_best() returns the token with the highest probability // whisper_sample_best() returns the token with the highest probability
// whisper_sample_timestamp() returns the most probable timestamp token // whisper_sample_timestamp() returns the most probable timestamp token
WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx); WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx); WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial);
// Return the id of the specified language, returns -1 if not found // Return the id of the specified language, returns -1 if not found
WHISPER_API int whisper_lang_id(const char * lang); WHISPER_API int whisper_lang_id(const char * lang);
@ -162,11 +162,15 @@ extern "C" {
WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx); WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx);
// Task tokens // Task tokens
WHISPER_API whisper_token whisper_token_translate (); WHISPER_API whisper_token whisper_token_translate (void);
WHISPER_API whisper_token whisper_token_transcribe(); WHISPER_API whisper_token whisper_token_transcribe(void);
// Performance information // Performance information
WHISPER_API void whisper_print_timings(struct whisper_context * ctx); WHISPER_API void whisper_print_timings(struct whisper_context * ctx);
WHISPER_API void whisper_reset_timings(struct whisper_context * ctx);
// Print system information
WHISPER_API const char * whisper_print_system_info(void);
//////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////
@ -181,17 +185,26 @@ extern "C" {
// Use the whisper_full_...() functions to obtain the text segments // Use the whisper_full_...() functions to obtain the text segments
typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data); typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data);
// Encoder begin callback
// If not NULL, called before the encoder starts
// If it returns false, the computation is aborted
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data);
// Parameters for the whisper_full() function
// If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp:
// whisper_full_default_params()
struct whisper_full_params { struct whisper_full_params {
enum whisper_sampling_strategy strategy; enum whisper_sampling_strategy strategy;
int n_threads; int n_threads;
int n_max_text_ctx; int n_max_text_ctx;
int offset_ms; // start offset in ms int offset_ms; // start offset in ms
int duration_ms; // audio duration to process in ms int duration_ms; // audio duration to process in ms
bool translate; bool translate;
bool no_context; bool no_context;
bool print_special_tokens; bool single_segment; // force single segment output (useful for streaming)
bool print_special;
bool print_progress; bool print_progress;
bool print_realtime; bool print_realtime;
bool print_timestamps; bool print_timestamps;
@ -201,6 +214,16 @@ extern "C" {
float thold_pt; // timestamp token probability threshold (~0.01) float thold_pt; // timestamp token probability threshold (~0.01)
float thold_ptsum; // timestamp token sum probability threshold (~0.01) float thold_ptsum; // timestamp token sum probability threshold (~0.01)
int max_len; // max segment length in characters int max_len; // max segment length in characters
int max_tokens; // max tokens per segment (0 = no limit)
// [EXPERIMENTAL] speed-up techniques
bool speed_up; // speed-up the audio by 2x using Phase Vocoder
int audio_ctx; // overwrite the audio context size (0 = use default)
// tokens to provide the whisper model as initial prompt
// these are prepended to any existing text context from a previous call
const whisper_token * prompt_tokens;
int prompt_n_tokens;
const char * language; const char * language;
@ -216,6 +239,9 @@ extern "C" {
whisper_new_segment_callback new_segment_callback; whisper_new_segment_callback new_segment_callback;
void * new_segment_callback_user_data; void * new_segment_callback_user_data;
whisper_encoder_begin_callback encoder_begin_callback;
void * encoder_begin_callback_user_data;
}; };
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy); WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
@ -223,20 +249,20 @@ extern "C" {
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
// Uses the specified decoding strategy to obtain the text. // Uses the specified decoding strategy to obtain the text.
WHISPER_API int whisper_full( WHISPER_API int whisper_full(
struct whisper_context * ctx, struct whisper_context * ctx,
struct whisper_full_params params, struct whisper_full_params params,
const float * samples, const float * samples,
int n_samples); int n_samples);
// Split the input audio in chunks and process each chunk separately using whisper_full() // Split the input audio in chunks and process each chunk separately using whisper_full()
// It seems this approach can offer some speedup in some cases. // It seems this approach can offer some speedup in some cases.
// However, the transcription accuracy can be worse at the beginning and end of each chunk. // However, the transcription accuracy can be worse at the beginning and end of each chunk.
WHISPER_API int whisper_full_parallel( WHISPER_API int whisper_full_parallel(
struct whisper_context * ctx, struct whisper_context * ctx,
struct whisper_full_params params, struct whisper_full_params params,
const float * samples, const float * samples,
int n_samples, int n_samples,
const int n_processors); int n_processors);
// Number of generated text segments. // Number of generated text segments.
// A segment can be a few words, a sentence, or even a paragraph. // A segment can be a few words, a sentence, or even a paragraph.
@ -263,9 +289,6 @@ extern "C" {
// Get the probability of the specified token in the specified segment. // Get the probability of the specified token in the specified segment.
WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token); WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);
// Print system information
WHISPER_API const char * whisper_print_system_info();
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

@ -723,6 +723,7 @@ enum ggml_opt_result ggml_opt(
// system info // system info
// //
int ggml_cpu_has_avx(void);
int ggml_cpu_has_avx2(void); int ggml_cpu_has_avx2(void);
int ggml_cpu_has_avx512(void); int ggml_cpu_has_avx512(void);
int ggml_cpu_has_neon(void); int ggml_cpu_has_neon(void);

@ -15,7 +15,14 @@
#include <stdio.h> #include <stdio.h>
#if defined _MSC_VER || defined(__MINGW32__) #if defined _MSC_VER || defined(__MINGW32__)
#if !defined(__MINGW32__)
#include <Windows.h> #include <Windows.h>
#else
// ref: https://github.com/ggerganov/whisper.cpp/issues/168
#include <windows.h>
#include <errno.h>
#endif
typedef volatile LONG atomic_int; typedef volatile LONG atomic_int;
typedef atomic_int atomic_bool; typedef atomic_int atomic_bool;
@ -37,7 +44,7 @@ typedef HANDLE pthread_t;
typedef DWORD thread_ret_t; typedef DWORD thread_ret_t;
static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) { static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) {
HANDLE handle = CreateThread(NULL, 0, func, arg, 0, NULL); HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);
if (handle == NULL) if (handle == NULL)
{ {
return EAGAIN; return EAGAIN;
@ -372,6 +379,49 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
sumf = _mm_cvtss_f32(r1); sumf = _mm_cvtss_f32(r1);
// leftovers
for (int i = n32; i < n; ++i) {
sumf += x[i]*y[i];
}
#elif defined(__AVX__)
// AVX 256-bit
const int n32 = (n & ~31);
__m256 sum0 = _mm256_setzero_ps();
__m256 sum1 = _mm256_setzero_ps();
__m256 sum2 = _mm256_setzero_ps();
__m256 sum3 = _mm256_setzero_ps();
__m256 x0, x1, x2, x3;
__m256 y0, y1, y2, y3;
for (int i = 0; i < n32; i += 32) {
x0 = _mm256_loadu_ps(x + i + 0);
x1 = _mm256_loadu_ps(x + i + 8);
x2 = _mm256_loadu_ps(x + i + 16);
x3 = _mm256_loadu_ps(x + i + 24);
y0 = _mm256_loadu_ps(y + i + 0);
y1 = _mm256_loadu_ps(y + i + 8);
y2 = _mm256_loadu_ps(y + i + 16);
y3 = _mm256_loadu_ps(y + i + 24);
sum0 = _mm256_add_ps(_mm256_mul_ps(x0, y0), sum0);
sum1 = _mm256_add_ps(_mm256_mul_ps(x1, y1), sum1);
sum2 = _mm256_add_ps(_mm256_mul_ps(x2, y2), sum2);
sum3 = _mm256_add_ps(_mm256_mul_ps(x3, y3), sum3);
}
sum0 = _mm256_add_ps(sum0, sum1);
sum2 = _mm256_add_ps(sum2, sum3);
sum0 = _mm256_add_ps(sum0, sum2);
const __m128 r4 = _mm_add_ps(_mm256_castps256_ps128(sum0), _mm256_extractf128_ps(sum0, 1));
const __m128 r2 = _mm_add_ps(r4, _mm_movehl_ps(r4, r4));
const __m128 r1 = _mm_add_ss(r2, _mm_movehdup_ps(r2));
sumf = _mm_cvtss_f32(r1);
// leftovers // leftovers
for (int i = n32; i < n; ++i) { for (int i = n32; i < n; ++i) {
sumf += x[i]*y[i]; sumf += x[i]*y[i];
@ -569,6 +619,50 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
sumf = _mm_cvtss_f32(r1); sumf = _mm_cvtss_f32(r1);
// leftovers
for (int i = n32; i < n; ++i) {
//GGML_ASSERT(false);
sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
}
#elif defined(__AVX__)
// AVX 256-bit
const int n32 = (n & ~31);
__m256 sum0 = _mm256_setzero_ps();
__m256 sum1 = _mm256_setzero_ps();
__m256 sum2 = _mm256_setzero_ps();
__m256 sum3 = _mm256_setzero_ps();
__m256 x0, x1, x2, x3;
__m256 y0, y1, y2, y3;
for (int i = 0; i < n32; i += 32) {
x0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 )));
x1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 )));
x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16)));
x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24)));
y0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 )));
y1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 )));
y2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16)));
y3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24)));
sum0 = _mm256_add_ps(_mm256_mul_ps(x0, y0), sum0);
sum1 = _mm256_add_ps(_mm256_mul_ps(x1, y1), sum1);
sum2 = _mm256_add_ps(_mm256_mul_ps(x2, y2), sum2);
sum3 = _mm256_add_ps(_mm256_mul_ps(x3, y3), sum3);
}
const __m256 sum01 = _mm256_add_ps(sum0, sum1);
const __m256 sum23 = _mm256_add_ps(sum2, sum3);
const __m256 sum0123 = _mm256_add_ps(sum01, sum23);
const __m128 r4 = _mm_add_ps(_mm256_castps256_ps128(sum0123), _mm256_extractf128_ps(sum0123, 1));
const __m128 r2 = _mm_add_ps(r4, _mm_movehl_ps(r4, r4));
const __m128 r1 = _mm_add_ss(r2, _mm_movehdup_ps(r2));
sumf = _mm_cvtss_f32(r1);
// leftovers // leftovers
for (int i = n32; i < n; ++i) { for (int i = n32; i < n; ++i) {
//GGML_ASSERT(false); //GGML_ASSERT(false);
@ -698,6 +792,41 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
_mm256_storeu_ps(y + i + 24, y3); _mm256_storeu_ps(y + i + 24, y3);
} }
// leftovers
for (int i = n32; i < n; ++i) {
y[i] += x[i]*v;
}
#elif defined(__AVX__)
// AVX 256-bit
const int n32 = (n & ~31);
const __m256 v4 = _mm256_set1_ps(v);
__m256 x0, x1, x2, x3;
__m256 y0, y1, y2, y3;
for (int i = 0; i < n32; i += 32) {
x0 = _mm256_loadu_ps(x + i + 0);
x1 = _mm256_loadu_ps(x + i + 8);
x2 = _mm256_loadu_ps(x + i + 16);
x3 = _mm256_loadu_ps(x + i + 24);
y0 = _mm256_loadu_ps(y + i + 0);
y1 = _mm256_loadu_ps(y + i + 8);
y2 = _mm256_loadu_ps(y + i + 16);
y3 = _mm256_loadu_ps(y + i + 24);
y0 = _mm256_add_ps(_mm256_mul_ps(x0, v4), y0);
y1 = _mm256_add_ps(_mm256_mul_ps(x1, v4), y1);
y2 = _mm256_add_ps(_mm256_mul_ps(x2, v4), y2);
y3 = _mm256_add_ps(_mm256_mul_ps(x3, v4), y3);
_mm256_storeu_ps(y + i + 0, y0);
_mm256_storeu_ps(y + i + 8, y1);
_mm256_storeu_ps(y + i + 16, y2);
_mm256_storeu_ps(y + i + 24, y3);
}
// leftovers // leftovers
for (int i = n32; i < n; ++i) { for (int i = n32; i < n; ++i) {
y[i] += x[i]*v; y[i] += x[i]*v;
@ -859,6 +988,42 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_
_mm_storeu_si128((__m128i*)(y + i + 24), _mm256_cvtps_ph(y3, 0)); _mm_storeu_si128((__m128i*)(y + i + 24), _mm256_cvtps_ph(y3, 0));
} }
// leftovers
for (int i = n32; i < n; ++i) {
GGML_ASSERT(false);
y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v);
}
#elif defined(__AVX__)
// AVX 256-bit
const int n32 = (n & ~31);
const __m256 v8 = _mm256_set1_ps(v);
__m256 x0, x1, x2, x3;
__m256 y0, y1, y2, y3;
for (int i = 0; i < n32; i += 32) {
y0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 )));
y1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 )));
y2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16)));
y3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24)));
x0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 )));
x1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 )));
x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16)));
x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24)));
y0 = _mm256_add_ps(_mm256_mul_ps(x0, v8), y0);
y1 = _mm256_add_ps(_mm256_mul_ps(x1, v8), y1);
y2 = _mm256_add_ps(_mm256_mul_ps(x2, v8), y2);
y3 = _mm256_add_ps(_mm256_mul_ps(x3, v8), y3);
_mm_storeu_si128((__m128i*)(y + i + 0 ), _mm256_cvtps_ph(y0, 0));
_mm_storeu_si128((__m128i*)(y + i + 8 ), _mm256_cvtps_ph(y1, 0));
_mm_storeu_si128((__m128i*)(y + i + 16), _mm256_cvtps_ph(y2, 0));
_mm_storeu_si128((__m128i*)(y + i + 24), _mm256_cvtps_ph(y3, 0));
}
// leftovers // leftovers
for (int i = n32; i < n; ++i) { for (int i = n32; i < n; ++i) {
GGML_ASSERT(false); GGML_ASSERT(false);
@ -8081,6 +8246,14 @@ enum ggml_opt_result ggml_opt(
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
int ggml_cpu_has_avx(void) {
#if defined(__AVX__)
return 1;
#else
return 0;
#endif
}
int ggml_cpu_has_avx2(void) { int ggml_cpu_has_avx2(void) {
#if defined(__AVX2__) #if defined(__AVX2__)
return 1; return 1;

Loading…
Cancel
Save