|
|
|
@ -133,11 +133,19 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
|
|
|
|
|
static const size_t MB = 1024*1024;
|
|
|
|
|
|
|
|
|
|
static const std::map<e_model, size_t> MEM_REQ_MODEL = {
|
|
|
|
|
{ MODEL_TINY, 86ull*MB },
|
|
|
|
|
{ MODEL_BASE, 165ull*MB },
|
|
|
|
|
{ MODEL_SMALL, 540ull*MB },
|
|
|
|
|
{ MODEL_MEDIUM, 1650ull*MB },
|
|
|
|
|
{ MODEL_LARGE, 3260ull*MB },
|
|
|
|
|
{ MODEL_TINY, 74ull*MB },
|
|
|
|
|
{ MODEL_BASE, 142ull*MB },
|
|
|
|
|
{ MODEL_SMALL, 466ull*MB },
|
|
|
|
|
{ MODEL_MEDIUM, 1464ull*MB },
|
|
|
|
|
{ MODEL_LARGE, 2952ull*MB },
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
static const std::map<e_model, size_t> MEM_REQ_MEMORY = {
|
|
|
|
|
{ MODEL_TINY, 12ull*MB },
|
|
|
|
|
{ MODEL_BASE, 24ull*MB },
|
|
|
|
|
{ MODEL_SMALL, 70ull*MB },
|
|
|
|
|
{ MODEL_MEDIUM, 184ull*MB },
|
|
|
|
|
{ MODEL_LARGE, 306ull*MB },
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
|
|
|
|
@ -410,6 +418,12 @@ struct whisper_context {
|
|
|
|
|
std::vector<whisper_segment> result_all;
|
|
|
|
|
|
|
|
|
|
std::vector<whisper_token> prompt_past;
|
|
|
|
|
|
|
|
|
|
// [EXPERIMENTAL] token-level timestamps data
|
|
|
|
|
int64_t t_beg;
|
|
|
|
|
int64_t t_last;
|
|
|
|
|
whisper_token tid_last;
|
|
|
|
|
std::vector<float> energy; // PCM signal energy
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// load the model from a ggml file
|
|
|
|
@ -423,7 +437,7 @@ struct whisper_context {
|
|
|
|
|
//
|
|
|
|
|
// see the convert-pt-to-ggml.py script for details
|
|
|
|
|
//
|
|
|
|
|
bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
|
|
|
|
|
static bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
|
|
|
|
|
fprintf(stderr, "%s: loading model from '%s'\n", __func__, fname.c_str());
|
|
|
|
|
|
|
|
|
|
auto & model = wctx.model;
|
|
|
|
@ -498,7 +512,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
|
|
|
|
|
|
|
|
|
|
wctx.buf_model = new std::vector<uint8_t>();
|
|
|
|
|
wctx.buf_model->resize(MEM_REQ_MODEL.at(model.type));
|
|
|
|
|
wctx.buf_memory.resize(std::max(MEM_REQ_MODEL.at(model.type), MEM_REQ_MODEL.at(model.type))); // TODO: TMP !!!
|
|
|
|
|
wctx.buf_memory.resize(MEM_REQ_MEMORY.at(model.type));
|
|
|
|
|
wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
|
|
|
|
|
wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
|
|
|
|
|
|
|
|
|
@ -722,20 +736,6 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// create the ggml memory context
|
|
|
|
|
{
|
|
|
|
|
struct ggml_init_params params = {
|
|
|
|
|
.mem_size = wctx.buf_memory.size(),
|
|
|
|
|
.mem_buffer = wctx.buf_memory.data(),
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
model.ctx_mem = ggml_init(params);
|
|
|
|
|
if (!model.ctx_mem) {
|
|
|
|
|
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// prepare memory for the weights
|
|
|
|
|
{
|
|
|
|
|
auto & ctx = model.ctx;
|
|
|
|
@ -932,6 +932,20 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// create the ggml memory context
|
|
|
|
|
{
|
|
|
|
|
struct ggml_init_params params = {
|
|
|
|
|
.mem_size = wctx.buf_memory.size(),
|
|
|
|
|
.mem_buffer = wctx.buf_memory.data(),
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
model.ctx_mem = ggml_init(params);
|
|
|
|
|
if (!model.ctx_mem) {
|
|
|
|
|
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// key + value memory
|
|
|
|
|
{
|
|
|
|
|
auto & ctx = model.ctx_mem;
|
|
|
|
@ -1054,7 +1068,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
|
|
|
|
|
// - n_threads: number of threads to use
|
|
|
|
|
// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
|
|
|
|
|
//
|
|
|
|
|
bool whisper_encode(
|
|
|
|
|
static bool whisper_encode(
|
|
|
|
|
whisper_context & wctx,
|
|
|
|
|
const int n_threads,
|
|
|
|
|
const int mel_offset) {
|
|
|
|
@ -1440,7 +1454,7 @@ bool whisper_encode(
|
|
|
|
|
// - n_tokens: number of tokens in the prompt
|
|
|
|
|
// - n_past: number of past tokens to prefix the prompt with
|
|
|
|
|
//
|
|
|
|
|
bool whisper_decode(
|
|
|
|
|
static bool whisper_decode(
|
|
|
|
|
whisper_context & wctx,
|
|
|
|
|
const int n_threads,
|
|
|
|
|
const whisper_token * tokens,
|
|
|
|
@ -1803,10 +1817,12 @@ bool whisper_decode(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// the most basic sampling scheme - select the top token
|
|
|
|
|
whisper_token_data whisper_sample_best(
|
|
|
|
|
static whisper_token_data whisper_sample_best(
|
|
|
|
|
const whisper_vocab & vocab,
|
|
|
|
|
const float * probs) {
|
|
|
|
|
whisper_token_data result;
|
|
|
|
|
whisper_token_data result = {
|
|
|
|
|
0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
int n_logits = vocab.id_to_token.size();
|
|
|
|
|
|
|
|
|
@ -1879,7 +1895,7 @@ whisper_token_data whisper_sample_best(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// samples only from the timestamps tokens
|
|
|
|
|
whisper_vocab::id whisper_sample_timestamp(
|
|
|
|
|
static whisper_vocab::id whisper_sample_timestamp(
|
|
|
|
|
const whisper_vocab & vocab,
|
|
|
|
|
const float * probs) {
|
|
|
|
|
int n_logits = vocab.id_to_token.size();
|
|
|
|
@ -1931,7 +1947,7 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
|
|
|
|
|
// naive Discrete Fourier Transform
|
|
|
|
|
// input is real-valued
|
|
|
|
|
// output is complex-valued
|
|
|
|
|
void dft(const std::vector<float> & in, std::vector<float> & out) {
|
|
|
|
|
static void dft(const std::vector<float> & in, std::vector<float> & out) {
|
|
|
|
|
int N = in.size();
|
|
|
|
|
|
|
|
|
|
out.resize(N*2);
|
|
|
|
@ -1955,7 +1971,7 @@ void dft(const std::vector<float> & in, std::vector<float> & out) {
|
|
|
|
|
// poor man's implementation - use something better
|
|
|
|
|
// input is real-valued
|
|
|
|
|
// output is complex-valued
|
|
|
|
|
void fft(const std::vector<float> & in, std::vector<float> & out) {
|
|
|
|
|
static void fft(const std::vector<float> & in, std::vector<float> & out) {
|
|
|
|
|
out.resize(in.size()*2);
|
|
|
|
|
|
|
|
|
|
int N = in.size();
|
|
|
|
@ -2006,7 +2022,7 @@ void fft(const std::vector<float> & in, std::vector<float> & out) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
|
|
|
|
|
bool log_mel_spectrogram(
|
|
|
|
|
static bool log_mel_spectrogram(
|
|
|
|
|
const float * samples,
|
|
|
|
|
const int n_samples,
|
|
|
|
|
const int sample_rate,
|
|
|
|
@ -2323,6 +2339,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
|
|
/*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
|
|
|
|
|
/*.n_max_text_ctx =*/ 16384,
|
|
|
|
|
/*.offset_ms =*/ 0,
|
|
|
|
|
/*.duration_ms =*/ 0,
|
|
|
|
|
|
|
|
|
|
/*.translate =*/ false,
|
|
|
|
|
/*.no_context =*/ false,
|
|
|
|
@ -2331,6 +2348,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
|
|
/*.print_realtime =*/ false,
|
|
|
|
|
/*.print_timestamps =*/ true,
|
|
|
|
|
|
|
|
|
|
/*.token_timestamps =*/ false,
|
|
|
|
|
/*.thold_pt =*/ 0.01f,
|
|
|
|
|
/*.thold_ptsum =*/ 0.01f,
|
|
|
|
|
/*.max_len =*/ 0,
|
|
|
|
|
|
|
|
|
|
/*.language =*/ "en",
|
|
|
|
|
|
|
|
|
|
/*.greedy =*/ {
|
|
|
|
@ -2355,6 +2377,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
|
|
/*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
|
|
|
|
|
/*.n_max_text_ctx =*/ 16384,
|
|
|
|
|
/*.offset_ms =*/ 0,
|
|
|
|
|
/*.duration_ms =*/ 0,
|
|
|
|
|
|
|
|
|
|
/*.translate =*/ false,
|
|
|
|
|
/*.no_context =*/ false,
|
|
|
|
@ -2363,6 +2386,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
|
|
/*.print_realtime =*/ false,
|
|
|
|
|
/*.print_timestamps =*/ true,
|
|
|
|
|
|
|
|
|
|
/*.token_timestamps =*/ false,
|
|
|
|
|
/*.thold_pt =*/ 0.01f,
|
|
|
|
|
/*.thold_ptsum =*/ 0.01f,
|
|
|
|
|
/*.max_len =*/ 0,
|
|
|
|
|
|
|
|
|
|
/*.language =*/ "en",
|
|
|
|
|
|
|
|
|
|
/*.greedy =*/ {
|
|
|
|
@ -2384,6 +2412,68 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// forward declarations
|
|
|
|
|
static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window);
|
|
|
|
|
static void whisper_exp_compute_token_level_timestamps(
|
|
|
|
|
struct whisper_context * ctx,
|
|
|
|
|
int i_segment,
|
|
|
|
|
float thold_pt,
|
|
|
|
|
float thold_ptsum);
|
|
|
|
|
|
|
|
|
|
// wrap the last segment to max_len characters
|
|
|
|
|
// returns the number of new segments
|
|
|
|
|
static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) {
|
|
|
|
|
auto segment = ctx->result_all.back();
|
|
|
|
|
|
|
|
|
|
int res = 1;
|
|
|
|
|
int acc = 0;
|
|
|
|
|
|
|
|
|
|
std::string text;
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < (int) segment.tokens.size(); i++) {
|
|
|
|
|
const auto & token = segment.tokens[i];
|
|
|
|
|
if (token.id >= whisper_token_eot(ctx)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const auto txt = whisper_token_to_str(ctx, token.id);
|
|
|
|
|
|
|
|
|
|
const int cur = strlen(txt);
|
|
|
|
|
|
|
|
|
|
if (acc + cur > max_len && i > 0) {
|
|
|
|
|
// split here
|
|
|
|
|
ctx->result_all.back().text = std::move(text);
|
|
|
|
|
ctx->result_all.back().t1 = token.t0;
|
|
|
|
|
ctx->result_all.back().tokens.resize(i);
|
|
|
|
|
|
|
|
|
|
ctx->result_all.push_back({});
|
|
|
|
|
ctx->result_all.back().t0 = token.t0;
|
|
|
|
|
ctx->result_all.back().t1 = segment.t1;
|
|
|
|
|
|
|
|
|
|
// add tokens [i, end] to the new segment
|
|
|
|
|
ctx->result_all.back().tokens.insert(
|
|
|
|
|
ctx->result_all.back().tokens.end(),
|
|
|
|
|
segment.tokens.begin() + i,
|
|
|
|
|
segment.tokens.end());
|
|
|
|
|
|
|
|
|
|
acc = 0;
|
|
|
|
|
text = "";
|
|
|
|
|
|
|
|
|
|
segment = ctx->result_all.back();
|
|
|
|
|
i = -1;
|
|
|
|
|
|
|
|
|
|
res++;
|
|
|
|
|
} else {
|
|
|
|
|
acc += cur;
|
|
|
|
|
text += txt;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ctx->result_all.back().text = std::move(text);
|
|
|
|
|
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int whisper_full(
|
|
|
|
|
struct whisper_context * ctx,
|
|
|
|
|
struct whisper_full_params params,
|
|
|
|
@ -2400,12 +2490,20 @@ int whisper_full(
|
|
|
|
|
return -1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (params.token_timestamps) {
|
|
|
|
|
ctx->t_beg = 0;
|
|
|
|
|
ctx->t_last = 0;
|
|
|
|
|
ctx->tid_last = 0;
|
|
|
|
|
ctx->energy = get_signal_energy(samples, n_samples, 32);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const int seek_start = params.offset_ms/10;
|
|
|
|
|
const int seek_end = seek_start + (params.duration_ms == 0 ? whisper_n_len(ctx) : params.duration_ms/10);
|
|
|
|
|
|
|
|
|
|
// if length of spectrogram is less than 1s (100 samples), then return
|
|
|
|
|
// basically don't process anything that is less than 1s
|
|
|
|
|
// see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
|
|
|
|
|
if (whisper_n_len(ctx) < 100 + seek_start) {
|
|
|
|
|
if (seek_end < 100 + seek_start) {
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -2438,7 +2536,7 @@ int whisper_full(
|
|
|
|
|
// main loop
|
|
|
|
|
int seek = seek_start;
|
|
|
|
|
while (true) {
|
|
|
|
|
int progress_cur = (100*seek)/whisper_n_len(ctx);
|
|
|
|
|
const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start);
|
|
|
|
|
while (progress_cur >= progress_prev + progress_step) {
|
|
|
|
|
progress_prev += progress_step;
|
|
|
|
|
if (params.print_progress) {
|
|
|
|
@ -2446,7 +2544,7 @@ int whisper_full(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (seek + 100 >= whisper_n_len(ctx)) {
|
|
|
|
|
if (seek + 100 >= seek_end) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -2527,7 +2625,7 @@ int whisper_full(
|
|
|
|
|
// end of text token
|
|
|
|
|
if (token.id == whisper_token_eot(ctx)) {
|
|
|
|
|
if (result_len == 0) {
|
|
|
|
|
if (seek + seek_delta + 100 >= whisper_n_len(ctx)) {
|
|
|
|
|
if (seek + seek_delta + 100 >= seek_end) {
|
|
|
|
|
result_len = i + 1;
|
|
|
|
|
} else {
|
|
|
|
|
// TODO: figure out how to resolve this
|
|
|
|
@ -2549,6 +2647,7 @@ int whisper_full(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// shrink down to result_len
|
|
|
|
|
tokens_cur.resize(result_len);
|
|
|
|
|
|
|
|
|
|
for (const auto & r : tokens_cur) {
|
|
|
|
@ -2587,8 +2686,19 @@ int whisper_full(
|
|
|
|
|
for (int j = i0; j <= i; j++) {
|
|
|
|
|
result_all.back().tokens.push_back(tokens_cur[j]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int n_new = 1;
|
|
|
|
|
|
|
|
|
|
if (params.token_timestamps) {
|
|
|
|
|
whisper_exp_compute_token_level_timestamps(
|
|
|
|
|
ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
|
|
|
|
|
|
|
|
|
|
if (params.max_len > 0) {
|
|
|
|
|
n_new = whisper_wrap_segment(ctx, params.max_len);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (params.new_segment_callback) {
|
|
|
|
|
params.new_segment_callback(ctx, params.new_segment_callback_user_data);
|
|
|
|
|
params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
text = "";
|
|
|
|
@ -2617,8 +2727,19 @@ int whisper_full(
|
|
|
|
|
for (int j = i0; j < (int) tokens_cur.size(); j++) {
|
|
|
|
|
result_all.back().tokens.push_back(tokens_cur[j]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int n_new = 1;
|
|
|
|
|
|
|
|
|
|
if (params.token_timestamps) {
|
|
|
|
|
whisper_exp_compute_token_level_timestamps(
|
|
|
|
|
ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
|
|
|
|
|
|
|
|
|
|
if (params.max_len > 0) {
|
|
|
|
|
n_new = whisper_wrap_segment(ctx, params.max_len);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (params.new_segment_callback) {
|
|
|
|
|
params.new_segment_callback(ctx, params.new_segment_callback_user_data);
|
|
|
|
|
params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -2752,7 +2873,7 @@ int whisper_full_parallel(
|
|
|
|
|
|
|
|
|
|
// call the new_segment_callback for each segment
|
|
|
|
|
if (params.new_segment_callback) {
|
|
|
|
|
params.new_segment_callback(ctx, params.new_segment_callback_user_data);
|
|
|
|
|
params.new_segment_callback(ctx, 1, params.new_segment_callback_user_data);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -2828,3 +2949,304 @@ const char * whisper_print_system_info() {
|
|
|
|
|
|
|
|
|
|
return s.c_str();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// =================================================================================================
|
|
|
|
|
|
|
|
|
|
//
|
|
|
|
|
// Experimental stuff below
|
|
|
|
|
//
|
|
|
|
|
// Not sure if these should be part of the library at all, because the quality of the results is not
|
|
|
|
|
// guaranteed. Might get removed at some point unless a robust algorithm implementation is found
|
|
|
|
|
//
|
|
|
|
|
|
|
|
|
|
// =================================================================================================
|
|
|
|
|
|
|
|
|
|
//
|
|
|
|
|
// token-level timestamps
|
|
|
|
|
//
|
|
|
|
|
|
|
|
|
|
static int timestamp_to_sample(int64_t t, int n_samples) {
|
|
|
|
|
return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static int64_t sample_to_timestamp(int i_sample) {
|
|
|
|
|
return (100*i_sample)/WHISPER_SAMPLE_RATE;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// a cost-function / heuristic that is high for text that takes longer to pronounce
|
|
|
|
|
// obviously, can be improved
|
|
|
|
|
static float voice_length(const std::string & text) {
|
|
|
|
|
float res = 0.0f;
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < text.size(); ++i) {
|
|
|
|
|
if (text[i] == ' ') {
|
|
|
|
|
res += 0.01f;
|
|
|
|
|
} else if (text[i] == ',') {
|
|
|
|
|
res += 2.00f;
|
|
|
|
|
} else if (text[i] == '.') {
|
|
|
|
|
res += 3.00f;
|
|
|
|
|
} else if (text[i] == '!') {
|
|
|
|
|
res += 3.00f;
|
|
|
|
|
} else if (text[i] == '?') {
|
|
|
|
|
res += 3.00f;
|
|
|
|
|
} else if (text[i] >= '0' && text[i] <= '9') {
|
|
|
|
|
res += 3.00f;
|
|
|
|
|
} else {
|
|
|
|
|
res += 1.00f;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// average the fabs of the signal
|
|
|
|
|
static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window) {
|
|
|
|
|
const int hw = n_samples_per_half_window;
|
|
|
|
|
|
|
|
|
|
std::vector<float> result(n_samples);
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < n_samples; i++) {
|
|
|
|
|
float sum = 0;
|
|
|
|
|
for (int j = -hw; j <= hw; j++) {
|
|
|
|
|
if (i + j >= 0 && i + j < n_samples) {
|
|
|
|
|
sum += fabs(signal[i + j]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
result[i] = sum/(2*hw + 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void whisper_exp_compute_token_level_timestamps(
|
|
|
|
|
struct whisper_context * ctx,
|
|
|
|
|
int i_segment,
|
|
|
|
|
float thold_pt,
|
|
|
|
|
float thold_ptsum) {
|
|
|
|
|
auto & segment = ctx->result_all[i_segment];
|
|
|
|
|
auto & tokens = segment.tokens;
|
|
|
|
|
|
|
|
|
|
const int n_samples = ctx->energy.size();
|
|
|
|
|
|
|
|
|
|
if (n_samples == 0) {
|
|
|
|
|
fprintf(stderr, "%s: no signal data available\n", __func__);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const int64_t t0 = segment.t0;
|
|
|
|
|
const int64_t t1 = segment.t1;
|
|
|
|
|
|
|
|
|
|
const int s0 = timestamp_to_sample(t0, n_samples);
|
|
|
|
|
const int s1 = timestamp_to_sample(t1, n_samples);
|
|
|
|
|
|
|
|
|
|
const int n = tokens.size();
|
|
|
|
|
|
|
|
|
|
if (n == 0) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (n == 1) {
|
|
|
|
|
tokens[0].t0 = t0;
|
|
|
|
|
tokens[0].t1 = t1;
|
|
|
|
|
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto & t_beg = ctx->t_beg;
|
|
|
|
|
auto & t_last = ctx->t_last;
|
|
|
|
|
auto & tid_last = ctx->tid_last;
|
|
|
|
|
|
|
|
|
|
for (int j = 0; j < n; ++j) {
|
|
|
|
|
auto & token = tokens[j];
|
|
|
|
|
|
|
|
|
|
if (j == 0) {
|
|
|
|
|
if (token.id == whisper_token_beg(ctx)) {
|
|
|
|
|
tokens[j ].t0 = t0;
|
|
|
|
|
tokens[j ].t1 = t0;
|
|
|
|
|
tokens[j + 1].t0 = t0;
|
|
|
|
|
|
|
|
|
|
t_beg = t0;
|
|
|
|
|
t_last = t0;
|
|
|
|
|
tid_last = whisper_token_beg(ctx);
|
|
|
|
|
} else {
|
|
|
|
|
tokens[j ].t0 = t_last;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx));
|
|
|
|
|
|
|
|
|
|
tokens[j].id = token.id;
|
|
|
|
|
tokens[j].tid = token.tid;
|
|
|
|
|
tokens[j].p = token.p;
|
|
|
|
|
tokens[j].pt = token.pt;
|
|
|
|
|
tokens[j].ptsum = token.ptsum;
|
|
|
|
|
|
|
|
|
|
tokens[j].vlen = voice_length(whisper_token_to_str(ctx, token.id));
|
|
|
|
|
|
|
|
|
|
if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) {
|
|
|
|
|
if (j > 0) {
|
|
|
|
|
tokens[j - 1].t1 = tt;
|
|
|
|
|
}
|
|
|
|
|
tokens[j].t0 = tt;
|
|
|
|
|
tid_last = token.tid;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
tokens[n - 2].t1 = t1;
|
|
|
|
|
tokens[n - 1].t0 = t1;
|
|
|
|
|
tokens[n - 1].t1 = t1;
|
|
|
|
|
|
|
|
|
|
t_last = t1;
|
|
|
|
|
|
|
|
|
|
// find intervals of tokens with unknown timestamps
|
|
|
|
|
// fill the timestamps by proportionally splitting the interval based on the token voice lengths
|
|
|
|
|
{
|
|
|
|
|
int p0 = 0;
|
|
|
|
|
int p1 = 0;
|
|
|
|
|
|
|
|
|
|
while (true) {
|
|
|
|
|
while (p1 < n && tokens[p1].t1 < 0) {
|
|
|
|
|
p1++;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (p1 >= n) {
|
|
|
|
|
p1--;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (p1 > p0) {
|
|
|
|
|
double psum = 0.0;
|
|
|
|
|
for (int j = p0; j <= p1; j++) {
|
|
|
|
|
psum += tokens[j].vlen;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//printf("analyzing %d - %d, psum = %f\n", p0, p1, psum);
|
|
|
|
|
|
|
|
|
|
const double dt = tokens[p1].t1 - tokens[p0].t0;
|
|
|
|
|
|
|
|
|
|
// split the time proportionally to the voice length
|
|
|
|
|
for (int j = p0 + 1; j <= p1; j++) {
|
|
|
|
|
const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum;
|
|
|
|
|
|
|
|
|
|
tokens[j - 1].t1 = ct;
|
|
|
|
|
tokens[j ].t0 = ct;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
p1++;
|
|
|
|
|
p0 = p1;
|
|
|
|
|
if (p1 >= n) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// fix up (just in case)
|
|
|
|
|
for (int j = 0; j < n - 1; j++) {
|
|
|
|
|
if (tokens[j].t1 < 0) {
|
|
|
|
|
tokens[j + 1].t0 = tokens[j].t1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (j > 0) {
|
|
|
|
|
if (tokens[j - 1].t1 > tokens[j].t0) {
|
|
|
|
|
tokens[j].t0 = tokens[j - 1].t1;
|
|
|
|
|
tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// VAD
|
|
|
|
|
// expand or contract tokens based on voice activity
|
|
|
|
|
{
|
|
|
|
|
const int hw = WHISPER_SAMPLE_RATE/8;
|
|
|
|
|
|
|
|
|
|
for (int j = 0; j < n; j++) {
|
|
|
|
|
if (tokens[j].id >= whisper_token_eot(ctx)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int s0 = timestamp_to_sample(tokens[j].t0, n_samples);
|
|
|
|
|
int s1 = timestamp_to_sample(tokens[j].t1, n_samples);
|
|
|
|
|
|
|
|
|
|
const int ss0 = std::max(s0 - hw, 0);
|
|
|
|
|
const int ss1 = std::min(s1 + hw, n_samples);
|
|
|
|
|
|
|
|
|
|
const int ns = ss1 - ss0;
|
|
|
|
|
|
|
|
|
|
float sum = 0.0f;
|
|
|
|
|
|
|
|
|
|
for (int k = ss0; k < ss1; k++) {
|
|
|
|
|
sum += ctx->energy[k];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const float thold = 0.5*sum/ns;
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
int k = s0;
|
|
|
|
|
if (ctx->energy[k] > thold && j > 0) {
|
|
|
|
|
while (k > 0 && ctx->energy[k] > thold) {
|
|
|
|
|
k--;
|
|
|
|
|
}
|
|
|
|
|
tokens[j].t0 = sample_to_timestamp(k);
|
|
|
|
|
if (tokens[j].t0 < tokens[j - 1].t1) {
|
|
|
|
|
tokens[j].t0 = tokens[j - 1].t1;
|
|
|
|
|
} else {
|
|
|
|
|
s0 = k;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
while (ctx->energy[k] < thold && k < s1) {
|
|
|
|
|
k++;
|
|
|
|
|
}
|
|
|
|
|
s0 = k;
|
|
|
|
|
tokens[j].t0 = sample_to_timestamp(k);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
int k = s1;
|
|
|
|
|
if (ctx->energy[k] > thold) {
|
|
|
|
|
while (k < n_samples - 1 && ctx->energy[k] > thold) {
|
|
|
|
|
k++;
|
|
|
|
|
}
|
|
|
|
|
tokens[j].t1 = sample_to_timestamp(k);
|
|
|
|
|
if (j < ns - 1 && tokens[j].t1 > tokens[j + 1].t0) {
|
|
|
|
|
tokens[j].t1 = tokens[j + 1].t0;
|
|
|
|
|
} else {
|
|
|
|
|
s1 = k;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
while (ctx->energy[k] < thold && k > s0) {
|
|
|
|
|
k--;
|
|
|
|
|
}
|
|
|
|
|
s1 = k;
|
|
|
|
|
tokens[j].t1 = sample_to_timestamp(k);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// fixed token expand (optional)
|
|
|
|
|
//{
|
|
|
|
|
// const int t_expand = 0;
|
|
|
|
|
|
|
|
|
|
// for (int j = 0; j < n; j++) {
|
|
|
|
|
// if (j > 0) {
|
|
|
|
|
// tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand));
|
|
|
|
|
// }
|
|
|
|
|
// if (j < n - 1) {
|
|
|
|
|
// tokens[j].t1 = tokens[j].t1 + t_expand;
|
|
|
|
|
// }
|
|
|
|
|
// }
|
|
|
|
|
//}
|
|
|
|
|
|
|
|
|
|
// debug info
|
|
|
|
|
//for (int j = 0; j < n; ++j) {
|
|
|
|
|
// const auto & token = tokens[j];
|
|
|
|
|
// const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]";
|
|
|
|
|
// printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__,
|
|
|
|
|
// tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(ctx, token.id));
|
|
|
|
|
|
|
|
|
|
// if (tokens[j].id >= whisper_token_eot(ctx)) {
|
|
|
|
|
// continue;
|
|
|
|
|
// }
|
|
|
|
|
//}
|
|
|
|
|
}
|
|
|
|
|