main : add "--prompt" command line argument (#90)

This allows to provide an initial prompt to be used at the start of the
processing.
pull/285/head
Georgi Gerganov 1 year ago
parent 4312995974
commit b8065d90f5
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -73,8 +73,9 @@ struct whisper_params {
bool print_colors = false;
bool no_timestamps = false;
std::string language = "en";
std::string model = "models/ggml-base.en.bin";
std::string language = "en";
std::string prompt = "";
std::string model = "models/ggml-base.en.bin";
std::vector<std::string> fname_inp = {};
};
@ -113,6 +114,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_inp.push_back(argv[++i]); }
else {
@ -150,6 +152,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
fprintf(stderr, "\n");
@ -462,6 +465,22 @@ int main(int argc, char ** argv) {
return 3;
}
// initial prompt
std::vector<whisper_token> prompt_tokens;
if (params.prompt.size() > 0) {
prompt_tokens.resize(1024);
prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
fprintf(stderr, "\n");
fprintf(stderr, "initial prompt: '%s'\n", params.prompt.c_str());
fprintf(stderr, "initial tokens: [ ");
for (int i = 0; i < (int) prompt_tokens.size(); ++i) {
fprintf(stderr, "%d ", prompt_tokens[i]);
}
fprintf(stderr, "]\n");
}
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto fname_inp = params.fname_inp[f];
@ -577,7 +596,6 @@ int main(int argc, char ** argv) {
fprintf(stderr, "\n");
}
// run the inference
{
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
@ -599,6 +617,9 @@ int main(int argc, char ** argv) {
wparams.speed_up = params.speed_up;
wparams.prompt_tokens = prompt_tokens.size() == 0 ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = prompt_tokens.size() == 0 ? 0 : prompt_tokens.size();
whisper_print_user_data user_data = { &params, &pcmf32s };
// this callback is called on each new segment

Loading…
Cancel
Save