sync : latest changes from whisper.cpp

pull/12/head
Georgi Gerganov 2 years ago
parent 1dcbe86a0c
commit bd9f710a45
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -69,11 +69,14 @@ struct whisper_params {
bool output_vtt = false;
bool output_srt = false;
bool output_wts = false;
bool output_csv = false;
bool print_special = false;
bool print_colors = false;
bool print_progress = false;
bool no_timestamps = false;
std::string language = "en";
std::string prompt;
std::string model = "models/ggml-base.en.bin";
std::vector<std::string> fname_inp = {};
@ -109,12 +112,15 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_inp.push_back(argv[++i]); }
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
@ -125,7 +131,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
return true;
}
void whisper_print_usage(int argc, char ** argv, const whisper_params & params) {
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
fprintf(stderr, "\n");
fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]);
fprintf(stderr, "\n");
@ -146,10 +152,13 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
fprintf(stderr, "\n");
@ -200,7 +209,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
std::string speaker = "";
std::string speaker;
if (params.diarize && pcmf32s.size() == 2) {
const int64_t n_samples = pcmf32s[0].size();
@ -266,7 +275,7 @@ bool output_txt(struct whisper_context * ctx, const char * fname) {
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
fout << text;
fout << text << "\n";
}
return true;
@ -319,10 +328,36 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_
return true;
}
bool output_csv(struct whisper_context * ctx, const char * fname) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
if (text[0] == ' ')
text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character.
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
//need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
fout << 10 * t0 << ", "
<< 10 * t1 << ", \""
<< text << "\"\n";
}
return true;
}
// karaoke video generation
// outputs a bash script that uses ffmpeg to generate a video with the subtitles
// TODO: font parameter adjustments
bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec) {
bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & /*params*/, float t_sec) {
std::ofstream fout(fname);
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
@ -371,7 +406,6 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
txt_ul = "\\ \\ ";
{
int ncnt = 0;
for (int k = 0; k < n; ++k) {
const auto & token2 = tokens[k];
@ -395,8 +429,6 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
txt_ul += "\\ ";
}
}
ncnt += txt.size();
}
::replace_all(txt_bg, "'", "\u2019");
@ -447,7 +479,7 @@ int main(int argc, char ** argv) {
return 2;
}
if (whisper_lang_id(params.language.c_str()) == -1) {
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
@ -462,6 +494,22 @@ int main(int argc, char ** argv) {
return 3;
}
// initial prompt
std::vector<whisper_token> prompt_tokens;
if (!params.prompt.empty()) {
prompt_tokens.resize(1024);
prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
fprintf(stderr, "\n");
fprintf(stderr, "initial prompt: '%s'\n", params.prompt.c_str());
fprintf(stderr, "initial tokens: [ ");
for (int i = 0; i < (int) prompt_tokens.size(); ++i) {
fprintf(stderr, "%d ", prompt_tokens[i]);
}
fprintf(stderr, "]\n");
}
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto fname_inp = params.fname_inp[f];
@ -486,14 +534,14 @@ int main(int argc, char ** argv) {
}
}
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), NULL) == false) {
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
fprintf(stderr, "error: failed to open WAV file from stdin\n");
return 4;
}
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
}
else if (drwav_init_file(&wav, fname_inp.c_str(), NULL) == false) {
else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
return 5;
}
@ -528,11 +576,11 @@ int main(int argc, char ** argv) {
// convert to mono, float
pcmf32.resize(n);
if (wav.channels == 1) {
for (int i = 0; i < n; i++) {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[i])/32768.0f;
}
} else {
for (int i = 0; i < n; i++) {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
}
}
@ -543,7 +591,7 @@ int main(int argc, char ** argv) {
pcmf32s[0].resize(n);
pcmf32s[1].resize(n);
for (int i = 0; i < n; i++) {
for (uint64_t i = 0; i < n; i++) {
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
}
@ -577,13 +625,12 @@ int main(int argc, char ** argv) {
fprintf(stderr, "\n");
}
// run the inference
{
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.print_realtime = false;
wparams.print_progress = false;
wparams.print_progress = params.print_progress;
wparams.print_timestamps = !params.no_timestamps;
wparams.print_special = params.print_special;
wparams.translate = params.translate;
@ -599,6 +646,9 @@ int main(int argc, char ** argv) {
wparams.speed_up = params.speed_up;
wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
whisper_print_user_data user_data = { &params, &pcmf32s };
// this callback is called on each new segment
@ -613,7 +663,7 @@ int main(int argc, char ** argv) {
{
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
wparams.encoder_begin_callback = [](struct whisper_context * ctx, void * user_data) {
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) {
bool is_aborted = *(bool*)user_data;
return !is_aborted;
};
@ -653,6 +703,13 @@ int main(int argc, char ** argv) {
const auto fname_wts = fname_inp + ".wts";
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
}
// output to CSV file
if (params.output_csv) {
const auto fname_csv = fname_inp + ".csv";
output_csv(ctx, fname_csv.c_str());
}
}
}

@ -14,6 +14,7 @@
#include <string>
#include <thread>
#include <vector>
#include <regex>
#define USE_FLASH_ATTN
//#define USE_FLASH_FF
@ -203,6 +204,10 @@ struct whisper_vocab {
std::map<token, id> token_to_id;
std::map<id, token> id_to_token;
// used to avoid memory allocations during sampling
// TODO: move to whisper_context in the future
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
id token_eot = 50256;
id token_sot = 50257;
id token_prev = 50360;
@ -429,6 +434,12 @@ struct whisper_context {
int32_t exp_n_audio_ctx; // 0 - use default
};
template<typename T>
static void read_safe(std::ifstream& fin, T& dest)
{
fin.read((char*)& dest, sizeof(T));
}
// load the model from a ggml file
//
// file format:
@ -455,7 +466,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
// verify magic
{
uint32_t magic;
fin.read((char *) &magic, sizeof(magic));
read_safe(fin, magic);
if (magic != 0x67676d6c) {
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
return false;
@ -466,17 +477,17 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
{
auto & hparams = model.hparams;
fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
fin.read((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx));
fin.read((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
fin.read((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head));
fin.read((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
fin.read((char *) &hparams.n_text_ctx, sizeof(hparams.n_text_ctx));
fin.read((char *) &hparams.n_text_state, sizeof(hparams.n_text_state));
fin.read((char *) &hparams.n_text_head, sizeof(hparams.n_text_head));
fin.read((char *) &hparams.n_text_layer, sizeof(hparams.n_text_layer));
fin.read((char *) &hparams.n_mels, sizeof(hparams.n_mels));
fin.read((char *) &hparams.f16, sizeof(hparams.f16));
read_safe(fin, hparams.n_vocab);
read_safe(fin, hparams.n_audio_ctx);
read_safe(fin, hparams.n_audio_state);
read_safe(fin, hparams.n_audio_head);
read_safe(fin, hparams.n_audio_layer);
read_safe(fin, hparams.n_text_ctx);
read_safe(fin, hparams.n_text_state);
read_safe(fin, hparams.n_text_head);
read_safe(fin, hparams.n_text_layer);
read_safe(fin, hparams.n_mels);
read_safe(fin, hparams.f16);
assert(hparams.n_text_state == hparams.n_audio_state);
@ -524,8 +535,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
{
auto & filters = wctx.model.filters;
fin.read((char *) &filters.n_mel, sizeof(filters.n_mel));
fin.read((char *) &filters.n_fft, sizeof(filters.n_fft));
read_safe(fin, filters.n_mel);
read_safe(fin, filters.n_fft);
filters.data.resize(filters.n_mel * filters.n_fft);
fin.read((char *) filters.data.data(), filters.data.size() * sizeof(float));
@ -534,7 +545,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
// load vocab
{
int32_t n_vocab = 0;
fin.read((char *) &n_vocab, sizeof(n_vocab));
read_safe(fin, n_vocab);
//if (n_vocab != model.hparams.n_vocab) {
// fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
@ -543,12 +554,23 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
//}
std::string word;
std::vector<char> tmp;
tmp.reserve(128);
for (int i = 0; i < n_vocab; i++) {
uint32_t len;
fin.read((char *) &len, sizeof(len));
read_safe(fin, len);
word.resize(len);
fin.read((char *) word.data(), len);
if (len > 0) {
tmp.resize(len);
fin.read(&tmp[0], tmp.size()); // read to buffer
word.assign(&tmp[0], tmp.size());
} else {
// seems like we have an empty-string token in multi-language models (i = 50256)
//fprintf(stderr, "%s: warning: empty-string token in vocab, i = %d\n", __func__, i);
word = "";
}
vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word;
@ -588,6 +610,11 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
vocab.id_to_token[i] = word;
}
}
wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
wctx.probs.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
vocab.probs_id.reserve(n_vocab);
}
{
@ -606,7 +633,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
size_t ctx_size = 0;
size_t ctx_mem_size = 0;
{
const auto & hparams = model.hparams;
@ -715,12 +741,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b
}
ctx_mem_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_k
ctx_mem_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_v
ctx_mem_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_k
ctx_mem_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_v
ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
fprintf(stderr, "%s: ggml ctx size = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
@ -728,10 +748,9 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
// create the ggml context
{
struct ggml_init_params params = {
.mem_size = wctx.buf_model->size(),
.mem_buffer = wctx.buf_model->data(),
};
struct ggml_init_params params;
params.mem_size = wctx.buf_model->size();
params.mem_buffer = wctx.buf_model->data();
model.ctx = ggml_init(params);
if (!model.ctx) {
@ -938,10 +957,9 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
// create the ggml memory context
{
struct ggml_init_params params = {
.mem_size = wctx.buf_memory.size(),
.mem_buffer = wctx.buf_memory.data(),
};
struct ggml_init_params params;
params.mem_size = wctx.buf_memory.size();
params.mem_buffer = wctx.buf_memory.data();
model.ctx_mem = ggml_init(params);
if (!model.ctx_mem) {
@ -998,9 +1016,9 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
int32_t length;
int32_t ftype;
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
fin.read(reinterpret_cast<char *>(&length), sizeof(length));
fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
read_safe(fin, n_dims);
read_safe(fin, length);
read_safe(fin, ftype);
if (fin.eof()) {
break;
@ -1009,14 +1027,16 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
int32_t nelements = 1;
int32_t ne[3] = { 1, 1, 1 };
for (int i = 0; i < n_dims; ++i) {
fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
read_safe(fin, ne[i]);
nelements *= ne[i];
}
std::string name(length, 0);
fin.read(&name[0], length);
std::string name;
std::vector<char> tmp(length); // create a buffer
fin.read(&tmp[0], tmp.size()); // read to buffer
name.assign(&tmp[0], tmp.size());
if (model.tensors.find(name.data()) == model.tensors.end()) {
if (model.tensors.find(name) == model.tensors.end()) {
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
return false;
}
@ -1088,10 +1108,9 @@ static bool whisper_encode(
const int n_mels = hparams.n_mels;
assert(mel_inp.n_mel == n_mels);
struct ggml_init_params params = {
.mem_size = wctx.buf_compute.size(),
.mem_buffer = wctx.buf_compute.data(),
};
struct ggml_init_params params;
params.mem_size = wctx.buf_compute.size();
params.mem_buffer = wctx.buf_compute.data();
struct ggml_context * ctx0 = ggml_init(params);
@ -1166,10 +1185,9 @@ static bool whisper_encode(
// create separate context for each layer to reduce memory usage
struct ggml_init_params paramsL = {
.mem_size = wctx.buf_compute_layer.size(),
.mem_buffer = wctx.buf_compute_layer.data(),
};
struct ggml_init_params paramsL;
paramsL.mem_size = wctx.buf_compute_layer.size();
paramsL.mem_buffer = wctx.buf_compute_layer.data();
struct ggml_context * ctxL = ggml_init(paramsL);
@ -1374,8 +1392,8 @@ static bool whisper_encode(
// input for next layer (inpO -> inpL)
memcpy(inpL->data, inpO->data, ggml_nbytes(inpL));
inpL->op = GGML_OP_NONE;
inpL->src0 = NULL;
inpL->src1 = NULL;
inpL->src0 = nullptr;
inpL->src1 = nullptr;
//printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0);
@ -1428,8 +1446,8 @@ static bool whisper_encode(
// TODO: hack to disconnect the encoded features from the previous graph
cur->op = GGML_OP_NONE;
cur->src0 = NULL;
cur->src1 = NULL;
cur->src0 = nullptr;
cur->src1 = nullptr;
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
auto & layer = model.layers_decoder[il];
@ -1503,10 +1521,9 @@ static bool whisper_decode(
const int N = n_tokens;
const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
struct ggml_init_params params = {
.mem_size = wctx.buf_compute.size(),
.mem_buffer = wctx.buf_compute.data(),
};
struct ggml_init_params params;
params.mem_size = wctx.buf_compute.size();
params.mem_buffer = wctx.buf_compute.data();
struct ggml_context * ctx0 = ggml_init(params);
@ -1529,10 +1546,9 @@ static bool whisper_decode(
for (int il = 0; il < n_layer; ++il) {
const auto & layer = model.layers_decoder[il];
struct ggml_init_params paramsL = {
.mem_size = wctx.buf_compute_layer.size(),
.mem_buffer = wctx.buf_compute_layer.data(),
};
struct ggml_init_params paramsL;
paramsL.mem_size = wctx.buf_compute_layer.size();
paramsL.mem_buffer = wctx.buf_compute_layer.data();
struct ggml_context * ctxL = ggml_init(paramsL);
struct ggml_cgraph gf = {};
@ -1788,8 +1804,8 @@ static bool whisper_decode(
// input for next layer (inpO -> inpL)
memcpy(inpL->data, inpO->data, ggml_nbytes(inpL));
inpL->op = GGML_OP_NONE;
inpL->src0 = NULL;
inpL->src1 = NULL;
inpL->src0 = nullptr;
inpL->src1 = nullptr;
if (N > 1) {
//printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0);
@ -1845,7 +1861,7 @@ static bool whisper_decode(
// the most basic sampling scheme - select the top token
static whisper_token_data whisper_sample_best(
const whisper_vocab & vocab,
whisper_vocab & vocab,
const float * probs,
bool force_timestamp,
bool is_initial) {
@ -1853,13 +1869,13 @@ static whisper_token_data whisper_sample_best(
0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
};
int n_logits = vocab.id_to_token.size();
const int n_logits = vocab.n_vocab;
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
probs_id.reserve(n_logits);
auto & probs_id = vocab.probs_id;
probs_id.clear();
for (int i = 0; i < n_logits; i++) {
probs_id.push_back(std::make_pair(probs[i], i));
probs_id.emplace_back(probs[i], i);
}
{
@ -1997,6 +2013,9 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
std::vector<float> even;
std::vector<float> odd;
even.reserve(N/2);
odd.reserve(N/2);
for (int i = 0; i < N; i++) {
if (i % 2 == 0) {
even.push_back(in[i]);
@ -2032,7 +2051,7 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
static bool log_mel_spectrogram(
const float * samples,
const int n_samples,
const int sample_rate,
const int /*sample_rate*/,
const int fft_size,
const int fft_step,
const int n_mel,
@ -2151,6 +2170,71 @@ static bool log_mel_spectrogram(
return true;
}
// split text into tokens
//
// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
//
// Regex (Python):
// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
//
// Regex (C++):
// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"
//
static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, const std::string & text) {
std::vector<std::string> words;
// first split the text into words
{
std::string str = text;
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
std::regex re(pat);
std::smatch m;
while (std::regex_search(str, m, re)) {
for (auto x : m) {
words.push_back(x);
}
str = m.suffix();
}
}
// find the longest tokens that form the words:
std::vector<whisper_vocab::id> tokens;
for (const auto & word : words) {
if (word.empty()) continue;
int i = 0;
int n = word.size();
while (i < n) {
int j = n;
while (j > i) {
auto it = vocab.token_to_id.find(word.substr(i, j-i));
if (it != vocab.token_to_id.end()) {
tokens.push_back(it->second);
i = j;
break;
}
--j;
}
if (i == n) {
break;
}
if (j == i) {
auto sub = word.substr(i, 1);
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
tokens.push_back(vocab.token_to_id.at(sub));
} else {
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
}
++i;
}
}
}
return tokens;
}
//
// interface implementation
//
@ -2166,7 +2250,8 @@ struct whisper_context * whisper_init(const char * path_model) {
if (!whisper_model_load(path_model, *ctx)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, path_model);
return NULL;
delete ctx;
return nullptr;
}
ctx->t_load_us = ggml_time_us() - t_start_us;
@ -2281,8 +2366,38 @@ struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx,
return res;
}
int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) {
const auto res = tokenize(ctx->vocab, text);
if (n_max_tokens < (int) res.size()) {
fprintf(stderr, "%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
return -1;
}
for (int i = 0; i < (int) res.size(); i++) {
tokens[i] = res[i];
}
return res.size();
}
int whisper_lang_max_id() {
auto max_id = 0;
for (const auto & kv : g_lang) {
max_id = std::max(max_id, kv.second.first);
}
return max_id;
}
int whisper_lang_id(const char * lang) {
if (!g_lang.count(lang)) {
for (const auto & kv : g_lang) {
if (kv.second.second == lang) {
return kv.second.first;
}
}
fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang);
return -1;
}
@ -2290,6 +2405,86 @@ int whisper_lang_id(const char * lang) {
return g_lang.at(lang).first;
}
const char * whisper_lang_str(int id) {
for (const auto & kv : g_lang) {
if (kv.second.first == id) {
return kv.first.c_str();
}
}
fprintf(stderr, "%s: unknown language id %d\n", __func__, id);
return nullptr;
}
int whisper_lang_auto_detect(
struct whisper_context * ctx,
int offset_ms,
int n_threads,
float * lang_probs) {
const int seek = offset_ms/10;
if (seek < 0) {
fprintf(stderr, "%s: offset %dms is before the start of the audio\n", __func__, offset_ms);
return -1;
}
if (seek >= ctx->mel.n_len) {
fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, ctx->mel.n_len*10);
return -2;
}
// run the encoder
if (whisper_encode(ctx, seek, n_threads) != 0) {
fprintf(stderr, "%s: failed to encode\n", __func__);
return -6;
}
const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
if (whisper_decode(ctx, prompt.data(), prompt.size(), 0, n_threads) != 0) {
fprintf(stderr, "%s: failed to decode\n", __func__);
return -7;
}
std::vector<std::pair<float, int>> probs_id;
for (const auto & kv : g_lang) {
const auto token_lang = whisper_token_lang(ctx, kv.second.first);
probs_id.emplace_back(ctx->probs[token_lang], kv.second.first);
}
// sort descending
{
using pair_type = decltype(probs_id)::value_type;
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
return a.first > b.first;
});
}
// softmax
{
float sum = 0;
for (const auto & kv : probs_id) {
sum += exp(kv.first);
}
for (auto & kv : probs_id) {
kv.first = exp(kv.first) / sum;
}
}
{
for (int i = 0; i < (int) probs_id.size(); i++) {
if (lang_probs) {
lang_probs[probs_id[i].second] = probs_id[i].first;
}
//printf("%s: lang %2d (%3s): %f\n", __func__, probs_id[i].second, whisper_lang_str(probs_id[i].second), probs_id[i].first);
}
}
return probs_id[0].second;
}
int whisper_n_len(struct whisper_context * ctx) {
return ctx->mel.n_len;
}
@ -2302,6 +2497,10 @@ int whisper_n_text_ctx(struct whisper_context * ctx) {
return ctx->model.hparams.n_text_ctx;
}
int whisper_n_audio_ctx(struct whisper_context * ctx) {
return ctx->model.hparams.n_audio_ctx;
}
int whisper_is_multilingual(struct whisper_context * ctx) {
return ctx->vocab.is_multilingual() ? 1 : 0;
}
@ -2338,6 +2537,10 @@ whisper_token whisper_token_beg(struct whisper_context * ctx) {
return ctx->vocab.token_beg;
}
whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) {
return whisper_token_sot(ctx) + 1 + lang_id;
}
whisper_token whisper_token_translate(void) {
return whisper_vocab::token_translate;
}
@ -2371,7 +2574,10 @@ const char * whisper_print_system_info(void) {
s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | ";
s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
@ -2569,10 +2775,25 @@ int whisper_full(
} else {
if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
return -1;
return -2;
}
}
// auto-detect language if not specified
if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) {
std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
const auto lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, probs.data());
if (lang_id < 0) {
fprintf(stderr, "%s: failed to auto-detect language\n", __func__);
return -3;
}
params.language = whisper_lang_str(lang_id);
fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
}
if (params.token_timestamps) {
ctx->t_beg = 0;
ctx->t_last = 0;
@ -2605,13 +2826,18 @@ int whisper_full(
std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
}
// overwrite audio_ctx
// overwrite audio_ctx, max allowed is hparams.n_audio_ctx
if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
return -4;
}
ctx->exp_n_audio_ctx = params.audio_ctx;
// these tokens determine the task that will be performed
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
if (whisper_is_multilingual(ctx)) {
prompt_init.push_back(whisper_token_sot(ctx) + 1 + whisper_lang_id(params.language));
const int lang_id = whisper_lang_id(params.language);
prompt_init.push_back(whisper_token_lang(ctx, lang_id));
if (params.translate) {
prompt_init.push_back(whisper_token_translate());
} else {
@ -2639,10 +2865,17 @@ int whisper_full(
}
}
// of only 1 second left, then stop
if (seek + 100 >= seek_end) {
break;
}
// if there is a very short audio segment left to process, we remove any past prompt since it tends
// to confuse the decoder and often make it repeat or hallucinate stuff
if (seek > seek_start && seek + 500 >= seek_end) {
prompt_past.clear();
}
if (params.encoder_begin_callback) {
if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) {
fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__);
@ -2653,14 +2886,14 @@ int whisper_full(
// encode audio features starting at offset seek
if (whisper_encode(ctx, seek, params.n_threads) != 0) {
fprintf(stderr, "%s: failed to encode\n", __func__);
return 7;
return -4;
}
int n_past = 0;
prompt.clear();
// if we have already generated some text, use it as a prompt to condition the next generation
if (prompt_past.size() > 0) {
if (!prompt_past.empty()) {
int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
prompt = { whisper_token_prev(ctx) };
@ -2686,11 +2919,12 @@ int whisper_full(
tokens_cur.clear();
bool failed = false;
bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment?
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
fprintf(stderr, "%s: failed to decode\n", __func__);
return 8;
return -5;
}
n_past += prompt.size();
@ -2711,13 +2945,13 @@ int whisper_full(
const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx));
// do not allow to go back in time
if (seek_delta != 100*WHISPER_CHUNK_SIZE &&
seek_delta > seek_delta_new && result_len < i) {
if (has_ts && seek_delta > seek_delta_new && result_len < i) {
break;
}
seek_delta = seek_delta_new;
result_len = i + 1;
has_ts = true;
}
// add it to the context
@ -2726,11 +2960,14 @@ int whisper_full(
//{
// const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]";
// printf("%s: %10s %6d %6.3f '%s'\n", __func__, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str());
// printf("%s: %3d %10s %6d %6.3f '%s'\n", __func__, i, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str());
//}
// end of text token
if (token.id == whisper_token_eot(ctx) || (params.max_tokens > 0 && i > params.max_tokens)) {
// end of segment
if (token.id == whisper_token_eot(ctx) || // end of text token
(params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached
(has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached
) {
if (result_len == 0) {
if (seek + seek_delta + 100 >= seek_end) {
result_len = i + 1;
@ -2765,8 +3002,14 @@ int whisper_full(
}
if (failed) {
fprintf(stderr, "\n%s: failed to generate timestamp token - using fallback strategy\n\n", __func__);
// when we fail to sample timestamp token, retry by clearing the past prompt
// if it fails again, then we advance the window by 1 second
if (!prompt_past.empty()) {
prompt_past.clear();
} else {
fprintf(stderr, "\n%s: failed to generate timestamp token - skipping one second\n\n", __func__);
seek += 100;
}
continue;
}
@ -2778,11 +3021,11 @@ int whisper_full(
}
// store the text from this iteration
if (tokens_cur.size() > 0) {
if (!tokens_cur.empty()) {
int i0 = 0;
auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
std::string text = "";
std::string text;
for (int i = 0; i < (int) tokens_cur.size(); i++) {
//printf("%s: %18s %6.3f %18s %6.3f\n", __func__,
@ -2901,10 +3144,9 @@ int whisper_full_parallel(
// create the ggml memory context
{
struct ggml_init_params params = {
.mem_size = ctxs[i].buf_memory.size(),
.mem_buffer = ctxs[i].buf_memory.data(),
};
struct ggml_init_params params;
params.mem_size = ctxs[i].buf_memory.size();
params.mem_buffer = ctxs[i].buf_memory.data();
model.ctx_mem = ggml_init(params);
if (!model.ctx_mem) {
@ -2990,7 +3232,7 @@ int whisper_full_parallel(
results_i[j].t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
// make sure that segments are not overlapping
if (ctx->result_all.size() > 0) {
if (!ctx->result_all.empty()) {
results_i[j].t0 = std::max(results_i[j].t0, ctx->result_all.back().t1);
}

@ -139,12 +139,45 @@ extern "C" {
WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial);
// Convert the provided text into tokens.
// The tokens pointer must be large enough to hold the resulting tokens.
// Returns the number of tokens on success, no more than n_max_tokens
// Returns -1 on failure
// TODO: not sure if correct
WHISPER_API int whisper_tokenize(
struct whisper_context * ctx,
const char * text,
whisper_token * tokens,
int n_max_tokens);
// Largest language id (i.e. number of available languages - 1)
WHISPER_API int whisper_lang_max_id();
// Return the id of the specified language, returns -1 if not found
// Examples:
// "de" -> 2
// "german" -> 2
WHISPER_API int whisper_lang_id(const char * lang);
// Return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found
WHISPER_API const char * whisper_lang_str(int id);
// Use mel data at offset_ms to try and auto-detect the spoken language
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first
// Returns the top language id or negative on failure
// If not null, fills the lang_probs array with the probabilities of all languages
// The array must be whispe_lang_max_id() + 1 in size
// ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
WHISPER_API int whisper_lang_auto_detect(
struct whisper_context * ctx,
int offset_ms,
int n_threads,
float * lang_probs);
WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
// The probabilities for the next token
@ -160,6 +193,7 @@ extern "C" {
WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id);
// Task tokens
WHISPER_API whisper_token whisper_token_translate (void);
@ -225,6 +259,7 @@ extern "C" {
const whisper_token * prompt_tokens;
int prompt_n_tokens;
// for auto-detection, set to nullptr, "" or "auto"
const char * language;
struct {

@ -681,7 +681,6 @@ struct ggml_opt_params {
bool print_forward_graph;
bool print_backward_graph;
union {
// ADAM parameters
struct {
int n_iter;
@ -708,7 +707,6 @@ struct ggml_opt_params {
enum ggml_linesearch linesearch;
} lbfgs;
};
};
struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type);
@ -726,7 +724,10 @@ enum ggml_opt_result ggml_opt(
int ggml_cpu_has_avx(void);
int ggml_cpu_has_avx2(void);
int ggml_cpu_has_avx512(void);
int ggml_cpu_has_fma(void);
int ggml_cpu_has_neon(void);
int ggml_cpu_has_arm_fma(void);
int ggml_cpu_has_f16c(void);
int ggml_cpu_has_fp16_va(void);
int ggml_cpu_has_wasm_simd(void);
int ggml_cpu_has_blas(void);

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save