|
|
@ -16,6 +16,14 @@
|
|
|
|
#include <vector>
|
|
|
|
#include <vector>
|
|
|
|
#include <regex>
|
|
|
|
#include <regex>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#define WHISPER_ASSERT(x) \
|
|
|
|
|
|
|
|
do { \
|
|
|
|
|
|
|
|
if (!(x)) { \
|
|
|
|
|
|
|
|
fprintf(stderr, "WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
|
|
|
|
|
|
|
|
abort(); \
|
|
|
|
|
|
|
|
} \
|
|
|
|
|
|
|
|
} while (0)
|
|
|
|
|
|
|
|
|
|
|
|
#define USE_FLASH_ATTN
|
|
|
|
#define USE_FLASH_ATTN
|
|
|
|
//#define USE_FLASH_FF
|
|
|
|
//#define USE_FLASH_FF
|
|
|
|
|
|
|
|
|
|
|
@ -423,8 +431,9 @@ struct whisper_context {
|
|
|
|
std::vector<float> logits;
|
|
|
|
std::vector<float> logits;
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<whisper_segment> result_all;
|
|
|
|
std::vector<whisper_segment> result_all;
|
|
|
|
|
|
|
|
std::vector<whisper_token> prompt_past;
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<whisper_token> prompt_past;
|
|
|
|
std::vector<float> work_logits; // used to avoid allocations
|
|
|
|
|
|
|
|
|
|
|
|
// [EXPERIMENTAL] token-level timestamps data
|
|
|
|
// [EXPERIMENTAL] token-level timestamps data
|
|
|
|
int64_t t_beg;
|
|
|
|
int64_t t_beg;
|
|
|
@ -2689,12 +2698,13 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
|
|
|
|
|
|
|
|
|
/*.language =*/ "en",
|
|
|
|
/*.language =*/ "en",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/*.suppress_blank =*/ true,
|
|
|
|
|
|
|
|
|
|
|
|
/*.greedy =*/ {
|
|
|
|
/*.greedy =*/ {
|
|
|
|
/*.n_past =*/ 0,
|
|
|
|
/*.dummy =*/ 0,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
|
|
|
|
|
|
|
|
/*.beam_search =*/ {
|
|
|
|
/*.beam_search =*/ {
|
|
|
|
/*.n_past =*/ -1,
|
|
|
|
|
|
|
|
/*.beam_width =*/ -1,
|
|
|
|
/*.beam_width =*/ -1,
|
|
|
|
/*.n_best =*/ -1,
|
|
|
|
/*.n_best =*/ -1,
|
|
|
|
},
|
|
|
|
},
|
|
|
@ -2738,12 +2748,13 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
|
|
|
|
|
|
|
|
|
/*.language =*/ "en",
|
|
|
|
/*.language =*/ "en",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/*.suppress_blank =*/ true,
|
|
|
|
|
|
|
|
|
|
|
|
/*.greedy =*/ {
|
|
|
|
/*.greedy =*/ {
|
|
|
|
/*.n_past =*/ -1,
|
|
|
|
/*.dummy =*/ 0,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
|
|
|
|
|
|
|
|
/*.beam_search =*/ {
|
|
|
|
/*.beam_search =*/ {
|
|
|
|
/*.n_past =*/ 0,
|
|
|
|
|
|
|
|
/*.beam_width =*/ 10,
|
|
|
|
/*.beam_width =*/ 10,
|
|
|
|
/*.n_best =*/ 5,
|
|
|
|
/*.n_best =*/ 5,
|
|
|
|
},
|
|
|
|
},
|
|
|
@ -2822,6 +2833,50 @@ static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) {
|
|
|
|
return res;
|
|
|
|
return res;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static struct whisper_token_data whisper_sample_next_token(
|
|
|
|
|
|
|
|
struct whisper_context * ctx,
|
|
|
|
|
|
|
|
struct whisper_full_params params,
|
|
|
|
|
|
|
|
const std::vector<whisper_token> & prompt,
|
|
|
|
|
|
|
|
const std::vector<whisper_token_data> & tokens_cur) {
|
|
|
|
|
|
|
|
struct whisper_token_data result = {};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const auto & vocab = ctx->vocab;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const bool is_initial = tokens_cur.size() == 0;
|
|
|
|
|
|
|
|
const int n_logits = vocab.id_to_token.size();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
WHISPER_ASSERT(n_logits == ctx->vocab.n_vocab);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// extract the logits for the last token
|
|
|
|
|
|
|
|
// we will be mutating and therefore we don't want to use the ctx->logits buffer directly
|
|
|
|
|
|
|
|
auto & logits = ctx->work_logits;
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
logits.resize(n_logits);
|
|
|
|
|
|
|
|
memcpy(logits.data(), ctx->logits.data() + (ctx->logits.size() - n_logits), n_logits*sizeof(float));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L480-L493
|
|
|
|
|
|
|
|
// TODO: apply logit filters here
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
switch (params.strategy) {
|
|
|
|
|
|
|
|
case WHISPER_SAMPLING_GREEDY:
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L249-L274
|
|
|
|
|
|
|
|
// TODO: implement
|
|
|
|
|
|
|
|
result = (is_initial) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx);
|
|
|
|
|
|
|
|
} break;
|
|
|
|
|
|
|
|
case WHISPER_SAMPLING_BEAM_SEARCH:
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L277C13-L364
|
|
|
|
|
|
|
|
// TODO: implement
|
|
|
|
|
|
|
|
} break;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return result;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
int whisper_full(
|
|
|
|
int whisper_full(
|
|
|
|
struct whisper_context * ctx,
|
|
|
|
struct whisper_context * ctx,
|
|
|
|
struct whisper_full_params params,
|
|
|
|
struct whisper_full_params params,
|
|
|
@ -2955,7 +3010,6 @@ int whisper_full(
|
|
|
|
return -4;
|
|
|
|
return -4;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
int n_past = 0;
|
|
|
|
|
|
|
|
prompt.clear();
|
|
|
|
prompt.clear();
|
|
|
|
|
|
|
|
|
|
|
|
// if we have already generated some text, use it as a prompt to condition the next generation
|
|
|
|
// if we have already generated some text, use it as a prompt to condition the next generation
|
|
|
@ -2971,8 +3025,6 @@ int whisper_full(
|
|
|
|
|
|
|
|
|
|
|
|
prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
|
|
|
|
prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
|
|
|
|
|
|
|
|
|
|
|
|
int seek_delta = 100*WHISPER_CHUNK_SIZE;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// print the prompt
|
|
|
|
// print the prompt
|
|
|
|
//printf("\n\n");
|
|
|
|
//printf("\n\n");
|
|
|
|
//for (int i = 0; i < prompt.size(); i++) {
|
|
|
|
//for (int i = 0; i < prompt.size(); i++) {
|
|
|
@ -2980,11 +3032,14 @@ int whisper_full(
|
|
|
|
//}
|
|
|
|
//}
|
|
|
|
//printf("\n\n");
|
|
|
|
//printf("\n\n");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int n_past = 0;
|
|
|
|
|
|
|
|
int seek_delta = 100*WHISPER_CHUNK_SIZE;
|
|
|
|
|
|
|
|
|
|
|
|
// the accumulated transcription in the current interation
|
|
|
|
// the accumulated transcription in the current interation
|
|
|
|
int result_len = 0;
|
|
|
|
int result_len = 0;
|
|
|
|
tokens_cur.clear();
|
|
|
|
tokens_cur.clear();
|
|
|
|
|
|
|
|
|
|
|
|
bool failed = false;
|
|
|
|
bool failed = false; // has the current segment failed to decode?
|
|
|
|
bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment?
|
|
|
|
bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment?
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
|
|
|
|
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
|
|
|
@ -2996,15 +3051,10 @@ int whisper_full(
|
|
|
|
n_past += prompt.size();
|
|
|
|
n_past += prompt.size();
|
|
|
|
prompt.clear();
|
|
|
|
prompt.clear();
|
|
|
|
|
|
|
|
|
|
|
|
// very basic greedy sampling strategy:
|
|
|
|
// sample the next token based on the selected decoding strategy + parameters
|
|
|
|
//
|
|
|
|
// also, update the sliding window position based on the sampled timestamp tokens
|
|
|
|
// - always take the most probable token
|
|
|
|
|
|
|
|
//
|
|
|
|
|
|
|
|
// more sophisticated sampling strategies could be implemented here, but we keep it simple
|
|
|
|
|
|
|
|
// feel free to experiment!
|
|
|
|
|
|
|
|
//
|
|
|
|
|
|
|
|
{
|
|
|
|
{
|
|
|
|
const auto token = (i == 0) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx);
|
|
|
|
const auto token = whisper_sample_next_token(ctx, params, prompt, tokens_cur);
|
|
|
|
|
|
|
|
|
|
|
|
// timestamp token - update sliding window
|
|
|
|
// timestamp token - update sliding window
|
|
|
|
if (token.id > whisper_token_beg(ctx)) {
|
|
|
|
if (token.id > whisper_token_beg(ctx)) {
|
|
|
@ -3059,8 +3109,7 @@ int whisper_full(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// sometimes, the decoding can get stuck in a repetition loop
|
|
|
|
// sometimes, the decoding can get stuck in a repetition loop
|
|
|
|
// this is a simple strategy to avoid such cases - we simply flag the decoding as failed and advance
|
|
|
|
// this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
|
|
|
|
// the sliding window by 1 second
|
|
|
|
|
|
|
|
if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
|
|
|
|
if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
|
|
|
|
failed = true;
|
|
|
|
failed = true;
|
|
|
|
break;
|
|
|
|
break;
|
|
|
|