Improve decoding (#291)

* whisper : prepare infra for new decoding strategies

* whisper : apply logit filters and compute logprobs

* whisper : add whisper_get_logits()

* whisper : separate self and cross attention memory

Initial step needed for supporting parallel decoders

* whisper : move probs_id buffer to whisper_context

* whisper : refactor kv cache into separate struct

* whisper : move self-attention kv cache to whisper_decoder

* whisper : wip decoding parameters + strategies

* whisper : wip decoding parameters + strategies (part 2)

* whisper : wip decoding parameters + strategies (part 3)

* whisper : wip decoding parameters + strategies (part 4)

* whisper : fix prompt_past update to not include prompt_init

* whisper : temperature + best_of support

* whisper : support for compression_ration_threshold

We actually use entropy, but it is similar

* command : fix example to use logits instead of obsolete probs

* whisper : handle empty sequence ranking

* whisper : add WHISPER_DEBUG + diagnostic prints + new main args

* whisper : minor fixes

* whisper : add beam-search support

* whisper : bug fix when there no previous context

* whisper : add comments

* stream : disable temperature fallback

For real-time processing, we always want a single decoder running at T=0

* whisper.swiftui : update example - fix paths + add empty folders
pull/409/head
Georgi Gerganov 2 years ago committed by GitHub
parent a6dbd9188b
commit 8de452c18b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

2
.gitignore vendored

@ -8,6 +8,7 @@ build/
build-em/ build-em/
build-debug/ build-debug/
build-release/ build-release/
build-static/
build-sanitize-addr/ build-sanitize-addr/
build-sanitize-thread/ build-sanitize-thread/
@ -18,6 +19,7 @@ build-sanitize-thread/
/bench /bench
sync.sh sync.sh
libwhisper.a
libwhisper.so libwhisper.so
compile_commands.json compile_commands.json

@ -212,17 +212,7 @@ make large
## Limitations ## Limitations
- Inference only - Inference only
- No GPU support - No GPU support (yet)
- Very basic greedy sampling scheme - always pick up the token with highest probability.
This should be similar to the [GreedyDecoder](https://github.com/openai/whisper/blob/main/whisper/decoding.py#L249-L274)
from the original python implementation, so in order to make a fair comparison between the 2 implementations, make sure
to run the python code with the following parameters:
```
whisper --best_of None --beam_size None ...
```
In the future, `whisper.cpp` will support more sampling strategies.
## Another example ## Another example

@ -671,56 +671,81 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
break; break;
} }
const auto * probs = whisper_get_probs(ctx); // estimate command probability
std::vector<std::pair<float, int>> probs_id; // NOTE: not optimal
{
double psum = 0.0; const auto * logits = whisper_get_logits(ctx);
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
probs_id.back().first += probs[allowed_tokens[i][j]];
}
probs_id.back().first /= allowed_tokens[i].size();
psum += probs_id.back().first;
}
// normalize std::vector<float> probs(whisper_n_vocab(ctx), 0.0f);
for (auto & p : probs_id) {
p.first /= psum;
}
// sort descending // compute probs from logits via softmax
{ {
using pair_type = decltype(probs_id)::value_type; float max = -1e9;
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) { for (int i = 0; i < (int) probs.size(); ++i) {
return a.first > b.first; max = std::max(max, logits[i]);
}); }
}
// print the commands and the respective probabilities float sum = 0.0f;
{ for (int i = 0; i < (int) probs.size(); ++i) {
fprintf(stdout, "\n"); probs[i] = expf(logits[i] - max);
for (const auto & cmd : probs_id) { sum += probs[i];
fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first); }
for (int token : allowed_tokens[cmd.second]) {
fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]); for (int i = 0; i < (int) probs.size(); ++i) {
probs[i] /= sum;
} }
}
std::vector<std::pair<float, int>> probs_id;
double psum = 0.0;
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
probs_id.back().first += probs[allowed_tokens[i][j]];
}
probs_id.back().first /= allowed_tokens[i].size();
psum += probs_id.back().first;
}
// normalize
for (auto & p : probs_id) {
p.first /= psum;
}
// sort descending
{
using pair_type = decltype(probs_id)::value_type;
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
return a.first > b.first;
});
}
// print the commands and the respective probabilities
{
fprintf(stdout, "\n"); fprintf(stdout, "\n");
for (const auto & cmd : probs_id) {
fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
for (int token : allowed_tokens[cmd.second]) {
fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]);
}
fprintf(stdout, "\n");
}
} }
}
// best command // best command
{ {
const auto t_end = std::chrono::high_resolution_clock::now(); const auto t_end = std::chrono::high_resolution_clock::now();
const float prob = probs_id[0].first; const float prob = probs_id[0].first;
const int index = probs_id[0].second; const int index = probs_id[0].second;
fprintf(stdout, "\n"); fprintf(stdout, "\n");
fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__, fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
"\033[1m", allowed_commands[index].c_str(), "\033[0m", prob, "\033[1m", allowed_commands[index].c_str(), "\033[0m", prob,
(int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count()); (int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count());
fprintf(stdout, "\n"); fprintf(stdout, "\n");
}
} }
audio.clear(); audio.clear();

@ -59,8 +59,12 @@ struct whisper_params {
int32_t duration_ms = 0; int32_t duration_ms = 0;
int32_t max_context = -1; int32_t max_context = -1;
int32_t max_len = 0; int32_t max_len = 0;
int32_t best_of = 5;
int32_t beam_size = -1;
float word_thold = 0.01f; float word_thold = 0.01f;
float entropy_thold = 2.4f;
float logprob_thold = -1.0f;
bool speed_up = false; bool speed_up = false;
bool translate = false; bool translate = false;
@ -104,7 +108,11 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); } else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); } else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); } else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
@ -136,31 +144,35 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]); fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "options:\n"); fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n"); fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors); fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n); fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false"); fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false"); fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false"); fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false"); fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true"); fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str()); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str()); fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", ""); fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
@ -235,7 +247,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
const char * text = whisper_full_get_token_text(ctx, i, j); const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j); const float p = whisper_full_get_token_p (ctx, i, j);
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size())))); const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m"); printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
} }
@ -331,20 +343,19 @@ bool output_csv(struct whisper_context * ctx, const char * fname) {
const int n_segments = whisper_full_n_segments(ctx); const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) { for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i); const char * text = whisper_full_get_segment_text(ctx, i);
if (text[0] == ' ') if (text[0] == ' ') {
text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character. text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character.
}
const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
//need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
fout << 10 * t0 << ", " //need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
<< 10 * t1 << ", \"" fout << 10 * t0 << ", " << 10 * t1 << ", \"" << text << "\"\n";
<< text << "\"\n";
} }
return true; return true;
} }
// karaoke video generation // karaoke video generation
// outputs a bash script that uses ffmpeg to generate a video with the subtitles // outputs a bash script that uses ffmpeg to generate a video with the subtitles
// TODO: font parameter adjustments // TODO: font parameter adjustments
@ -620,6 +631,8 @@ int main(int argc, char ** argv) {
{ {
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
wparams.print_realtime = false; wparams.print_realtime = false;
wparams.print_progress = params.print_progress; wparams.print_progress = params.print_progress;
wparams.print_timestamps = !params.no_timestamps; wparams.print_timestamps = !params.no_timestamps;
@ -633,12 +646,18 @@ int main(int argc, char ** argv) {
wparams.token_timestamps = params.output_wts || params.max_len > 0; wparams.token_timestamps = params.output_wts || params.max_len > 0;
wparams.thold_pt = params.word_thold; wparams.thold_pt = params.word_thold;
wparams.entropy_thold = params.entropy_thold;
wparams.logprob_thold = params.logprob_thold;
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
wparams.speed_up = params.speed_up; wparams.speed_up = params.speed_up;
wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data(); wparams.greedy.best_of = params.best_of;
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size(); wparams.beam_search.beam_size = params.beam_size;
wparams.temperature_inc = -1;
wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
whisper_print_user_data user_data = { &params, &pcmf32s }; whisper_print_user_data user_data = { &params, &pcmf32s };

@ -49,6 +49,9 @@ void stream_main(size_t index) {
wparams.max_tokens = 32; wparams.max_tokens = 32;
wparams.audio_ctx = 768; // partial encoder context for better performance wparams.audio_ctx = 768; // partial encoder context for better performance
// disable temperature fallback
wparams.temperature_inc = -1.0f;
wparams.language = "en"; wparams.language = "en";
printf("stream: using %d threads\n", wparams.n_threads); printf("stream: using %d threads\n", wparams.n_threads);

@ -615,6 +615,9 @@ int main(int argc, char ** argv) {
wparams.audio_ctx = params.audio_ctx; wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up; wparams.speed_up = params.speed_up;
// disable temperature fallback
wparams.temperature_inc = -1.0f;
wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data(); wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size(); wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size();

@ -35,10 +35,10 @@
0AAC5DA029539CD0003032C3 /* WhisperCppDemo.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = WhisperCppDemo.entitlements; sourceTree = "<group>"; }; 0AAC5DA029539CD0003032C3 /* WhisperCppDemo.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = WhisperCppDemo.entitlements; sourceTree = "<group>"; };
0AAC5DA229539CD0003032C3 /* Preview Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = "Preview Assets.xcassets"; sourceTree = "<group>"; }; 0AAC5DA229539CD0003032C3 /* Preview Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = "Preview Assets.xcassets"; sourceTree = "<group>"; };
0AAC5DC629539EAF003032C3 /* WhisperCppDemo-Bridging-Header.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = "WhisperCppDemo-Bridging-Header.h"; sourceTree = "<group>"; }; 0AAC5DC629539EAF003032C3 /* WhisperCppDemo-Bridging-Header.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = "WhisperCppDemo-Bridging-Header.h"; sourceTree = "<group>"; };
0AAC5DC729539EB0003032C3 /* whisper.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = whisper.cpp; path = ../../../whisper.cpp; sourceTree = "<group>"; }; 0AAC5DC729539EB0003032C3 /* whisper.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = whisper.cpp; sourceTree = "<group>"; };
0AAC5DC829539EB0003032C3 /* whisper.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = whisper.h; path = ../../../whisper.h; sourceTree = "<group>"; }; 0AAC5DC829539EB0003032C3 /* whisper.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = whisper.h; sourceTree = "<group>"; };
0AAC5DC929539EB0003032C3 /* ggml.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; name = ggml.c; path = ../../../ggml.c; sourceTree = "<group>"; }; 0AAC5DC929539EB0003032C3 /* ggml.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; path = ggml.c; sourceTree = "<group>"; };
0AAC5DCA29539EB0003032C3 /* ggml.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = ggml.h; path = ../../../ggml.h; sourceTree = "<group>"; }; 0AAC5DCA29539EB0003032C3 /* ggml.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ggml.h; sourceTree = "<group>"; };
0AAC5DCD2953A05C003032C3 /* WhisperState.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = WhisperState.swift; sourceTree = "<group>"; }; 0AAC5DCD2953A05C003032C3 /* WhisperState.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = WhisperState.swift; sourceTree = "<group>"; };
0AAC5DD02953A394003032C3 /* LibWhisper.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LibWhisper.swift; sourceTree = "<group>"; }; 0AAC5DD02953A394003032C3 /* LibWhisper.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LibWhisper.swift; sourceTree = "<group>"; };
/* End PBXFileReference section */ /* End PBXFileReference section */
@ -129,7 +129,8 @@
0AAC5DC729539EB0003032C3 /* whisper.cpp */, 0AAC5DC729539EB0003032C3 /* whisper.cpp */,
0AAC5DC829539EB0003032C3 /* whisper.h */, 0AAC5DC829539EB0003032C3 /* whisper.h */,
); );
path = whisper.cpp; name = whisper.cpp;
path = ../..;
sourceTree = "<group>"; sourceTree = "<group>";
}; };
0AAC5DCF2953A36C003032C3 /* whisper.cpp.swift */ = { 0AAC5DCF2953A36C003032C3 /* whisper.cpp.swift */ = {

File diff suppressed because it is too large Load Diff

@ -74,6 +74,7 @@ extern "C" {
whisper_token tid; // forced timestamp token id whisper_token tid; // forced timestamp token id
float p; // probability of the token float p; // probability of the token
float plog; // log probability of the token
float pt; // probability of the timestamp token float pt; // probability of the timestamp token
float ptsum; // sum of probabilities of all timestamp tokens float ptsum; // sum of probabilities of all timestamp tokens
@ -136,6 +137,7 @@ extern "C" {
// tokens + n_tokens is the provided context for the decoder. // tokens + n_tokens is the provided context for the decoder.
// n_past is the number of tokens to use from previous decoder calls. // n_past is the number of tokens to use from previous decoder calls.
// Returns 0 on success // Returns 0 on success
// TODO: add support for multiple decoders
WHISPER_API int whisper_decode( WHISPER_API int whisper_decode(
struct whisper_context * ctx, struct whisper_context * ctx,
const whisper_token * tokens, const whisper_token * tokens,
@ -143,14 +145,6 @@ extern "C" {
int n_past, int n_past,
int n_threads); int n_threads);
// Token sampling methods.
// These are provided for convenience and can be used after each call to whisper_decode().
// You can also implement your own sampling method using the whisper_get_probs() function.
// whisper_sample_best() returns the token with the highest probability
// whisper_sample_timestamp() returns the most probable timestamp token
WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial);
// Convert the provided text into tokens. // Convert the provided text into tokens.
// The tokens pointer must be large enough to hold the resulting tokens. // The tokens pointer must be large enough to hold the resulting tokens.
// Returns the number of tokens on success, no more than n_max_tokens // Returns the number of tokens on success, no more than n_max_tokens
@ -192,8 +186,11 @@ extern "C" {
WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx); WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx); WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
// The probabilities for the next token // Token logits obtained from the last call to whisper_decode()
WHISPER_API float * whisper_get_probs(struct whisper_context * ctx); // The logits for the last token are stored in the last row
// Rows: n_tokens
// Cols: n_vocab
WHISPER_API float * whisper_get_logits(struct whisper_context * ctx);
// Token Id -> String. Uses the vocabulary in the provided context // Token Id -> String. Uses the vocabulary in the provided context
WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token); WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token);
@ -222,8 +219,8 @@ extern "C" {
// Available sampling strategies // Available sampling strategies
enum whisper_sampling_strategy { enum whisper_sampling_strategy {
WHISPER_SAMPLING_GREEDY, // Always select the most probable token WHISPER_SAMPLING_GREEDY, // similar to OpenAI's GreefyDecoder
WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet! WHISPER_SAMPLING_BEAM_SEARCH, // similar to OpenAI's BeamSearchDecoder
}; };
// Text segment callback // Text segment callback
@ -243,17 +240,17 @@ extern "C" {
enum whisper_sampling_strategy strategy; enum whisper_sampling_strategy strategy;
int n_threads; int n_threads;
int n_max_text_ctx; int n_max_text_ctx; // max tokens to use from past text as prompt for the decoder
int offset_ms; // start offset in ms int offset_ms; // start offset in ms
int duration_ms; // audio duration to process in ms int duration_ms; // audio duration to process in ms
bool translate; bool translate;
bool no_context; bool no_context; // do not use initial prompt for the decoder (if any)
bool single_segment; // force single segment output (useful for streaming) bool single_segment; // force single segment output (useful for streaming)
bool print_special; bool print_special; // print special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.)
bool print_progress; bool print_progress; // print progress information
bool print_realtime; bool print_realtime; // print results from within whisper.cpp (avoid it, use callback instead)
bool print_timestamps; bool print_timestamps; // print timestamps for each text segment when printing realtime
// [EXPERIMENTAL] token-level timestamps // [EXPERIMENTAL] token-level timestamps
bool token_timestamps; // enable token-level timestamps bool token_timestamps; // enable token-level timestamps
@ -263,10 +260,11 @@ extern "C" {
int max_tokens; // max tokens per segment (0 = no limit) int max_tokens; // max tokens per segment (0 = no limit)
// [EXPERIMENTAL] speed-up techniques // [EXPERIMENTAL] speed-up techniques
// note: these can significantly reduce the quality of the output
bool speed_up; // speed-up the audio by 2x using Phase Vocoder bool speed_up; // speed-up the audio by 2x using Phase Vocoder
int audio_ctx; // overwrite the audio context size (0 = use default) int audio_ctx; // overwrite the audio context size (0 = use default)
// tokens to provide the whisper model as initial prompt // tokens to provide to the whisper decoder as initial prompt
// these are prepended to any existing text context from a previous call // these are prepended to any existing text context from a previous call
const whisper_token * prompt_tokens; const whisper_token * prompt_tokens;
int prompt_n_tokens; int prompt_n_tokens;
@ -274,19 +272,35 @@ extern "C" {
// for auto-detection, set to nullptr, "" or "auto" // for auto-detection, set to nullptr, "" or "auto"
const char * language; const char * language;
// common decoding parameters:
bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478
float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
float length_penalty; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L267
// fallback parameters
// ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L274-L278
float temperature_inc;
float entropy_thold; // similar to OpenAI's "compression_ratio_threshold"
float logprob_thold;
float no_speech_thold; // TODO: not implemented
struct { struct {
int n_past; int best_of; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264
} greedy; } greedy;
struct { struct {
int n_past; int beam_size; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265
int beam_width;
int n_best; float patience; // TODO: not implemented, ref: https://arxiv.org/pdf/2204.05424.pdf
} beam_search; } beam_search;
// called for every newly generated text segment
whisper_new_segment_callback new_segment_callback; whisper_new_segment_callback new_segment_callback;
void * new_segment_callback_user_data; void * new_segment_callback_user_data;
// called each time before the encoder starts
whisper_encoder_begin_callback encoder_begin_callback; whisper_encoder_begin_callback encoder_begin_callback;
void * encoder_begin_callback_user_data; void * encoder_begin_callback_user_data;
}; };

Loading…
Cancel
Save