whisper : add new-segment callback

Can be used to process new segments as they are being generated.
Sample usage in main, for printing the resulting segments during the
inference.
pull/78/head
Georgi Gerganov 2 years ago
parent 8f95c25aed
commit 7affd309d3

@ -141,6 +141,55 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
fprintf(stderr, "\n");
}
void whisper_print_segment_callback(struct whisper_context * ctx, void * user_data) {
const whisper_params & params = *(whisper_params *) user_data;
const int n_segments = whisper_full_n_segments(ctx);
// print the last segment
const int i = n_segments - 1;
if (i == 0) {
printf("\n");
}
if (params.no_timestamps) {
if (params.print_colors) {
// TODO
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("%s", text);
fflush(stdout);
}
} else {
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
if (params.print_colors) {
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special_tokens == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
}
}
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
}
printf("\n");
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
}
}
}
bool output_txt(struct whisper_context * ctx, const char * fname) {
std::ofstream fout(fname);
if (!fout.is_open()) {
@ -294,7 +343,7 @@ int main(int argc, char ** argv) {
{
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.print_realtime = !params.print_colors;
wparams.print_realtime = false;
wparams.print_progress = false;
wparams.print_timestamps = !params.no_timestamps;
wparams.print_special_tokens = params.print_special_tokens;
@ -303,47 +352,15 @@ int main(int argc, char ** argv) {
wparams.n_threads = params.n_threads;
wparams.offset_ms = params.offset_t_ms;
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
return 7;
}
// print result
// this callback is called on each new segment
if (!wparams.print_realtime) {
printf("\n");
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
if (params.no_timestamps) {
if (params.print_colors) {
// TODO
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("%s", text);
fflush(stdout);
}
} else {
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
if (params.print_colors) {
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
wparams.new_segment_callback = whisper_print_segment_callback;
wparams.new_segment_callback_user_data = &params;
}
printf("\n");
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
}
}
}
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
return 7;
}
printf("\n");

@ -2320,6 +2320,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.beam_width =*/ -1,
/*.n_best =*/ -1,
},
/*.new_segment_callback =*/ nullptr,
/*.new_segment_callback_user_data =*/ nullptr,
};
} break;
case WHISPER_SAMPLING_BEAM_SEARCH:
@ -2348,6 +2351,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.beam_width =*/ 10,
/*.n_best =*/ 5,
},
/*.new_segment_callback =*/ nullptr,
/*.new_segment_callback_user_data =*/ nullptr,
};
} break;
}
@ -2549,6 +2555,9 @@ int whisper_full(
for (int j = i0; j <= i; j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
if (params.new_segment_callback) {
params.new_segment_callback(ctx, params.new_segment_callback_user_data);
}
}
text = "";
while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) {
@ -2576,6 +2585,9 @@ int whisper_full(
for (int j = i0; j < (int) tokens_cur.size(); j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
if (params.new_segment_callback) {
params.new_segment_callback(ctx, params.new_segment_callback_user_data);
}
}
}
@ -2609,6 +2621,10 @@ const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_seg
return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str();
}
whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) {
return ctx->result_all[i_segment].tokens[i_token].id;
}
float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
return ctx->result_all[i_segment].tokens[i_token].p;
}

@ -160,6 +160,11 @@ extern "C" {
WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet!
};
// Text segment callback
// Called on every newly generated text segment
// Use the whisper_full_...() functions to obtain the text segments
typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, void * user_data);
struct whisper_full_params {
enum whisper_sampling_strategy strategy;
@ -184,6 +189,9 @@ extern "C" {
int beam_width;
int n_best;
} beam_search;
whisper_new_segment_callback new_segment_callback;
void * new_segment_callback_user_data;
};
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
@ -212,6 +220,7 @@ extern "C" {
// Get the token text of the specified token in the specified segment.
WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token);
WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token);
// Get the probability of the specified token in the specified segment.
WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);

Loading…
Cancel
Save