ggml : sync with latest code from whisper.cpp

pull/12/head
Georgi Gerganov 2 years ago
parent 90ee5c6358
commit f56828ed78
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -36,6 +36,10 @@ std::string to_timestamp(int64_t t, bool comma = false) {
return std::string(buf);
}
int timestamp_to_sample(int64_t t, int n_samples) {
return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
}
// helper function to replace substrings
void replace_all(std::string & s, const std::string & search, const std::string & replace) {
for (size_t pos = 0; ; pos += replace.length()) {
@ -48,7 +52,6 @@ void replace_all(std::string & s, const std::string & search, const std::string
// command-line parameters
struct whisper_params {
int32_t seed = -1; // RNG seed, not used currently
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t n_processors = 1;
int32_t offset_t_ms = 0;
@ -59,13 +62,14 @@ struct whisper_params {
float word_thold = 0.01f;
bool verbose = false;
bool speed_up = false;
bool translate = false;
bool diarize = false;
bool output_txt = false;
bool output_vtt = false;
bool output_srt = false;
bool output_wts = false;
bool print_special_tokens = false;
bool print_special = false;
bool print_colors = false;
bool no_timestamps = false;
@ -86,57 +90,32 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
continue;
}
if (arg == "-s" || arg == "--seed") {
params.seed = std::stoi(argv[++i]);
} else if (arg == "-t" || arg == "--threads") {
params.n_threads = std::stoi(argv[++i]);
} else if (arg == "-p" || arg == "--processors") {
params.n_processors = std::stoi(argv[++i]);
} else if (arg == "-ot" || arg == "--offset-t") {
params.offset_t_ms = std::stoi(argv[++i]);
} else if (arg == "-on" || arg == "--offset-n") {
params.offset_n = std::stoi(argv[++i]);
} else if (arg == "-d" || arg == "--duration") {
params.duration_ms = std::stoi(argv[++i]);
} else if (arg == "-mc" || arg == "--max-context") {
params.max_context = std::stoi(argv[++i]);
} else if (arg == "-ml" || arg == "--max-len") {
params.max_len = std::stoi(argv[++i]);
} else if (arg == "-wt" || arg == "--word-thold") {
params.word_thold = std::stof(argv[++i]);
} else if (arg == "-v" || arg == "--verbose") {
params.verbose = true;
} else if (arg == "--translate") {
params.translate = true;
} else if (arg == "-l" || arg == "--language") {
params.language = argv[++i];
if (whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
if (arg == "-h" || arg == "--help") {
whisper_print_usage(argc, argv, params);
exit(0);
}
} else if (arg == "-otxt" || arg == "--output-txt") {
params.output_txt = true;
} else if (arg == "-ovtt" || arg == "--output-vtt") {
params.output_vtt = true;
} else if (arg == "-osrt" || arg == "--output-srt") {
params.output_srt = true;
} else if (arg == "-owts" || arg == "--output-words") {
params.output_wts = true;
} else if (arg == "-ps" || arg == "--print_special") {
params.print_special_tokens = true;
} else if (arg == "-pc" || arg == "--print_colors") {
params.print_colors = true;
} else if (arg == "-nt" || arg == "--no_timestamps") {
params.no_timestamps = true;
} else if (arg == "-m" || arg == "--model") {
params.model = argv[++i];
} else if (arg == "-f" || arg == "--file") {
params.fname_inp.push_back(argv[++i]);
} else if (arg == "-h" || arg == "--help") {
whisper_print_usage(argc, argv, params);
exit(0);
} else {
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); }
else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); }
else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); }
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_inp.push_back(argv[++i]); }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
@ -151,33 +130,40 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stderr, " -p N, --processors N number of processors to use during computation (default: %d)\n", params.n_processors);
fprintf(stderr, " -ot N, --offset-t N time offset in milliseconds (default: %d)\n", params.offset_t_ms);
fprintf(stderr, " -on N, --offset-n N segment index offset (default: %d)\n", params.offset_n);
fprintf(stderr, " -d N, --duration N duration of audio to process in milliseconds (default: %d)\n", params.duration_ms);
fprintf(stderr, " -mc N, --max-context N maximum number of text context tokens to store (default: max)\n");
fprintf(stderr, " -ml N, --max-len N maximum segment length in characters (default: %d)\n", params.max_len);
fprintf(stderr, " -wt N, --word-thold N word timestamp probability threshold (default: %f)\n", params.word_thold);
fprintf(stderr, " -v, --verbose verbose output\n");
fprintf(stderr, " --translate translate from source language to english\n");
fprintf(stderr, " -otxt, --output-txt output result in a text file\n");
fprintf(stderr, " -ovtt, --output-vtt output result in a vtt file\n");
fprintf(stderr, " -osrt, --output-srt output result in a srt file\n");
fprintf(stderr, " -owts, --output-words output script for generating karaoke video\n");
fprintf(stderr, " -ps, --print_special print special tokens\n");
fprintf(stderr, " -pc, --print_colors print colors\n");
fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n");
fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str());
fprintf(stderr, " -m FNAME, --model FNAME model path (default: %s)\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME input WAV file path\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
fprintf(stderr, "\n");
}
struct whisper_print_user_data {
const whisper_params * params;
const std::vector<std::vector<float>> * pcmf32s;
};
void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
const whisper_params & params = *(whisper_params *) user_data;
const auto & params = *((whisper_print_user_data *) user_data)->params;
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
const int n_segments = whisper_full_n_segments(ctx);
@ -191,7 +177,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
if (params.no_timestamps) {
if (params.print_colors) {
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special_tokens == false) {
if (params.print_special == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
@ -214,10 +200,37 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
std::string speaker = "";
if (params.diarize && pcmf32s.size() == 2) {
const int64_t n_samples = pcmf32s[0].size();
const int64_t is0 = timestamp_to_sample(t0, n_samples);
const int64_t is1 = timestamp_to_sample(t1, n_samples);
double energy0 = 0.0f;
double energy1 = 0.0f;
for (int64_t j = is0; j < is1; j++) {
energy0 += fabs(pcmf32s[0][j]);
energy1 += fabs(pcmf32s[1][j]);
}
if (energy0 > 1.1*energy1) {
speaker = "(speaker 0)";
} else if (energy1 > 1.1*energy0) {
speaker = "(speaker 1)";
} else {
speaker = "(speaker ?)";
}
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
}
if (params.print_colors) {
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special_tokens == false) {
if (params.print_special == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
@ -229,13 +242,13 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
}
printf("\n");
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
printf("[%s --> %s] %s%s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), speaker.c_str(), text);
}
}
}
@ -263,7 +276,7 @@ bool output_vtt(struct whisper_context * ctx, const char * fname) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return 9;
return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
@ -386,9 +399,9 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
ncnt += txt.size();
}
::replace_all(txt_bg, "'", "");
::replace_all(txt_bg, "'", "\u2019");
::replace_all(txt_bg, "\"", "\\\"");
::replace_all(txt_fg, "'", "");
::replace_all(txt_fg, "'", "\u2019");
::replace_all(txt_fg, "\"", "\\\"");
}
@ -428,16 +441,18 @@ int main(int argc, char ** argv) {
return 1;
}
if (params.seed < 0) {
params.seed = time(NULL);
}
if (params.fname_inp.empty()) {
fprintf(stderr, "error: no input files specified\n");
whisper_print_usage(argc, argv, params);
return 2;
}
if (whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
// whisper init
struct whisper_context * ctx = whisper_init(params.model.c_str());
@ -450,53 +465,60 @@ int main(int argc, char ** argv) {
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto fname_inp = params.fname_inp[f];
std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
// WAV input
std::vector<float> pcmf32;
{
drwav wav;
std::vector<uint8_t> wav_data; // used for pipe input from stdin
if (fname_inp == "-") {
std::vector<uint8_t> wav_data;
{
uint8_t buf[1024];
while (true)
{
const size_t n = fread(buf, 1, sizeof(buf), stdin);
if (n == 0)
{
if (n == 0) {
break;
}
wav_data.insert(wav_data.end(), buf, buf + n);
}
}
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), NULL) == false)
{
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), NULL) == false) {
fprintf(stderr, "error: failed to open WAV file from stdin\n");
return 4;
}
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
}
else if (drwav_init_file(&wav, fname_inp.c_str(), NULL) == false) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
return 4;
return 5;
}
if (wav.channels != 1 && wav.channels != 2) {
fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str());
return 5;
return 6;
}
if (params.diarize && wav.channels != 2 && params.no_timestamps == false) {
fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", argv[0], fname_inp.c_str());
return 6;
}
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str());
return 6;
return 8;
}
if (wav.bitsPerSample != 16) {
fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str());
return 7;
return 9;
}
int n = wav.totalPCMFrameCount;
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
std::vector<int16_t> pcm16;
pcm16.resize(n*wav.channels);
@ -514,6 +536,18 @@ int main(int argc, char ** argv) {
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
}
}
if (params.diarize) {
// convert to stereo, float
pcmf32s.resize(2);
pcmf32s[0].resize(n);
pcmf32s[1].resize(n);
for (int i = 0; i < n; i++) {
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
}
}
}
// print system information
@ -551,7 +585,7 @@ int main(int argc, char ** argv) {
wparams.print_realtime = false;
wparams.print_progress = false;
wparams.print_timestamps = !params.no_timestamps;
wparams.print_special_tokens = params.print_special_tokens;
wparams.print_special = params.print_special;
wparams.translate = params.translate;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
@ -563,15 +597,32 @@ int main(int argc, char ** argv) {
wparams.thold_pt = params.word_thold;
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
wparams.speed_up = params.speed_up;
whisper_print_user_data user_data = { &params, &pcmf32s };
// this callback is called on each new segment
if (!wparams.print_realtime) {
wparams.new_segment_callback = whisper_print_segment_callback;
wparams.new_segment_callback_user_data = &params;
wparams.new_segment_callback_user_data = &user_data;
}
// example for abort mechanism
// in this example, we do not abort the processing, but we could if the flag is set to true
// the callback is called before every encoder run - if it returns false, the processing is aborted
{
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
wparams.encoder_begin_callback = [](struct whisper_context * ctx, void * user_data) {
bool is_aborted = *(bool*)user_data;
return !is_aborted;
};
wparams.encoder_begin_callback_user_data = &is_aborted;
}
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
return 8;
return 10;
}
}

@ -424,6 +424,9 @@ struct whisper_context {
int64_t t_last;
whisper_token tid_last;
std::vector<float> energy; // PCM signal energy
// [EXPERIMENTAL] speed-up techniques
int32_t exp_n_audio_ctx; // 0 - use default
};
// load the model from a ggml file
@ -515,15 +518,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
wctx.buf_memory.resize(MEM_REQ_MEMORY.at(model.type));
wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
// this is the total memory required to run the inference
const size_t mem_required =
wctx.buf_model->size() +
wctx.buf_memory.size() +
wctx.buf_compute.size() +
wctx.buf_compute_layer.size();
fprintf(stderr, "%s: mem_required = %.2f MB\n", __func__, mem_required / 1024.0 / 1024.0);
}
// load mel filters
@ -596,11 +590,21 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
}
}
{
// this is the total memory required to run the inference
const size_t mem_required =
wctx.buf_model->size() +
wctx.buf_memory.size() +
wctx.buf_compute.size() +
wctx.buf_compute_layer.size();
fprintf(stderr, "%s: mem_required = %7.2f MB\n", __func__, mem_required / 1024.0 / 1024.0);
}
// for the big tensors, we have the option to store the data in 16-bit floats
// in order to save memory and also to speed up the computation
const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
size_t ctx_size = 0;
size_t ctx_mem_size = 0;
@ -719,7 +723,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
fprintf(stderr, "%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
fprintf(stderr, "%s: ggml ctx size = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
}
// create the ggml context
@ -980,7 +984,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) +
ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
fprintf(stderr, "%s: memory size = %8.2f MB\n", __func__, memory_size/1024.0/1024.0);
fprintf(stderr, "%s: memory size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
}
// load weights
@ -1039,12 +1043,12 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
//printf("%24s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
//printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
total_size += ggml_nbytes(tensor);
model.n_loaded++;
}
fprintf(stderr, "%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
fprintf(stderr, "%s: model size = %7.2f MB\n", __func__, total_size/1024.0/1024.0);
if (model.n_loaded == 0) {
fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
@ -1076,13 +1080,11 @@ static bool whisper_encode(
const auto & mel_inp = wctx.mel;
const auto & hparams = model.hparams;
const int n_ctx = hparams.n_audio_ctx;
const int n_ctx = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
const int n_state = hparams.n_audio_state;
const int n_head = hparams.n_audio_head;
const int n_layer = hparams.n_audio_layer;
const int N = n_ctx;
const int n_mels = hparams.n_mels;
assert(mel_inp.n_mel == n_mels);
@ -1132,7 +1134,30 @@ static bool whisper_encode(
cur = ggml_gelu(ctx0, cur);
}
cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
// ===================================================================
// NOTE: experimenting with partial evaluation of the encoder (ignore)
//static int iter = -1;
//const int n_iter = 1500/n_ctx;
//iter = (iter + 1) % n_iter;
//if (iter == 0) {
// memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k));
// memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v));
//}
static int iter = 0;
const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe);
const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur));
// ===================================================================
// original:
//cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
struct ggml_tensor * inpL = cur;
@ -1198,14 +1223,14 @@ static bool whisper_encode(
ggml_permute(ctxL,
ggml_cpy(ctxL,
Qcur,
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3);
struct ggml_tensor * K =
ggml_permute(ctxL,
ggml_cpy(ctxL,
Kcur,
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3);
struct ggml_tensor * V =
@ -1213,9 +1238,9 @@ static bool whisper_encode(
ggml_permute(ctxL,
ggml_reshape_3d(ctxL,
Vcur,
n_state/n_head, n_head, N),
n_state/n_head, n_head, n_ctx),
1, 2, 0, 3),
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, N, n_state/n_head, n_head)
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_ctx, n_state/n_head, n_head)
);
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
@ -1224,14 +1249,14 @@ static bool whisper_encode(
ggml_permute(ctxL,
ggml_cpy(ctxL,
Qcur,
ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3);
struct ggml_tensor * K =
ggml_permute(ctxL,
ggml_cpy(ctxL,
Kcur,
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3);
// K * Q
@ -1249,7 +1274,7 @@ static bool whisper_encode(
// ggml_permute(ctxL,
// ggml_cpy(ctxL,
// Vcur,
// ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
// ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
// 1, 2, 0, 3);
//struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
@ -1259,9 +1284,9 @@ static bool whisper_encode(
ggml_permute(ctxL,
ggml_reshape_3d(ctxL,
Vcur,
n_state/n_head, n_head, N),
n_state/n_head, n_head, n_ctx),
0, 2, 1, 3),
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, N, n_head)
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_ctx, n_head)
);
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
@ -1271,7 +1296,7 @@ static bool whisper_encode(
cur = ggml_cpy(ctxL,
KQV_merged,
ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, n_ctx));
}
// projection
@ -1425,6 +1450,8 @@ static bool whisper_encode(
Vcross),
Vcross);
//struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
//struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*n_ctx));
struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*n_ctx));
@ -1474,7 +1501,7 @@ static bool whisper_decode(
const int n_layer = hparams.n_text_layer;
const int N = n_tokens;
const int M = hparams.n_audio_ctx;
const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
struct ggml_init_params params = {
.mem_size = wctx.buf_compute.size(),
@ -1819,7 +1846,9 @@ static bool whisper_decode(
// the most basic sampling scheme - select the top token
static whisper_token_data whisper_sample_best(
const whisper_vocab & vocab,
const float * probs) {
const float * probs,
bool force_timestamp,
bool is_initial) {
whisper_token_data result = {
0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
};
@ -1842,7 +1871,18 @@ static whisper_token_data whisper_sample_best(
max_tx = std::max(max_tx, probs_id[i].first);
}
for (int i = vocab.token_beg; i < n_logits; i++) {
const auto i0 = is_initial ? vocab.token_beg + 101 : vocab.token_beg;
const auto i1 = is_initial ? vocab.token_beg + 101 : n_logits;
// the initial timestamp cannot be larger than 100
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429
if (is_initial) {
for (int i = i0; i < n_logits; ++ i) {
probs_id[i].first = -INFINITY;
}
}
for (int i = vocab.token_beg; i < i1; i++) {
sum_ts += probs_id[i].first;
if (probs_id[i].first > max_ts) {
max_ts = probs_id[i].first;
@ -1852,7 +1892,7 @@ static whisper_token_data whisper_sample_best(
// if the probability sum of all timestamp tokens is higher than the max probability of the text tokens - sample a
// timestamp token
if (sum_ts > max_tx) {
if (sum_ts > max_tx || force_timestamp) {
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
for (int i = 0; i < vocab.token_beg; i++) {
probs_id[i].first = -INFINITY;
@ -1894,39 +1934,6 @@ static whisper_token_data whisper_sample_best(
return result;
}
// samples only from the timestamps tokens
static whisper_vocab::id whisper_sample_timestamp(
const whisper_vocab & vocab,
const float * probs) {
int n_logits = vocab.id_to_token.size();
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
probs_id.reserve(n_logits);
for (int i = vocab.token_beg + 1; i < n_logits; i++) {
probs_id.push_back(std::make_pair(probs[i], i));
}
const int top_k = 10;
// find the top K tokens
std::partial_sort(
probs_id.begin(),
probs_id.begin() + top_k, probs_id.end(),
[](const std::pair<double, whisper_vocab::id> & a, const std::pair<double, whisper_vocab::id> & b) {
return a.first > b.first;
});
probs_id.resize(top_k);
//printf("\n");
//for (int i = 0; i < (int) probs_id.size(); i++) {
// printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
//}
return probs_id[0].second;
}
// 500 -> 00:05.000
// 6000 -> 01:00.000
static std::string to_timestamp(int64_t t, bool comma = false) {
@ -2031,6 +2038,7 @@ static bool log_mel_spectrogram(
const int n_mel,
const int n_threads,
const whisper_filters & filters,
const bool speed_up,
whisper_mel & mel) {
// Hanning window
@ -2044,7 +2052,7 @@ static bool log_mel_spectrogram(
mel.n_len = (n_samples)/fft_step;
mel.data.resize(mel.n_mel*mel.n_len);
const int n_fft = 1 + fft_size/2;
const int n_fft = 1 + (speed_up ? fft_size/4 : fft_size/2);
//printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
//printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);
@ -2091,6 +2099,13 @@ static bool log_mel_spectrogram(
//}
}
if (speed_up) {
// scale down in the frequency domain results in a speed up in the time domain
for (int j = 0; j < n_fft; j++) {
fft_out[j] = 0.5*(fft_out[2*j] + fft_out[2*j + 1]);
}
}
// mel spectrogram
for (int j = 0; j < mel.n_mel; j++) {
double sum = 0.0;
@ -2161,6 +2176,12 @@ struct whisper_context * whisper_init(const char * path_model) {
void whisper_free(struct whisper_context * ctx) {
if (ctx) {
if (ctx->model.ctx) {
ggml_free(ctx->model.ctx);
}
if (ctx->model.ctx_mem) {
ggml_free(ctx->model.ctx_mem);
}
if (ctx->buf_model) {
delete ctx->buf_model;
}
@ -2171,7 +2192,21 @@ void whisper_free(struct whisper_context * ctx) {
int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
const int64_t t_start_us = ggml_time_us();
if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, ctx->mel)) {
if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) {
fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
return -1;
}
ctx->t_mel_us = ggml_time_us() - t_start_us;
return 0;
}
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
const int64_t t_start_us = ggml_time_us();
if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) {
fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
return -1;
}
@ -2229,19 +2264,17 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) {
const int64_t t_start_sample_us = ggml_time_us();
// TODO: simplify
auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab));
const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), false, false);
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
return res;
}
whisper_token whisper_sample_timestamp(struct whisper_context * ctx) {
struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial) {
const int64_t t_start_sample_us = ggml_time_us();
// TODO: simplify
auto res = whisper_sample_timestamp(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab));
const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), true, is_initial);
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
@ -2305,11 +2338,11 @@ whisper_token whisper_token_beg(struct whisper_context * ctx) {
return ctx->vocab.token_beg;
}
whisper_token whisper_token_translate() {
whisper_token whisper_token_translate(void) {
return whisper_vocab::token_translate;
}
whisper_token whisper_token_transcribe() {
whisper_token whisper_token_transcribe(void) {
return whisper_vocab::token_transcribe;
}
@ -2325,6 +2358,27 @@ void whisper_print_timings(struct whisper_context * ctx) {
fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
}
void whisper_reset_timings(struct whisper_context * ctx) {
ctx->t_sample_us = 0;
ctx->t_encode_us = 0;
ctx->t_decode_us = 0;
}
const char * whisper_print_system_info(void) {
static std::string s;
s = "";
s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | ";
s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
return s.c_str();
}
////////////////////////////////////////////////////////////////////////////
struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
@ -2343,7 +2397,8 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.translate =*/ false,
/*.no_context =*/ false,
/*.print_special_tokens =*/ false,
/*.single_segment =*/ false,
/*.print_special =*/ false,
/*.print_progress =*/ true,
/*.print_realtime =*/ false,
/*.print_timestamps =*/ true,
@ -2352,6 +2407,13 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.thold_pt =*/ 0.01f,
/*.thold_ptsum =*/ 0.01f,
/*.max_len =*/ 0,
/*.max_tokens =*/ 0,
/*.speed_up =*/ false,
/*.audio_ctx =*/ 0,
/*.prompt_tokens =*/ nullptr,
/*.prompt_n_tokens =*/ 0,
/*.language =*/ "en",
@ -2367,6 +2429,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.new_segment_callback =*/ nullptr,
/*.new_segment_callback_user_data =*/ nullptr,
/*.encoder_begin_callback =*/ nullptr,
/*.encoder_begin_callback_user_data =*/ nullptr,
};
} break;
case WHISPER_SAMPLING_BEAM_SEARCH:
@ -2381,7 +2446,8 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.translate =*/ false,
/*.no_context =*/ false,
/*.print_special_tokens =*/ false,
/*.single_segment =*/ false,
/*.print_special =*/ false,
/*.print_progress =*/ true,
/*.print_realtime =*/ false,
/*.print_timestamps =*/ true,
@ -2390,6 +2456,13 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.thold_pt =*/ 0.01f,
/*.thold_ptsum =*/ 0.01f,
/*.max_len =*/ 0,
/*.max_tokens =*/ 0,
/*.speed_up =*/ false,
/*.audio_ctx =*/ 0,
/*.prompt_tokens =*/ nullptr,
/*.prompt_n_tokens =*/ 0,
/*.language =*/ "en",
@ -2405,6 +2478,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.new_segment_callback =*/ nullptr,
/*.new_segment_callback_user_data =*/ nullptr,
/*.encoder_begin_callback =*/ nullptr,
/*.encoder_begin_callback_user_data =*/ nullptr,
};
} break;
}
@ -2485,10 +2561,17 @@ int whisper_full(
result_all.clear();
// compute log mel spectrogram
if (params.speed_up) {
if (whisper_pcm_to_mel_phase_vocoder(ctx, samples, n_samples, params.n_threads) != 0) {
fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
return -1;
}
} else {
if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
return -1;
}
}
if (params.token_timestamps) {
ctx->t_beg = 0;
@ -2513,6 +2596,18 @@ int whisper_full(
prompt_past.clear();
}
// prepend the prompt tokens to the prompt_past
if (params.prompt_tokens && params.prompt_n_tokens > 0) {
// parse tokens from the pointer
for (int i = 0; i < params.prompt_n_tokens; i++) {
prompt_past.push_back(params.prompt_tokens[i]);
}
std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
}
// overwrite audio_ctx
ctx->exp_n_audio_ctx = params.audio_ctx;
// these tokens determine the task that will be performed
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
if (whisper_is_multilingual(ctx)) {
@ -2548,6 +2643,13 @@ int whisper_full(
break;
}
if (params.encoder_begin_callback) {
if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) {
fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__);
break;
}
}
// encode audio features starting at offset seek
if (whisper_encode(ctx, seek, params.n_threads) != 0) {
fprintf(stderr, "%s: failed to encode\n", __func__);
@ -2570,7 +2672,6 @@ int whisper_full(
prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
bool done = false;
int seek_delta = 100*WHISPER_CHUNK_SIZE;
// print the prompt
@ -2584,7 +2685,9 @@ int whisper_full(
int result_len = 0;
tokens_cur.clear();
for (int i = 0; i < whisper_n_text_ctx(ctx)/2 - 4; ++i) {
bool failed = false;
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
fprintf(stderr, "%s: failed to decode\n", __func__);
return 8;
@ -2601,15 +2704,19 @@ int whisper_full(
// feel free to experiment!
//
{
auto token = whisper_sample_best(ctx);
if (i == 0) {
token.tid = whisper_token_beg(ctx);
}
const auto token = (i == 0) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx);
// timestamp token - update sliding window
if (token.id > whisper_token_beg(ctx)) {
seek_delta = 2*(token.id - whisper_token_beg(ctx));
const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx));
// do not allow to go back in time
if (seek_delta != 100*WHISPER_CHUNK_SIZE &&
seek_delta > seek_delta_new && result_len < i) {
break;
}
seek_delta = seek_delta_new;
result_len = i + 1;
}
@ -2619,19 +2726,25 @@ int whisper_full(
//{
// const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]";
// printf("%s: %10s %6.3f '%s'\n", __func__, tt.c_str(), token.pt, ctx->vocab.id_to_token[token.id].c_str());
// printf("%s: %10s %6d %6.3f '%s'\n", __func__, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str());
//}
// end of text token
if (token.id == whisper_token_eot(ctx)) {
if (token.id == whisper_token_eot(ctx) || (params.max_tokens > 0 && i > params.max_tokens)) {
if (result_len == 0) {
if (seek + seek_delta + 100 >= seek_end) {
result_len = i + 1;
} else {
// TODO: figure out how to resolve this
fprintf(stderr, "\n%s: failed to generate timestamp token - this should not happen\n\n", __func__);
failed = true;
break;
}
}
if (params.single_segment) {
result_len = i + 1;
seek_delta = 100*WHISPER_CHUNK_SIZE;
}
break;
}
@ -2642,11 +2755,21 @@ int whisper_full(
}
}
if (done) {
// sometimes, the decoding can get stuck in a repetition loop
// this is a simple strategy to avoid such cases - we simply flag the decoding as failed and advance
// the sliding window by 1 second
if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
failed = true;
break;
}
}
if (failed) {
fprintf(stderr, "\n%s: failed to generate timestamp token - using fallback strategy\n\n", __func__);
seek += 100;
continue;
}
// shrink down to result_len
tokens_cur.resize(result_len);
@ -2666,23 +2789,26 @@ int whisper_full(
// ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p,
// ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt);
if (params.print_special_tokens == false && tokens_cur[i].id >= whisper_token_eot(ctx)) {
if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) {
} else {
text += whisper_token_to_str(ctx, tokens_cur[i].id);
}
if (tokens_cur[i].id > whisper_token_beg(ctx)) {
if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) {
const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
if (!text.empty()) {
const auto tt0 = params.speed_up ? 2*t0 : t0;
const auto tt1 = params.speed_up ? 2*t1 : t1;
if (params.print_realtime) {
if (params.print_timestamps) {
printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text.c_str());
printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
} else {
printf("%s", text.c_str());
fflush(stdout);
}
}
result_all.push_back({ t0, t1, text, {} });
result_all.push_back({ tt0, tt1, text, {} });
for (int j = i0; j <= i; j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
@ -2714,16 +2840,19 @@ int whisper_full(
if (!text.empty()) {
const auto t1 = seek + seek_delta;
const auto tt0 = params.speed_up ? 2*t0 : t0;
const auto tt1 = params.speed_up ? 2*t1 : t1;
if (params.print_realtime) {
if (params.print_timestamps) {
printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text.c_str());
printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
} else {
printf("%s", text.c_str());
fflush(stdout);
}
}
result_all.push_back({ t0, t1, text, {} });
result_all.push_back({ tt0, tt1, text, {} });
for (int j = i0; j < (int) tokens_cur.size(); j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
@ -2755,7 +2884,7 @@ int whisper_full_parallel(
struct whisper_full_params params,
const float * samples,
int n_samples,
const int n_processors) {
int n_processors) {
if (n_processors == 1) {
return whisper_full(ctx, params, samples, n_samples);
}
@ -2813,10 +2942,6 @@ int whisper_full_parallel(
model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
}
const size_t memory_size =
ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) +
ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
}
}
@ -2936,20 +3061,6 @@ float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int
return ctx->result_all[i_segment].tokens[i_token].p;
}
const char * whisper_print_system_info() {
static std::string s;
s = "";
s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
return s.c_str();
}
// =================================================================================================
//
@ -3036,9 +3147,6 @@ static void whisper_exp_compute_token_level_timestamps(
const int64_t t0 = segment.t0;
const int64_t t1 = segment.t1;
const int s0 = timestamp_to_sample(t0, n_samples);
const int s1 = timestamp_to_sample(t1, n_samples);
const int n = tokens.size();
if (n == 0) {

@ -137,7 +137,7 @@ extern "C" {
// whisper_sample_best() returns the token with the highest probability
// whisper_sample_timestamp() returns the most probable timestamp token
WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx);
WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial);
// Return the id of the specified language, returns -1 if not found
WHISPER_API int whisper_lang_id(const char * lang);
@ -162,11 +162,15 @@ extern "C" {
WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx);
// Task tokens
WHISPER_API whisper_token whisper_token_translate ();
WHISPER_API whisper_token whisper_token_transcribe();
WHISPER_API whisper_token whisper_token_translate (void);
WHISPER_API whisper_token whisper_token_transcribe(void);
// Performance information
WHISPER_API void whisper_print_timings(struct whisper_context * ctx);
WHISPER_API void whisper_reset_timings(struct whisper_context * ctx);
// Print system information
WHISPER_API const char * whisper_print_system_info(void);
////////////////////////////////////////////////////////////////////////////
@ -181,6 +185,14 @@ extern "C" {
// Use the whisper_full_...() functions to obtain the text segments
typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data);
// Encoder begin callback
// If not NULL, called before the encoder starts
// If it returns false, the computation is aborted
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data);
// Parameters for the whisper_full() function
// If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp:
// whisper_full_default_params()
struct whisper_full_params {
enum whisper_sampling_strategy strategy;
@ -191,7 +203,8 @@ extern "C" {
bool translate;
bool no_context;
bool print_special_tokens;
bool single_segment; // force single segment output (useful for streaming)
bool print_special;
bool print_progress;
bool print_realtime;
bool print_timestamps;
@ -201,6 +214,16 @@ extern "C" {
float thold_pt; // timestamp token probability threshold (~0.01)
float thold_ptsum; // timestamp token sum probability threshold (~0.01)
int max_len; // max segment length in characters
int max_tokens; // max tokens per segment (0 = no limit)
// [EXPERIMENTAL] speed-up techniques
bool speed_up; // speed-up the audio by 2x using Phase Vocoder
int audio_ctx; // overwrite the audio context size (0 = use default)
// tokens to provide the whisper model as initial prompt
// these are prepended to any existing text context from a previous call
const whisper_token * prompt_tokens;
int prompt_n_tokens;
const char * language;
@ -216,6 +239,9 @@ extern "C" {
whisper_new_segment_callback new_segment_callback;
void * new_segment_callback_user_data;
whisper_encoder_begin_callback encoder_begin_callback;
void * encoder_begin_callback_user_data;
};
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
@ -236,7 +262,7 @@ extern "C" {
struct whisper_full_params params,
const float * samples,
int n_samples,
const int n_processors);
int n_processors);
// Number of generated text segments.
// A segment can be a few words, a sentence, or even a paragraph.
@ -263,9 +289,6 @@ extern "C" {
// Get the probability of the specified token in the specified segment.
WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);
// Print system information
WHISPER_API const char * whisper_print_system_info();
#ifdef __cplusplus
}
#endif

@ -723,6 +723,7 @@ enum ggml_opt_result ggml_opt(
// system info
//
int ggml_cpu_has_avx(void);
int ggml_cpu_has_avx2(void);
int ggml_cpu_has_avx512(void);
int ggml_cpu_has_neon(void);

@ -15,7 +15,14 @@
#include <stdio.h>
#if defined _MSC_VER || defined(__MINGW32__)
#if !defined(__MINGW32__)
#include <Windows.h>
#else
// ref: https://github.com/ggerganov/whisper.cpp/issues/168
#include <windows.h>
#include <errno.h>
#endif
typedef volatile LONG atomic_int;
typedef atomic_int atomic_bool;
@ -37,7 +44,7 @@ typedef HANDLE pthread_t;
typedef DWORD thread_ret_t;
static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) {
HANDLE handle = CreateThread(NULL, 0, func, arg, 0, NULL);
HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);
if (handle == NULL)
{
return EAGAIN;
@ -372,6 +379,49 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
sumf = _mm_cvtss_f32(r1);
// leftovers
for (int i = n32; i < n; ++i) {
sumf += x[i]*y[i];
}
#elif defined(__AVX__)
// AVX 256-bit
const int n32 = (n & ~31);
__m256 sum0 = _mm256_setzero_ps();
__m256 sum1 = _mm256_setzero_ps();
__m256 sum2 = _mm256_setzero_ps();
__m256 sum3 = _mm256_setzero_ps();
__m256 x0, x1, x2, x3;
__m256 y0, y1, y2, y3;
for (int i = 0; i < n32; i += 32) {
x0 = _mm256_loadu_ps(x + i + 0);
x1 = _mm256_loadu_ps(x + i + 8);
x2 = _mm256_loadu_ps(x + i + 16);
x3 = _mm256_loadu_ps(x + i + 24);
y0 = _mm256_loadu_ps(y + i + 0);
y1 = _mm256_loadu_ps(y + i + 8);
y2 = _mm256_loadu_ps(y + i + 16);
y3 = _mm256_loadu_ps(y + i + 24);
sum0 = _mm256_add_ps(_mm256_mul_ps(x0, y0), sum0);
sum1 = _mm256_add_ps(_mm256_mul_ps(x1, y1), sum1);
sum2 = _mm256_add_ps(_mm256_mul_ps(x2, y2), sum2);
sum3 = _mm256_add_ps(_mm256_mul_ps(x3, y3), sum3);
}
sum0 = _mm256_add_ps(sum0, sum1);
sum2 = _mm256_add_ps(sum2, sum3);
sum0 = _mm256_add_ps(sum0, sum2);
const __m128 r4 = _mm_add_ps(_mm256_castps256_ps128(sum0), _mm256_extractf128_ps(sum0, 1));
const __m128 r2 = _mm_add_ps(r4, _mm_movehl_ps(r4, r4));
const __m128 r1 = _mm_add_ss(r2, _mm_movehdup_ps(r2));
sumf = _mm_cvtss_f32(r1);
// leftovers
for (int i = n32; i < n; ++i) {
sumf += x[i]*y[i];
@ -569,6 +619,50 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
sumf = _mm_cvtss_f32(r1);
// leftovers
for (int i = n32; i < n; ++i) {
//GGML_ASSERT(false);
sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
}
#elif defined(__AVX__)
// AVX 256-bit
const int n32 = (n & ~31);
__m256 sum0 = _mm256_setzero_ps();
__m256 sum1 = _mm256_setzero_ps();
__m256 sum2 = _mm256_setzero_ps();
__m256 sum3 = _mm256_setzero_ps();
__m256 x0, x1, x2, x3;
__m256 y0, y1, y2, y3;
for (int i = 0; i < n32; i += 32) {
x0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 )));
x1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 )));
x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16)));
x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24)));
y0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 )));
y1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 )));
y2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16)));
y3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24)));
sum0 = _mm256_add_ps(_mm256_mul_ps(x0, y0), sum0);
sum1 = _mm256_add_ps(_mm256_mul_ps(x1, y1), sum1);
sum2 = _mm256_add_ps(_mm256_mul_ps(x2, y2), sum2);
sum3 = _mm256_add_ps(_mm256_mul_ps(x3, y3), sum3);
}
const __m256 sum01 = _mm256_add_ps(sum0, sum1);
const __m256 sum23 = _mm256_add_ps(sum2, sum3);
const __m256 sum0123 = _mm256_add_ps(sum01, sum23);
const __m128 r4 = _mm_add_ps(_mm256_castps256_ps128(sum0123), _mm256_extractf128_ps(sum0123, 1));
const __m128 r2 = _mm_add_ps(r4, _mm_movehl_ps(r4, r4));
const __m128 r1 = _mm_add_ss(r2, _mm_movehdup_ps(r2));
sumf = _mm_cvtss_f32(r1);
// leftovers
for (int i = n32; i < n; ++i) {
//GGML_ASSERT(false);
@ -698,6 +792,41 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
_mm256_storeu_ps(y + i + 24, y3);
}
// leftovers
for (int i = n32; i < n; ++i) {
y[i] += x[i]*v;
}
#elif defined(__AVX__)
// AVX 256-bit
const int n32 = (n & ~31);
const __m256 v4 = _mm256_set1_ps(v);
__m256 x0, x1, x2, x3;
__m256 y0, y1, y2, y3;
for (int i = 0; i < n32; i += 32) {
x0 = _mm256_loadu_ps(x + i + 0);
x1 = _mm256_loadu_ps(x + i + 8);
x2 = _mm256_loadu_ps(x + i + 16);
x3 = _mm256_loadu_ps(x + i + 24);
y0 = _mm256_loadu_ps(y + i + 0);
y1 = _mm256_loadu_ps(y + i + 8);
y2 = _mm256_loadu_ps(y + i + 16);
y3 = _mm256_loadu_ps(y + i + 24);
y0 = _mm256_add_ps(_mm256_mul_ps(x0, v4), y0);
y1 = _mm256_add_ps(_mm256_mul_ps(x1, v4), y1);
y2 = _mm256_add_ps(_mm256_mul_ps(x2, v4), y2);
y3 = _mm256_add_ps(_mm256_mul_ps(x3, v4), y3);
_mm256_storeu_ps(y + i + 0, y0);
_mm256_storeu_ps(y + i + 8, y1);
_mm256_storeu_ps(y + i + 16, y2);
_mm256_storeu_ps(y + i + 24, y3);
}
// leftovers
for (int i = n32; i < n; ++i) {
y[i] += x[i]*v;
@ -859,6 +988,42 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_
_mm_storeu_si128((__m128i*)(y + i + 24), _mm256_cvtps_ph(y3, 0));
}
// leftovers
for (int i = n32; i < n; ++i) {
GGML_ASSERT(false);
y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v);
}
#elif defined(__AVX__)
// AVX 256-bit
const int n32 = (n & ~31);
const __m256 v8 = _mm256_set1_ps(v);
__m256 x0, x1, x2, x3;
__m256 y0, y1, y2, y3;
for (int i = 0; i < n32; i += 32) {
y0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 )));
y1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 )));
y2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16)));
y3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24)));
x0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 )));
x1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 )));
x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16)));
x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24)));
y0 = _mm256_add_ps(_mm256_mul_ps(x0, v8), y0);
y1 = _mm256_add_ps(_mm256_mul_ps(x1, v8), y1);
y2 = _mm256_add_ps(_mm256_mul_ps(x2, v8), y2);
y3 = _mm256_add_ps(_mm256_mul_ps(x3, v8), y3);
_mm_storeu_si128((__m128i*)(y + i + 0 ), _mm256_cvtps_ph(y0, 0));
_mm_storeu_si128((__m128i*)(y + i + 8 ), _mm256_cvtps_ph(y1, 0));
_mm_storeu_si128((__m128i*)(y + i + 16), _mm256_cvtps_ph(y2, 0));
_mm_storeu_si128((__m128i*)(y + i + 24), _mm256_cvtps_ph(y3, 0));
}
// leftovers
for (int i = n32; i < n; ++i) {
GGML_ASSERT(false);
@ -8081,6 +8246,14 @@ enum ggml_opt_result ggml_opt(
////////////////////////////////////////////////////////////////////////////////
int ggml_cpu_has_avx(void) {
#if defined(__AVX__)
return 1;
#else
return 0;
#endif
}
int ggml_cpu_has_avx2(void) {
#if defined(__AVX2__)
return 1;

Loading…
Cancel
Save