whisper : prepare infra for new decoding strategies

pull/291/head
Georgi Gerganov 3 years ago
parent 4ef3398e8f
commit abe104b80b
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -16,6 +16,14 @@
#include <vector> #include <vector>
#include <regex> #include <regex>
#define WHISPER_ASSERT(x) \
do { \
if (!(x)) { \
fprintf(stderr, "WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
abort(); \
} \
} while (0)
#define USE_FLASH_ATTN #define USE_FLASH_ATTN
//#define USE_FLASH_FF //#define USE_FLASH_FF
@ -423,8 +431,9 @@ struct whisper_context {
std::vector<float> logits; std::vector<float> logits;
std::vector<whisper_segment> result_all; std::vector<whisper_segment> result_all;
std::vector<whisper_token> prompt_past;
std::vector<whisper_token> prompt_past; std::vector<float> work_logits; // used to avoid allocations
// [EXPERIMENTAL] token-level timestamps data // [EXPERIMENTAL] token-level timestamps data
int64_t t_beg; int64_t t_beg;
@ -2689,12 +2698,13 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.language =*/ "en", /*.language =*/ "en",
/*.suppress_blank =*/ true,
/*.greedy =*/ { /*.greedy =*/ {
/*.n_past =*/ 0, /*.dummy =*/ 0,
}, },
/*.beam_search =*/ { /*.beam_search =*/ {
/*.n_past =*/ -1,
/*.beam_width =*/ -1, /*.beam_width =*/ -1,
/*.n_best =*/ -1, /*.n_best =*/ -1,
}, },
@ -2738,12 +2748,13 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.language =*/ "en", /*.language =*/ "en",
/*.suppress_blank =*/ true,
/*.greedy =*/ { /*.greedy =*/ {
/*.n_past =*/ -1, /*.dummy =*/ 0,
}, },
/*.beam_search =*/ { /*.beam_search =*/ {
/*.n_past =*/ 0,
/*.beam_width =*/ 10, /*.beam_width =*/ 10,
/*.n_best =*/ 5, /*.n_best =*/ 5,
}, },
@ -2822,6 +2833,50 @@ static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) {
return res; return res;
} }
static struct whisper_token_data whisper_sample_next_token(
struct whisper_context * ctx,
struct whisper_full_params params,
const std::vector<whisper_token> & prompt,
const std::vector<whisper_token_data> & tokens_cur) {
struct whisper_token_data result = {};
const auto & vocab = ctx->vocab;
const bool is_initial = tokens_cur.size() == 0;
const int n_logits = vocab.id_to_token.size();
WHISPER_ASSERT(n_logits == ctx->vocab.n_vocab);
// extract the logits for the last token
// we will be mutating and therefore we don't want to use the ctx->logits buffer directly
auto & logits = ctx->work_logits;
{
logits.resize(n_logits);
memcpy(logits.data(), ctx->logits.data() + (ctx->logits.size() - n_logits), n_logits*sizeof(float));
}
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L480-L493
// TODO: apply logit filters here
{
}
switch (params.strategy) {
case WHISPER_SAMPLING_GREEDY:
{
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L249-L274
// TODO: implement
result = (is_initial) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx);
} break;
case WHISPER_SAMPLING_BEAM_SEARCH:
{
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L277C13-L364
// TODO: implement
} break;
}
return result;
}
int whisper_full( int whisper_full(
struct whisper_context * ctx, struct whisper_context * ctx,
struct whisper_full_params params, struct whisper_full_params params,
@ -2955,7 +3010,6 @@ int whisper_full(
return -4; return -4;
} }
int n_past = 0;
prompt.clear(); prompt.clear();
// if we have already generated some text, use it as a prompt to condition the next generation // if we have already generated some text, use it as a prompt to condition the next generation
@ -2971,8 +3025,6 @@ int whisper_full(
prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
int seek_delta = 100*WHISPER_CHUNK_SIZE;
// print the prompt // print the prompt
//printf("\n\n"); //printf("\n\n");
//for (int i = 0; i < prompt.size(); i++) { //for (int i = 0; i < prompt.size(); i++) {
@ -2980,11 +3032,14 @@ int whisper_full(
//} //}
//printf("\n\n"); //printf("\n\n");
int n_past = 0;
int seek_delta = 100*WHISPER_CHUNK_SIZE;
// the accumulated transcription in the current interation // the accumulated transcription in the current interation
int result_len = 0; int result_len = 0;
tokens_cur.clear(); tokens_cur.clear();
bool failed = false; bool failed = false; // has the current segment failed to decode?
bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment? bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment?
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
@ -2996,15 +3051,10 @@ int whisper_full(
n_past += prompt.size(); n_past += prompt.size();
prompt.clear(); prompt.clear();
// very basic greedy sampling strategy: // sample the next token based on the selected decoding strategy + parameters
// // also, update the sliding window position based on the sampled timestamp tokens
// - always take the most probable token
//
// more sophisticated sampling strategies could be implemented here, but we keep it simple
// feel free to experiment!
//
{ {
const auto token = (i == 0) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx); const auto token = whisper_sample_next_token(ctx, params, prompt, tokens_cur);
// timestamp token - update sliding window // timestamp token - update sliding window
if (token.id > whisper_token_beg(ctx)) { if (token.id > whisper_token_beg(ctx)) {
@ -3059,8 +3109,7 @@ int whisper_full(
} }
// sometimes, the decoding can get stuck in a repetition loop // sometimes, the decoding can get stuck in a repetition loop
// this is a simple strategy to avoid such cases - we simply flag the decoding as failed and advance // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
// the sliding window by 1 second
if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) { if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
failed = true; failed = true;
break; break;

@ -274,12 +274,13 @@ extern "C" {
// for auto-detection, set to nullptr, "" or "auto" // for auto-detection, set to nullptr, "" or "auto"
const char * language; const char * language;
bool suppress_blank;
struct { struct {
int n_past; int dummy;
} greedy; } greedy;
struct { struct {
int n_past;
int beam_width; int beam_width;
int n_best; int n_best;
} beam_search; } beam_search;

Loading…
Cancel
Save