|
|
@ -14,6 +14,7 @@
|
|
|
|
#include <string>
|
|
|
|
#include <string>
|
|
|
|
#include <thread>
|
|
|
|
#include <thread>
|
|
|
|
#include <vector>
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
#include <regex>
|
|
|
|
|
|
|
|
|
|
|
|
#define USE_FLASH_ATTN
|
|
|
|
#define USE_FLASH_ATTN
|
|
|
|
//#define USE_FLASH_FF
|
|
|
|
//#define USE_FLASH_FF
|
|
|
@ -203,6 +204,10 @@ struct whisper_vocab {
|
|
|
|
std::map<token, id> token_to_id;
|
|
|
|
std::map<token, id> token_to_id;
|
|
|
|
std::map<id, token> id_to_token;
|
|
|
|
std::map<id, token> id_to_token;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// used to avoid memory allocations during sampling
|
|
|
|
|
|
|
|
// TODO: move to whisper_context in the future
|
|
|
|
|
|
|
|
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
|
|
|
|
|
|
|
|
|
|
|
|
id token_eot = 50256;
|
|
|
|
id token_eot = 50256;
|
|
|
|
id token_sot = 50257;
|
|
|
|
id token_sot = 50257;
|
|
|
|
id token_prev = 50360;
|
|
|
|
id token_prev = 50360;
|
|
|
@ -429,6 +434,12 @@ struct whisper_context {
|
|
|
|
int32_t exp_n_audio_ctx; // 0 - use default
|
|
|
|
int32_t exp_n_audio_ctx; // 0 - use default
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template<typename T>
|
|
|
|
|
|
|
|
static void read_safe(std::ifstream& fin, T& dest)
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
fin.read((char*)& dest, sizeof(T));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// load the model from a ggml file
|
|
|
|
// load the model from a ggml file
|
|
|
|
//
|
|
|
|
//
|
|
|
|
// file format:
|
|
|
|
// file format:
|
|
|
@ -455,7 +466,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
|
|
|
// verify magic
|
|
|
|
// verify magic
|
|
|
|
{
|
|
|
|
{
|
|
|
|
uint32_t magic;
|
|
|
|
uint32_t magic;
|
|
|
|
fin.read((char *) &magic, sizeof(magic));
|
|
|
|
read_safe(fin, magic);
|
|
|
|
if (magic != 0x67676d6c) {
|
|
|
|
if (magic != 0x67676d6c) {
|
|
|
|
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
|
|
|
|
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
@ -466,17 +477,17 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
|
|
|
{
|
|
|
|
{
|
|
|
|
auto & hparams = model.hparams;
|
|
|
|
auto & hparams = model.hparams;
|
|
|
|
|
|
|
|
|
|
|
|
fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
|
|
|
|
read_safe(fin, hparams.n_vocab);
|
|
|
|
fin.read((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx));
|
|
|
|
read_safe(fin, hparams.n_audio_ctx);
|
|
|
|
fin.read((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
|
|
|
|
read_safe(fin, hparams.n_audio_state);
|
|
|
|
fin.read((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head));
|
|
|
|
read_safe(fin, hparams.n_audio_head);
|
|
|
|
fin.read((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
|
|
|
|
read_safe(fin, hparams.n_audio_layer);
|
|
|
|
fin.read((char *) &hparams.n_text_ctx, sizeof(hparams.n_text_ctx));
|
|
|
|
read_safe(fin, hparams.n_text_ctx);
|
|
|
|
fin.read((char *) &hparams.n_text_state, sizeof(hparams.n_text_state));
|
|
|
|
read_safe(fin, hparams.n_text_state);
|
|
|
|
fin.read((char *) &hparams.n_text_head, sizeof(hparams.n_text_head));
|
|
|
|
read_safe(fin, hparams.n_text_head);
|
|
|
|
fin.read((char *) &hparams.n_text_layer, sizeof(hparams.n_text_layer));
|
|
|
|
read_safe(fin, hparams.n_text_layer);
|
|
|
|
fin.read((char *) &hparams.n_mels, sizeof(hparams.n_mels));
|
|
|
|
read_safe(fin, hparams.n_mels);
|
|
|
|
fin.read((char *) &hparams.f16, sizeof(hparams.f16));
|
|
|
|
read_safe(fin, hparams.f16);
|
|
|
|
|
|
|
|
|
|
|
|
assert(hparams.n_text_state == hparams.n_audio_state);
|
|
|
|
assert(hparams.n_text_state == hparams.n_audio_state);
|
|
|
|
|
|
|
|
|
|
|
@ -524,8 +535,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
|
|
|
{
|
|
|
|
{
|
|
|
|
auto & filters = wctx.model.filters;
|
|
|
|
auto & filters = wctx.model.filters;
|
|
|
|
|
|
|
|
|
|
|
|
fin.read((char *) &filters.n_mel, sizeof(filters.n_mel));
|
|
|
|
read_safe(fin, filters.n_mel);
|
|
|
|
fin.read((char *) &filters.n_fft, sizeof(filters.n_fft));
|
|
|
|
read_safe(fin, filters.n_fft);
|
|
|
|
|
|
|
|
|
|
|
|
filters.data.resize(filters.n_mel * filters.n_fft);
|
|
|
|
filters.data.resize(filters.n_mel * filters.n_fft);
|
|
|
|
fin.read((char *) filters.data.data(), filters.data.size() * sizeof(float));
|
|
|
|
fin.read((char *) filters.data.data(), filters.data.size() * sizeof(float));
|
|
|
@ -534,7 +545,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
|
|
|
// load vocab
|
|
|
|
// load vocab
|
|
|
|
{
|
|
|
|
{
|
|
|
|
int32_t n_vocab = 0;
|
|
|
|
int32_t n_vocab = 0;
|
|
|
|
fin.read((char *) &n_vocab, sizeof(n_vocab));
|
|
|
|
read_safe(fin, n_vocab);
|
|
|
|
|
|
|
|
|
|
|
|
//if (n_vocab != model.hparams.n_vocab) {
|
|
|
|
//if (n_vocab != model.hparams.n_vocab) {
|
|
|
|
// fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
|
|
|
|
// fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
|
|
|
@ -543,12 +554,23 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
|
|
|
//}
|
|
|
|
//}
|
|
|
|
|
|
|
|
|
|
|
|
std::string word;
|
|
|
|
std::string word;
|
|
|
|
|
|
|
|
std::vector<char> tmp;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tmp.reserve(128);
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < n_vocab; i++) {
|
|
|
|
for (int i = 0; i < n_vocab; i++) {
|
|
|
|
uint32_t len;
|
|
|
|
uint32_t len;
|
|
|
|
fin.read((char *) &len, sizeof(len));
|
|
|
|
read_safe(fin, len);
|
|
|
|
|
|
|
|
|
|
|
|
word.resize(len);
|
|
|
|
if (len > 0) {
|
|
|
|
fin.read((char *) word.data(), len);
|
|
|
|
tmp.resize(len);
|
|
|
|
|
|
|
|
fin.read(&tmp[0], tmp.size()); // read to buffer
|
|
|
|
|
|
|
|
word.assign(&tmp[0], tmp.size());
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
// seems like we have an empty-string token in multi-language models (i = 50256)
|
|
|
|
|
|
|
|
//fprintf(stderr, "%s: warning: empty-string token in vocab, i = %d\n", __func__, i);
|
|
|
|
|
|
|
|
word = "";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
vocab.token_to_id[word] = i;
|
|
|
|
vocab.token_to_id[word] = i;
|
|
|
|
vocab.id_to_token[i] = word;
|
|
|
|
vocab.id_to_token[i] = word;
|
|
|
@ -588,6 +610,11 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
|
|
|
vocab.id_to_token[i] = word;
|
|
|
|
vocab.id_to_token[i] = word;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
|
|
|
|
|
|
|
|
wctx.probs.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vocab.probs_id.reserve(n_vocab);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
{
|
|
|
@ -606,7 +633,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
|
|
|
const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
|
|
|
const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
|
|
|
|
|
|
|
|
|
|
|
size_t ctx_size = 0;
|
|
|
|
size_t ctx_size = 0;
|
|
|
|
size_t ctx_mem_size = 0;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
{
|
|
|
|
const auto & hparams = model.hparams;
|
|
|
|
const auto & hparams = model.hparams;
|
|
|
@ -715,12 +741,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
|
|
|
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b
|
|
|
|
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
ctx_mem_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_k
|
|
|
|
|
|
|
|
ctx_mem_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_v
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ctx_mem_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_k
|
|
|
|
|
|
|
|
ctx_mem_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_v
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
|
|
|
|
ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
|
|
|
|
|
|
|
|
|
|
|
|
fprintf(stderr, "%s: ggml ctx size = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
|
|
|
|
fprintf(stderr, "%s: ggml ctx size = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
|
|
|
@ -728,10 +748,9 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
|
|
|
|
|
|
|
|
|
|
|
// create the ggml context
|
|
|
|
// create the ggml context
|
|
|
|
{
|
|
|
|
{
|
|
|
|
struct ggml_init_params params = {
|
|
|
|
struct ggml_init_params params;
|
|
|
|
.mem_size = wctx.buf_model->size(),
|
|
|
|
params.mem_size = wctx.buf_model->size();
|
|
|
|
.mem_buffer = wctx.buf_model->data(),
|
|
|
|
params.mem_buffer = wctx.buf_model->data();
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.ctx = ggml_init(params);
|
|
|
|
model.ctx = ggml_init(params);
|
|
|
|
if (!model.ctx) {
|
|
|
|
if (!model.ctx) {
|
|
|
@ -938,10 +957,9 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
|
|
|
|
|
|
|
|
|
|
|
// create the ggml memory context
|
|
|
|
// create the ggml memory context
|
|
|
|
{
|
|
|
|
{
|
|
|
|
struct ggml_init_params params = {
|
|
|
|
struct ggml_init_params params;
|
|
|
|
.mem_size = wctx.buf_memory.size(),
|
|
|
|
params.mem_size = wctx.buf_memory.size();
|
|
|
|
.mem_buffer = wctx.buf_memory.data(),
|
|
|
|
params.mem_buffer = wctx.buf_memory.data();
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.ctx_mem = ggml_init(params);
|
|
|
|
model.ctx_mem = ggml_init(params);
|
|
|
|
if (!model.ctx_mem) {
|
|
|
|
if (!model.ctx_mem) {
|
|
|
@ -998,9 +1016,9 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
|
|
|
int32_t length;
|
|
|
|
int32_t length;
|
|
|
|
int32_t ftype;
|
|
|
|
int32_t ftype;
|
|
|
|
|
|
|
|
|
|
|
|
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
|
|
|
|
read_safe(fin, n_dims);
|
|
|
|
fin.read(reinterpret_cast<char *>(&length), sizeof(length));
|
|
|
|
read_safe(fin, length);
|
|
|
|
fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
|
|
|
|
read_safe(fin, ftype);
|
|
|
|
|
|
|
|
|
|
|
|
if (fin.eof()) {
|
|
|
|
if (fin.eof()) {
|
|
|
|
break;
|
|
|
|
break;
|
|
|
@ -1009,14 +1027,16 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
|
|
|
int32_t nelements = 1;
|
|
|
|
int32_t nelements = 1;
|
|
|
|
int32_t ne[3] = { 1, 1, 1 };
|
|
|
|
int32_t ne[3] = { 1, 1, 1 };
|
|
|
|
for (int i = 0; i < n_dims; ++i) {
|
|
|
|
for (int i = 0; i < n_dims; ++i) {
|
|
|
|
fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
|
|
|
|
read_safe(fin, ne[i]);
|
|
|
|
nelements *= ne[i];
|
|
|
|
nelements *= ne[i];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
std::string name(length, 0);
|
|
|
|
std::string name;
|
|
|
|
fin.read(&name[0], length);
|
|
|
|
std::vector<char> tmp(length); // create a buffer
|
|
|
|
|
|
|
|
fin.read(&tmp[0], tmp.size()); // read to buffer
|
|
|
|
|
|
|
|
name.assign(&tmp[0], tmp.size());
|
|
|
|
|
|
|
|
|
|
|
|
if (model.tensors.find(name.data()) == model.tensors.end()) {
|
|
|
|
if (model.tensors.find(name) == model.tensors.end()) {
|
|
|
|
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
|
|
|
|
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -1088,10 +1108,9 @@ static bool whisper_encode(
|
|
|
|
const int n_mels = hparams.n_mels;
|
|
|
|
const int n_mels = hparams.n_mels;
|
|
|
|
assert(mel_inp.n_mel == n_mels);
|
|
|
|
assert(mel_inp.n_mel == n_mels);
|
|
|
|
|
|
|
|
|
|
|
|
struct ggml_init_params params = {
|
|
|
|
struct ggml_init_params params;
|
|
|
|
.mem_size = wctx.buf_compute.size(),
|
|
|
|
params.mem_size = wctx.buf_compute.size();
|
|
|
|
.mem_buffer = wctx.buf_compute.data(),
|
|
|
|
params.mem_buffer = wctx.buf_compute.data();
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct ggml_context * ctx0 = ggml_init(params);
|
|
|
|
struct ggml_context * ctx0 = ggml_init(params);
|
|
|
|
|
|
|
|
|
|
|
@ -1166,10 +1185,9 @@ static bool whisper_encode(
|
|
|
|
|
|
|
|
|
|
|
|
// create separate context for each layer to reduce memory usage
|
|
|
|
// create separate context for each layer to reduce memory usage
|
|
|
|
|
|
|
|
|
|
|
|
struct ggml_init_params paramsL = {
|
|
|
|
struct ggml_init_params paramsL;
|
|
|
|
.mem_size = wctx.buf_compute_layer.size(),
|
|
|
|
paramsL.mem_size = wctx.buf_compute_layer.size();
|
|
|
|
.mem_buffer = wctx.buf_compute_layer.data(),
|
|
|
|
paramsL.mem_buffer = wctx.buf_compute_layer.data();
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct ggml_context * ctxL = ggml_init(paramsL);
|
|
|
|
struct ggml_context * ctxL = ggml_init(paramsL);
|
|
|
|
|
|
|
|
|
|
|
@ -1374,8 +1392,8 @@ static bool whisper_encode(
|
|
|
|
// input for next layer (inpO -> inpL)
|
|
|
|
// input for next layer (inpO -> inpL)
|
|
|
|
memcpy(inpL->data, inpO->data, ggml_nbytes(inpL));
|
|
|
|
memcpy(inpL->data, inpO->data, ggml_nbytes(inpL));
|
|
|
|
inpL->op = GGML_OP_NONE;
|
|
|
|
inpL->op = GGML_OP_NONE;
|
|
|
|
inpL->src0 = NULL;
|
|
|
|
inpL->src0 = nullptr;
|
|
|
|
inpL->src1 = NULL;
|
|
|
|
inpL->src1 = nullptr;
|
|
|
|
|
|
|
|
|
|
|
|
//printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0);
|
|
|
|
//printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0);
|
|
|
|
|
|
|
|
|
|
|
@ -1428,8 +1446,8 @@ static bool whisper_encode(
|
|
|
|
|
|
|
|
|
|
|
|
// TODO: hack to disconnect the encoded features from the previous graph
|
|
|
|
// TODO: hack to disconnect the encoded features from the previous graph
|
|
|
|
cur->op = GGML_OP_NONE;
|
|
|
|
cur->op = GGML_OP_NONE;
|
|
|
|
cur->src0 = NULL;
|
|
|
|
cur->src0 = nullptr;
|
|
|
|
cur->src1 = NULL;
|
|
|
|
cur->src1 = nullptr;
|
|
|
|
|
|
|
|
|
|
|
|
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
|
|
|
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
|
|
|
auto & layer = model.layers_decoder[il];
|
|
|
|
auto & layer = model.layers_decoder[il];
|
|
|
@ -1503,10 +1521,9 @@ static bool whisper_decode(
|
|
|
|
const int N = n_tokens;
|
|
|
|
const int N = n_tokens;
|
|
|
|
const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
|
|
|
|
const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
|
|
|
|
|
|
|
|
|
|
|
|
struct ggml_init_params params = {
|
|
|
|
struct ggml_init_params params;
|
|
|
|
.mem_size = wctx.buf_compute.size(),
|
|
|
|
params.mem_size = wctx.buf_compute.size();
|
|
|
|
.mem_buffer = wctx.buf_compute.data(),
|
|
|
|
params.mem_buffer = wctx.buf_compute.data();
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct ggml_context * ctx0 = ggml_init(params);
|
|
|
|
struct ggml_context * ctx0 = ggml_init(params);
|
|
|
|
|
|
|
|
|
|
|
@ -1529,10 +1546,9 @@ static bool whisper_decode(
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
|
|
const auto & layer = model.layers_decoder[il];
|
|
|
|
const auto & layer = model.layers_decoder[il];
|
|
|
|
|
|
|
|
|
|
|
|
struct ggml_init_params paramsL = {
|
|
|
|
struct ggml_init_params paramsL;
|
|
|
|
.mem_size = wctx.buf_compute_layer.size(),
|
|
|
|
paramsL.mem_size = wctx.buf_compute_layer.size();
|
|
|
|
.mem_buffer = wctx.buf_compute_layer.data(),
|
|
|
|
paramsL.mem_buffer = wctx.buf_compute_layer.data();
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct ggml_context * ctxL = ggml_init(paramsL);
|
|
|
|
struct ggml_context * ctxL = ggml_init(paramsL);
|
|
|
|
struct ggml_cgraph gf = {};
|
|
|
|
struct ggml_cgraph gf = {};
|
|
|
@ -1788,8 +1804,8 @@ static bool whisper_decode(
|
|
|
|
// input for next layer (inpO -> inpL)
|
|
|
|
// input for next layer (inpO -> inpL)
|
|
|
|
memcpy(inpL->data, inpO->data, ggml_nbytes(inpL));
|
|
|
|
memcpy(inpL->data, inpO->data, ggml_nbytes(inpL));
|
|
|
|
inpL->op = GGML_OP_NONE;
|
|
|
|
inpL->op = GGML_OP_NONE;
|
|
|
|
inpL->src0 = NULL;
|
|
|
|
inpL->src0 = nullptr;
|
|
|
|
inpL->src1 = NULL;
|
|
|
|
inpL->src1 = nullptr;
|
|
|
|
|
|
|
|
|
|
|
|
if (N > 1) {
|
|
|
|
if (N > 1) {
|
|
|
|
//printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0);
|
|
|
|
//printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0);
|
|
|
@ -1845,7 +1861,7 @@ static bool whisper_decode(
|
|
|
|
|
|
|
|
|
|
|
|
// the most basic sampling scheme - select the top token
|
|
|
|
// the most basic sampling scheme - select the top token
|
|
|
|
static whisper_token_data whisper_sample_best(
|
|
|
|
static whisper_token_data whisper_sample_best(
|
|
|
|
const whisper_vocab & vocab,
|
|
|
|
whisper_vocab & vocab,
|
|
|
|
const float * probs,
|
|
|
|
const float * probs,
|
|
|
|
bool force_timestamp,
|
|
|
|
bool force_timestamp,
|
|
|
|
bool is_initial) {
|
|
|
|
bool is_initial) {
|
|
|
@ -1853,13 +1869,13 @@ static whisper_token_data whisper_sample_best(
|
|
|
|
0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
|
|
|
|
0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
int n_logits = vocab.id_to_token.size();
|
|
|
|
const int n_logits = vocab.n_vocab;
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
|
|
|
|
auto & probs_id = vocab.probs_id;
|
|
|
|
probs_id.reserve(n_logits);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
probs_id.clear();
|
|
|
|
for (int i = 0; i < n_logits; i++) {
|
|
|
|
for (int i = 0; i < n_logits; i++) {
|
|
|
|
probs_id.push_back(std::make_pair(probs[i], i));
|
|
|
|
probs_id.emplace_back(probs[i], i);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
{
|
|
|
@ -1997,6 +2013,9 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
|
|
|
|
std::vector<float> even;
|
|
|
|
std::vector<float> even;
|
|
|
|
std::vector<float> odd;
|
|
|
|
std::vector<float> odd;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
even.reserve(N/2);
|
|
|
|
|
|
|
|
odd.reserve(N/2);
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < N; i++) {
|
|
|
|
for (int i = 0; i < N; i++) {
|
|
|
|
if (i % 2 == 0) {
|
|
|
|
if (i % 2 == 0) {
|
|
|
|
even.push_back(in[i]);
|
|
|
|
even.push_back(in[i]);
|
|
|
@ -2032,7 +2051,7 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
|
|
|
|
static bool log_mel_spectrogram(
|
|
|
|
static bool log_mel_spectrogram(
|
|
|
|
const float * samples,
|
|
|
|
const float * samples,
|
|
|
|
const int n_samples,
|
|
|
|
const int n_samples,
|
|
|
|
const int sample_rate,
|
|
|
|
const int /*sample_rate*/,
|
|
|
|
const int fft_size,
|
|
|
|
const int fft_size,
|
|
|
|
const int fft_step,
|
|
|
|
const int fft_step,
|
|
|
|
const int n_mel,
|
|
|
|
const int n_mel,
|
|
|
@ -2151,6 +2170,71 @@ static bool log_mel_spectrogram(
|
|
|
|
return true;
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// split text into tokens
|
|
|
|
|
|
|
|
//
|
|
|
|
|
|
|
|
// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
|
|
|
|
|
|
|
|
//
|
|
|
|
|
|
|
|
// Regex (Python):
|
|
|
|
|
|
|
|
// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
|
|
|
|
|
|
|
|
//
|
|
|
|
|
|
|
|
// Regex (C++):
|
|
|
|
|
|
|
|
// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"
|
|
|
|
|
|
|
|
//
|
|
|
|
|
|
|
|
static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, const std::string & text) {
|
|
|
|
|
|
|
|
std::vector<std::string> words;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// first split the text into words
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
std::string str = text;
|
|
|
|
|
|
|
|
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::regex re(pat);
|
|
|
|
|
|
|
|
std::smatch m;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
while (std::regex_search(str, m, re)) {
|
|
|
|
|
|
|
|
for (auto x : m) {
|
|
|
|
|
|
|
|
words.push_back(x);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
str = m.suffix();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// find the longest tokens that form the words:
|
|
|
|
|
|
|
|
std::vector<whisper_vocab::id> tokens;
|
|
|
|
|
|
|
|
for (const auto & word : words) {
|
|
|
|
|
|
|
|
if (word.empty()) continue;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int i = 0;
|
|
|
|
|
|
|
|
int n = word.size();
|
|
|
|
|
|
|
|
while (i < n) {
|
|
|
|
|
|
|
|
int j = n;
|
|
|
|
|
|
|
|
while (j > i) {
|
|
|
|
|
|
|
|
auto it = vocab.token_to_id.find(word.substr(i, j-i));
|
|
|
|
|
|
|
|
if (it != vocab.token_to_id.end()) {
|
|
|
|
|
|
|
|
tokens.push_back(it->second);
|
|
|
|
|
|
|
|
i = j;
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
--j;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (i == n) {
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (j == i) {
|
|
|
|
|
|
|
|
auto sub = word.substr(i, 1);
|
|
|
|
|
|
|
|
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
|
|
|
|
|
|
|
|
tokens.push_back(vocab.token_to_id.at(sub));
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
++i;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return tokens;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
//
|
|
|
|
//
|
|
|
|
// interface implementation
|
|
|
|
// interface implementation
|
|
|
|
//
|
|
|
|
//
|
|
|
@ -2166,7 +2250,8 @@ struct whisper_context * whisper_init(const char * path_model) {
|
|
|
|
|
|
|
|
|
|
|
|
if (!whisper_model_load(path_model, *ctx)) {
|
|
|
|
if (!whisper_model_load(path_model, *ctx)) {
|
|
|
|
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, path_model);
|
|
|
|
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, path_model);
|
|
|
|
return NULL;
|
|
|
|
delete ctx;
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
ctx->t_load_us = ggml_time_us() - t_start_us;
|
|
|
|
ctx->t_load_us = ggml_time_us() - t_start_us;
|
|
|
@ -2281,8 +2366,38 @@ struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx,
|
|
|
|
return res;
|
|
|
|
return res;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) {
|
|
|
|
|
|
|
|
const auto res = tokenize(ctx->vocab, text);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (n_max_tokens < (int) res.size()) {
|
|
|
|
|
|
|
|
fprintf(stderr, "%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
|
|
|
|
|
|
|
|
return -1;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < (int) res.size(); i++) {
|
|
|
|
|
|
|
|
tokens[i] = res[i];
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return res.size();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int whisper_lang_max_id() {
|
|
|
|
|
|
|
|
auto max_id = 0;
|
|
|
|
|
|
|
|
for (const auto & kv : g_lang) {
|
|
|
|
|
|
|
|
max_id = std::max(max_id, kv.second.first);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return max_id;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
int whisper_lang_id(const char * lang) {
|
|
|
|
int whisper_lang_id(const char * lang) {
|
|
|
|
if (!g_lang.count(lang)) {
|
|
|
|
if (!g_lang.count(lang)) {
|
|
|
|
|
|
|
|
for (const auto & kv : g_lang) {
|
|
|
|
|
|
|
|
if (kv.second.second == lang) {
|
|
|
|
|
|
|
|
return kv.second.first;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang);
|
|
|
|
fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang);
|
|
|
|
return -1;
|
|
|
|
return -1;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -2290,6 +2405,86 @@ int whisper_lang_id(const char * lang) {
|
|
|
|
return g_lang.at(lang).first;
|
|
|
|
return g_lang.at(lang).first;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const char * whisper_lang_str(int id) {
|
|
|
|
|
|
|
|
for (const auto & kv : g_lang) {
|
|
|
|
|
|
|
|
if (kv.second.first == id) {
|
|
|
|
|
|
|
|
return kv.first.c_str();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fprintf(stderr, "%s: unknown language id %d\n", __func__, id);
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int whisper_lang_auto_detect(
|
|
|
|
|
|
|
|
struct whisper_context * ctx,
|
|
|
|
|
|
|
|
int offset_ms,
|
|
|
|
|
|
|
|
int n_threads,
|
|
|
|
|
|
|
|
float * lang_probs) {
|
|
|
|
|
|
|
|
const int seek = offset_ms/10;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (seek < 0) {
|
|
|
|
|
|
|
|
fprintf(stderr, "%s: offset %dms is before the start of the audio\n", __func__, offset_ms);
|
|
|
|
|
|
|
|
return -1;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (seek >= ctx->mel.n_len) {
|
|
|
|
|
|
|
|
fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, ctx->mel.n_len*10);
|
|
|
|
|
|
|
|
return -2;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// run the encoder
|
|
|
|
|
|
|
|
if (whisper_encode(ctx, seek, n_threads) != 0) {
|
|
|
|
|
|
|
|
fprintf(stderr, "%s: failed to encode\n", __func__);
|
|
|
|
|
|
|
|
return -6;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (whisper_decode(ctx, prompt.data(), prompt.size(), 0, n_threads) != 0) {
|
|
|
|
|
|
|
|
fprintf(stderr, "%s: failed to decode\n", __func__);
|
|
|
|
|
|
|
|
return -7;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<std::pair<float, int>> probs_id;
|
|
|
|
|
|
|
|
for (const auto & kv : g_lang) {
|
|
|
|
|
|
|
|
const auto token_lang = whisper_token_lang(ctx, kv.second.first);
|
|
|
|
|
|
|
|
probs_id.emplace_back(ctx->probs[token_lang], kv.second.first);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// sort descending
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
using pair_type = decltype(probs_id)::value_type;
|
|
|
|
|
|
|
|
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
|
|
|
|
|
|
|
|
return a.first > b.first;
|
|
|
|
|
|
|
|
});
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// softmax
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
float sum = 0;
|
|
|
|
|
|
|
|
for (const auto & kv : probs_id) {
|
|
|
|
|
|
|
|
sum += exp(kv.first);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (auto & kv : probs_id) {
|
|
|
|
|
|
|
|
kv.first = exp(kv.first) / sum;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
for (int i = 0; i < (int) probs_id.size(); i++) {
|
|
|
|
|
|
|
|
if (lang_probs) {
|
|
|
|
|
|
|
|
lang_probs[probs_id[i].second] = probs_id[i].first;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//printf("%s: lang %2d (%3s): %f\n", __func__, probs_id[i].second, whisper_lang_str(probs_id[i].second), probs_id[i].first);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return probs_id[0].second;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
int whisper_n_len(struct whisper_context * ctx) {
|
|
|
|
int whisper_n_len(struct whisper_context * ctx) {
|
|
|
|
return ctx->mel.n_len;
|
|
|
|
return ctx->mel.n_len;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -2302,6 +2497,10 @@ int whisper_n_text_ctx(struct whisper_context * ctx) {
|
|
|
|
return ctx->model.hparams.n_text_ctx;
|
|
|
|
return ctx->model.hparams.n_text_ctx;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int whisper_n_audio_ctx(struct whisper_context * ctx) {
|
|
|
|
|
|
|
|
return ctx->model.hparams.n_audio_ctx;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
int whisper_is_multilingual(struct whisper_context * ctx) {
|
|
|
|
int whisper_is_multilingual(struct whisper_context * ctx) {
|
|
|
|
return ctx->vocab.is_multilingual() ? 1 : 0;
|
|
|
|
return ctx->vocab.is_multilingual() ? 1 : 0;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -2338,6 +2537,10 @@ whisper_token whisper_token_beg(struct whisper_context * ctx) {
|
|
|
|
return ctx->vocab.token_beg;
|
|
|
|
return ctx->vocab.token_beg;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) {
|
|
|
|
|
|
|
|
return whisper_token_sot(ctx) + 1 + lang_id;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
whisper_token whisper_token_translate(void) {
|
|
|
|
whisper_token whisper_token_translate(void) {
|
|
|
|
return whisper_vocab::token_translate;
|
|
|
|
return whisper_vocab::token_translate;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -2371,7 +2574,10 @@ const char * whisper_print_system_info(void) {
|
|
|
|
s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | ";
|
|
|
|
s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | ";
|
|
|
|
s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
|
|
|
|
s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
|
|
|
|
s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
|
|
|
|
s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
|
|
|
|
|
|
|
|
s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
|
|
|
|
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
|
|
|
|
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
|
|
|
|
|
|
|
|
s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
|
|
|
|
|
|
|
|
s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
|
|
|
|
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
|
|
|
|
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
|
|
|
|
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
|
|
|
|
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
|
|
|
|
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
|
|
|
|
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
|
|
|
@ -2569,10 +2775,25 @@ int whisper_full(
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
|
|
|
|
if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
|
|
|
|
fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
|
|
|
|
fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
|
|
|
|
return -1;
|
|
|
|
return -2;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// auto-detect language if not specified
|
|
|
|
|
|
|
|
if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) {
|
|
|
|
|
|
|
|
std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const auto lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, probs.data());
|
|
|
|
|
|
|
|
if (lang_id < 0) {
|
|
|
|
|
|
|
|
fprintf(stderr, "%s: failed to auto-detect language\n", __func__);
|
|
|
|
|
|
|
|
return -3;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
params.language = whisper_lang_str(lang_id);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if (params.token_timestamps) {
|
|
|
|
if (params.token_timestamps) {
|
|
|
|
ctx->t_beg = 0;
|
|
|
|
ctx->t_beg = 0;
|
|
|
|
ctx->t_last = 0;
|
|
|
|
ctx->t_last = 0;
|
|
|
@ -2605,13 +2826,18 @@ int whisper_full(
|
|
|
|
std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
|
|
|
|
std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// overwrite audio_ctx
|
|
|
|
// overwrite audio_ctx, max allowed is hparams.n_audio_ctx
|
|
|
|
|
|
|
|
if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
|
|
|
|
|
|
|
|
fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
|
|
|
|
|
|
|
|
return -4;
|
|
|
|
|
|
|
|
}
|
|
|
|
ctx->exp_n_audio_ctx = params.audio_ctx;
|
|
|
|
ctx->exp_n_audio_ctx = params.audio_ctx;
|
|
|
|
|
|
|
|
|
|
|
|
// these tokens determine the task that will be performed
|
|
|
|
// these tokens determine the task that will be performed
|
|
|
|
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
|
|
|
|
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
|
|
|
|
if (whisper_is_multilingual(ctx)) {
|
|
|
|
if (whisper_is_multilingual(ctx)) {
|
|
|
|
prompt_init.push_back(whisper_token_sot(ctx) + 1 + whisper_lang_id(params.language));
|
|
|
|
const int lang_id = whisper_lang_id(params.language);
|
|
|
|
|
|
|
|
prompt_init.push_back(whisper_token_lang(ctx, lang_id));
|
|
|
|
if (params.translate) {
|
|
|
|
if (params.translate) {
|
|
|
|
prompt_init.push_back(whisper_token_translate());
|
|
|
|
prompt_init.push_back(whisper_token_translate());
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
@ -2639,10 +2865,17 @@ int whisper_full(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// of only 1 second left, then stop
|
|
|
|
if (seek + 100 >= seek_end) {
|
|
|
|
if (seek + 100 >= seek_end) {
|
|
|
|
break;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// if there is a very short audio segment left to process, we remove any past prompt since it tends
|
|
|
|
|
|
|
|
// to confuse the decoder and often make it repeat or hallucinate stuff
|
|
|
|
|
|
|
|
if (seek > seek_start && seek + 500 >= seek_end) {
|
|
|
|
|
|
|
|
prompt_past.clear();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if (params.encoder_begin_callback) {
|
|
|
|
if (params.encoder_begin_callback) {
|
|
|
|
if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) {
|
|
|
|
if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) {
|
|
|
|
fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__);
|
|
|
|
fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__);
|
|
|
@ -2653,14 +2886,14 @@ int whisper_full(
|
|
|
|
// encode audio features starting at offset seek
|
|
|
|
// encode audio features starting at offset seek
|
|
|
|
if (whisper_encode(ctx, seek, params.n_threads) != 0) {
|
|
|
|
if (whisper_encode(ctx, seek, params.n_threads) != 0) {
|
|
|
|
fprintf(stderr, "%s: failed to encode\n", __func__);
|
|
|
|
fprintf(stderr, "%s: failed to encode\n", __func__);
|
|
|
|
return 7;
|
|
|
|
return -4;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
int n_past = 0;
|
|
|
|
int n_past = 0;
|
|
|
|
prompt.clear();
|
|
|
|
prompt.clear();
|
|
|
|
|
|
|
|
|
|
|
|
// if we have already generated some text, use it as a prompt to condition the next generation
|
|
|
|
// if we have already generated some text, use it as a prompt to condition the next generation
|
|
|
|
if (prompt_past.size() > 0) {
|
|
|
|
if (!prompt_past.empty()) {
|
|
|
|
int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
|
|
|
|
int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
|
|
|
|
|
|
|
|
|
|
|
|
prompt = { whisper_token_prev(ctx) };
|
|
|
|
prompt = { whisper_token_prev(ctx) };
|
|
|
@ -2686,11 +2919,12 @@ int whisper_full(
|
|
|
|
tokens_cur.clear();
|
|
|
|
tokens_cur.clear();
|
|
|
|
|
|
|
|
|
|
|
|
bool failed = false;
|
|
|
|
bool failed = false;
|
|
|
|
|
|
|
|
bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment?
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
|
|
|
|
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
|
|
|
|
if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
|
|
|
|
if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
|
|
|
|
fprintf(stderr, "%s: failed to decode\n", __func__);
|
|
|
|
fprintf(stderr, "%s: failed to decode\n", __func__);
|
|
|
|
return 8;
|
|
|
|
return -5;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
n_past += prompt.size();
|
|
|
|
n_past += prompt.size();
|
|
|
@ -2711,13 +2945,13 @@ int whisper_full(
|
|
|
|
const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx));
|
|
|
|
const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx));
|
|
|
|
|
|
|
|
|
|
|
|
// do not allow to go back in time
|
|
|
|
// do not allow to go back in time
|
|
|
|
if (seek_delta != 100*WHISPER_CHUNK_SIZE &&
|
|
|
|
if (has_ts && seek_delta > seek_delta_new && result_len < i) {
|
|
|
|
seek_delta > seek_delta_new && result_len < i) {
|
|
|
|
|
|
|
|
break;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
seek_delta = seek_delta_new;
|
|
|
|
seek_delta = seek_delta_new;
|
|
|
|
result_len = i + 1;
|
|
|
|
result_len = i + 1;
|
|
|
|
|
|
|
|
has_ts = true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// add it to the context
|
|
|
|
// add it to the context
|
|
|
@ -2726,11 +2960,14 @@ int whisper_full(
|
|
|
|
|
|
|
|
|
|
|
|
//{
|
|
|
|
//{
|
|
|
|
// const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]";
|
|
|
|
// const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]";
|
|
|
|
// printf("%s: %10s %6d %6.3f '%s'\n", __func__, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str());
|
|
|
|
// printf("%s: %3d %10s %6d %6.3f '%s'\n", __func__, i, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str());
|
|
|
|
//}
|
|
|
|
//}
|
|
|
|
|
|
|
|
|
|
|
|
// end of text token
|
|
|
|
// end of segment
|
|
|
|
if (token.id == whisper_token_eot(ctx) || (params.max_tokens > 0 && i > params.max_tokens)) {
|
|
|
|
if (token.id == whisper_token_eot(ctx) || // end of text token
|
|
|
|
|
|
|
|
(params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached
|
|
|
|
|
|
|
|
(has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached
|
|
|
|
|
|
|
|
) {
|
|
|
|
if (result_len == 0) {
|
|
|
|
if (result_len == 0) {
|
|
|
|
if (seek + seek_delta + 100 >= seek_end) {
|
|
|
|
if (seek + seek_delta + 100 >= seek_end) {
|
|
|
|
result_len = i + 1;
|
|
|
|
result_len = i + 1;
|
|
|
@ -2765,8 +3002,14 @@ int whisper_full(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if (failed) {
|
|
|
|
if (failed) {
|
|
|
|
fprintf(stderr, "\n%s: failed to generate timestamp token - using fallback strategy\n\n", __func__);
|
|
|
|
// when we fail to sample timestamp token, retry by clearing the past prompt
|
|
|
|
seek += 100;
|
|
|
|
// if it fails again, then we advance the window by 1 second
|
|
|
|
|
|
|
|
if (!prompt_past.empty()) {
|
|
|
|
|
|
|
|
prompt_past.clear();
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
fprintf(stderr, "\n%s: failed to generate timestamp token - skipping one second\n\n", __func__);
|
|
|
|
|
|
|
|
seek += 100;
|
|
|
|
|
|
|
|
}
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -2778,11 +3021,11 @@ int whisper_full(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// store the text from this iteration
|
|
|
|
// store the text from this iteration
|
|
|
|
if (tokens_cur.size() > 0) {
|
|
|
|
if (!tokens_cur.empty()) {
|
|
|
|
int i0 = 0;
|
|
|
|
int i0 = 0;
|
|
|
|
auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
|
|
|
|
auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
|
|
|
|
|
|
|
|
|
|
|
|
std::string text = "";
|
|
|
|
std::string text;
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < (int) tokens_cur.size(); i++) {
|
|
|
|
for (int i = 0; i < (int) tokens_cur.size(); i++) {
|
|
|
|
//printf("%s: %18s %6.3f %18s %6.3f\n", __func__,
|
|
|
|
//printf("%s: %18s %6.3f %18s %6.3f\n", __func__,
|
|
|
@ -2901,10 +3144,9 @@ int whisper_full_parallel(
|
|
|
|
|
|
|
|
|
|
|
|
// create the ggml memory context
|
|
|
|
// create the ggml memory context
|
|
|
|
{
|
|
|
|
{
|
|
|
|
struct ggml_init_params params = {
|
|
|
|
struct ggml_init_params params;
|
|
|
|
.mem_size = ctxs[i].buf_memory.size(),
|
|
|
|
params.mem_size = ctxs[i].buf_memory.size();
|
|
|
|
.mem_buffer = ctxs[i].buf_memory.data(),
|
|
|
|
params.mem_buffer = ctxs[i].buf_memory.data();
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.ctx_mem = ggml_init(params);
|
|
|
|
model.ctx_mem = ggml_init(params);
|
|
|
|
if (!model.ctx_mem) {
|
|
|
|
if (!model.ctx_mem) {
|
|
|
@ -2990,7 +3232,7 @@ int whisper_full_parallel(
|
|
|
|
results_i[j].t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
|
|
|
|
results_i[j].t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
|
|
|
|
|
|
|
|
|
|
|
|
// make sure that segments are not overlapping
|
|
|
|
// make sure that segments are not overlapping
|
|
|
|
if (ctx->result_all.size() > 0) {
|
|
|
|
if (!ctx->result_all.empty()) {
|
|
|
|
results_i[j].t0 = std::max(results_i[j].t0, ctx->result_all.back().t1);
|
|
|
|
results_i[j].t0 = std::max(results_i[j].t0, ctx->result_all.back().t1);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|