feat: added flag to split on word

pull/455/head
Matija Pevec 2 years ago
parent dd6f07ebfe
commit 5c1af3df22

@ -69,6 +69,7 @@ struct whisper_params {
bool speed_up = false; bool speed_up = false;
bool translate = false; bool translate = false;
bool diarize = false; bool diarize = false;
bool split_on_word = false;
bool output_txt = false; bool output_txt = false;
bool output_vtt = false; bool output_vtt = false;
bool output_srt = false; bool output_srt = false;
@ -117,6 +118,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
else if (arg == "-sow" || arg == "--split_on_word") { params.split_on_word = true; }
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; } else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; } else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
@ -154,6 +156,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false");
fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of); fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size); fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
@ -653,6 +656,7 @@ int main(int argc, char ** argv) {
wparams.entropy_thold = params.entropy_thold; wparams.entropy_thold = params.entropy_thold;
wparams.logprob_thold = params.logprob_thold; wparams.logprob_thold = params.logprob_thold;
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
wparams.split_on_word = params.split_on_word;
wparams.speed_up = params.speed_up; wparams.speed_up = params.speed_up;

@ -2780,6 +2780,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.thold_pt =*/ 0.01f, /*.thold_pt =*/ 0.01f,
/*.thold_ptsum =*/ 0.01f, /*.thold_ptsum =*/ 0.01f,
/*.max_len =*/ 0, /*.max_len =*/ 0,
/*.split_on_word =*/ false,
/*.max_tokens =*/ 0, /*.max_tokens =*/ 0,
/*.speed_up =*/ false, /*.speed_up =*/ false,
@ -2866,9 +2867,16 @@ static inline void trim(std::string &s) {
ltrim(s); ltrim(s);
} }
static inline bool should_split_on_word(const char * txt, bool split_on_word) {
if (!split_on_word) return true;
std::string s = txt;
return s.substr(0, 1) == " ";
}
// wrap the last segment to max_len characters // wrap the last segment to max_len characters
// returns the number of new segments // returns the number of new segments
static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) { static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool split_on_word) {
auto segment = ctx.result_all.back(); auto segment = ctx.result_all.back();
int res = 1; int res = 1;
@ -2883,11 +2891,9 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) {
} }
const auto txt = whisper_token_to_str(&ctx, token.id); const auto txt = whisper_token_to_str(&ctx, token.id);
std::string s = txt;
const int cur = strlen(txt); const int cur = strlen(txt);
if (acc + cur > max_len && i > 0 && s.substr(0, 1) == " ") { if (acc + cur > max_len && i > 0 && should_split_on_word(txt, split_on_word)) {
// split here // split here
trim(text); trim(text);
ctx.result_all.back().text = std::move(text); ctx.result_all.back().text = std::move(text);
@ -3946,7 +3952,7 @@ int whisper_full(
*ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
if (params.max_len > 0) { if (params.max_len > 0) {
n_new = whisper_wrap_segment(*ctx, params.max_len); n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
} }
} }
if (params.new_segment_callback) { if (params.new_segment_callback) {
@ -3990,7 +3996,7 @@ int whisper_full(
*ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
if (params.max_len > 0) { if (params.max_len > 0) {
n_new = whisper_wrap_segment(*ctx, params.max_len); n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
} }
} }
if (params.new_segment_callback) { if (params.new_segment_callback) {

@ -257,6 +257,7 @@ extern "C" {
float thold_pt; // timestamp token probability threshold (~0.01) float thold_pt; // timestamp token probability threshold (~0.01)
float thold_ptsum; // timestamp token sum probability threshold (~0.01) float thold_ptsum; // timestamp token sum probability threshold (~0.01)
int max_len; // max segment length in characters int max_len; // max segment length in characters
bool split_on_word; // split on word rather than on token (when used with max_len)
int max_tokens; // max tokens per segment (0 = no limit) int max_tokens; // max tokens per segment (0 = no limit)
// [EXPERIMENTAL] speed-up techniques // [EXPERIMENTAL] speed-up techniques

Loading…
Cancel
Save