|
|
|
@ -1105,7 +1105,7 @@ static bool whisper_encode(
|
|
|
|
|
|
|
|
|
|
struct ggml_init_params params;
|
|
|
|
|
params.mem_size = wctx.buf_compute.size();
|
|
|
|
|
params.mem_buffer = wctx.buf_compute.data();
|
|
|
|
|
params.mem_buffer = wctx.buf_compute.data();
|
|
|
|
|
|
|
|
|
|
struct ggml_context * ctx0 = ggml_init(params);
|
|
|
|
|
|
|
|
|
@ -2372,8 +2372,23 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to
|
|
|
|
|
return res.size();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int whisper_lang_max_id() {
|
|
|
|
|
auto max_id = 0;
|
|
|
|
|
for (const auto & kv : g_lang) {
|
|
|
|
|
max_id = std::max(max_id, kv.second.first);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return max_id;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int whisper_lang_id(const char * lang) {
|
|
|
|
|
if (!g_lang.count(lang)) {
|
|
|
|
|
for (const auto & kv : g_lang) {
|
|
|
|
|
if (kv.second.second == lang) {
|
|
|
|
|
return kv.second.first;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang);
|
|
|
|
|
return -1;
|
|
|
|
|
}
|
|
|
|
@ -2381,6 +2396,86 @@ int whisper_lang_id(const char * lang) {
|
|
|
|
|
return g_lang.at(lang).first;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const char * whisper_lang_str(int id) {
|
|
|
|
|
for (const auto & kv : g_lang) {
|
|
|
|
|
if (kv.second.first == id) {
|
|
|
|
|
return kv.first.c_str();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fprintf(stderr, "%s: unknown language id %d\n", __func__, id);
|
|
|
|
|
return NULL;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int whisper_lang_auto_detect(
|
|
|
|
|
struct whisper_context * ctx,
|
|
|
|
|
int offset_ms,
|
|
|
|
|
int n_threads,
|
|
|
|
|
float * lang_probs) {
|
|
|
|
|
const int seek = offset_ms/10;
|
|
|
|
|
|
|
|
|
|
if (seek < 0) {
|
|
|
|
|
fprintf(stderr, "%s: offset %dms is before the start of the audio\n", __func__, offset_ms);
|
|
|
|
|
return -1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (seek >= ctx->mel.n_len) {
|
|
|
|
|
fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, ctx->mel.n_len*10);
|
|
|
|
|
return -2;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// run the encoder
|
|
|
|
|
if (whisper_encode(ctx, seek, n_threads) != 0) {
|
|
|
|
|
fprintf(stderr, "%s: failed to encode\n", __func__);
|
|
|
|
|
return -6;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
|
|
|
|
|
|
|
|
|
|
if (whisper_decode(ctx, prompt.data(), prompt.size(), 0, n_threads) != 0) {
|
|
|
|
|
fprintf(stderr, "%s: failed to decode\n", __func__);
|
|
|
|
|
return -7;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::pair<float, int>> probs_id;
|
|
|
|
|
for (const auto kv : g_lang) {
|
|
|
|
|
const auto token_lang = whisper_token_lang(ctx, kv.second.first);
|
|
|
|
|
probs_id.push_back({ ctx->probs[token_lang], kv.second.first });
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// sort descending
|
|
|
|
|
{
|
|
|
|
|
using pair_type = decltype(probs_id)::value_type;
|
|
|
|
|
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
|
|
|
|
|
return a.first > b.first;
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// softmax
|
|
|
|
|
{
|
|
|
|
|
float sum = 0;
|
|
|
|
|
for (const auto & kv : probs_id) {
|
|
|
|
|
sum += exp(kv.first);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto & kv : probs_id) {
|
|
|
|
|
kv.first = exp(kv.first) / sum;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
for (int i = 0; i < probs_id.size(); i++) {
|
|
|
|
|
if (lang_probs) {
|
|
|
|
|
lang_probs[probs_id[i].second] = probs_id[i].first;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//printf("%s: lang %2d (%3s): %f\n", __func__, probs_id[i].second, whisper_lang_str(probs_id[i].second), probs_id[i].first);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return probs_id[0].second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int whisper_n_len(struct whisper_context * ctx) {
|
|
|
|
|
return ctx->mel.n_len;
|
|
|
|
|
}
|
|
|
|
@ -2429,6 +2524,10 @@ whisper_token whisper_token_beg(struct whisper_context * ctx) {
|
|
|
|
|
return ctx->vocab.token_beg;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) {
|
|
|
|
|
return whisper_token_sot(ctx) + 1 + lang_id;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
whisper_token whisper_token_translate(void) {
|
|
|
|
|
return whisper_vocab::token_translate;
|
|
|
|
|
}
|
|
|
|
@ -2661,10 +2760,25 @@ int whisper_full(
|
|
|
|
|
} else {
|
|
|
|
|
if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
|
|
|
|
|
fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
|
|
|
|
|
return -1;
|
|
|
|
|
return -2;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// auto-detect language if not specified
|
|
|
|
|
if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) {
|
|
|
|
|
std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
|
|
|
|
|
|
|
|
|
|
const auto lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, probs.data());
|
|
|
|
|
if (lang_id < 0) {
|
|
|
|
|
fprintf(stderr, "%s: failed to auto-detect language\n", __func__);
|
|
|
|
|
return -3;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
params.language = whisper_lang_str(lang_id);
|
|
|
|
|
|
|
|
|
|
fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (params.token_timestamps) {
|
|
|
|
|
ctx->t_beg = 0;
|
|
|
|
|
ctx->t_last = 0;
|
|
|
|
@ -2703,7 +2817,8 @@ int whisper_full(
|
|
|
|
|
// these tokens determine the task that will be performed
|
|
|
|
|
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
|
|
|
|
|
if (whisper_is_multilingual(ctx)) {
|
|
|
|
|
prompt_init.push_back(whisper_token_sot(ctx) + 1 + whisper_lang_id(params.language));
|
|
|
|
|
const int lang_id = whisper_lang_id(params.language);
|
|
|
|
|
prompt_init.push_back(whisper_token_lang(ctx, lang_id));
|
|
|
|
|
if (params.translate) {
|
|
|
|
|
prompt_init.push_back(whisper_token_translate());
|
|
|
|
|
} else {
|
|
|
|
@ -2752,7 +2867,7 @@ int whisper_full(
|
|
|
|
|
// encode audio features starting at offset seek
|
|
|
|
|
if (whisper_encode(ctx, seek, params.n_threads) != 0) {
|
|
|
|
|
fprintf(stderr, "%s: failed to encode\n", __func__);
|
|
|
|
|
return 7;
|
|
|
|
|
return -4;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int n_past = 0;
|
|
|
|
@ -2790,7 +2905,7 @@ int whisper_full(
|
|
|
|
|
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
|
|
|
|
|
if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
|
|
|
|
|
fprintf(stderr, "%s: failed to decode\n", __func__);
|
|
|
|
|
return 8;
|
|
|
|
|
return -5;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
n_past += prompt.size();
|
|
|
|
|