diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 8deb3f0..6a697ac 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -53,22 +53,23 @@ void replace_all(std::string & s, const std::string & search, const std::string // command-line parameters struct whisper_params { int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); - int32_t n_processors = 1; - int32_t offset_t_ms = 0; - int32_t offset_n = 0; - int32_t duration_ms = 0; + int32_t n_processors = 1; + int32_t offset_t_ms = 0; + int32_t offset_n = 0; + int32_t duration_ms = 0; int32_t max_context = -1; - int32_t max_len = 0; - int32_t best_of = 5; + int32_t max_len = 0; + int32_t best_of = 5; int32_t beam_size = -1; - float word_thold = 0.01f; - float entropy_thold = 2.4f; - float logprob_thold = -1.0f; + float word_thold = 0.01f; + float entropy_thold = 2.40f; + float logprob_thold = -1.00f; bool speed_up = false; bool translate = false; bool diarize = false; + bool no_fallback = false; bool output_txt = false; bool output_vtt = false; bool output_srt = false; @@ -117,6 +118,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } + else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; } else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; } @@ -162,6 +164,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); + fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false"); fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false"); fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false"); @@ -647,17 +650,19 @@ int main(int argc, char ** argv) { wparams.token_timestamps = params.output_wts || params.max_len > 0; wparams.thold_pt = params.word_thold; - wparams.entropy_thold = params.entropy_thold; - wparams.logprob_thold = params.logprob_thold; wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; wparams.speed_up = params.speed_up; + wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data(); + wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size(); + wparams.greedy.best_of = params.best_of; wparams.beam_search.beam_size = params.beam_size; - wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data(); - wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size(); + wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc; + wparams.entropy_thold = params.entropy_thold; + wparams.logprob_thold = params.logprob_thold; whisper_print_user_data user_data = { ¶ms, &pcmf32s };