whisper : temperature + best_of support

pull/291/head
Georgi Gerganov 3 years ago
parent 712bc4b960
commit 34c5110f59
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -15,6 +15,7 @@
#include <thread> #include <thread>
#include <vector> #include <vector>
#include <regex> #include <regex>
#include <random>
#define WHISPER_ASSERT(x) \ #define WHISPER_ASSERT(x) \
do { \ do { \
@ -479,6 +480,8 @@ struct whisper_context {
std::vector<std::pair<double, whisper_vocab::id>> logits_id; std::vector<std::pair<double, whisper_vocab::id>> logits_id;
mutable std::mt19937 rng; // used for sampling at t > 0.0
// [EXPERIMENTAL] token-level timestamps data // [EXPERIMENTAL] token-level timestamps data
int64_t t_beg; int64_t t_beg;
int64_t t_last; int64_t t_last;
@ -1180,6 +1183,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
} }
} }
wctx.rng = std::mt19937(0);
wctx.t_load_us = ggml_time_us() - t_start_us; wctx.t_load_us = ggml_time_us() - t_start_us;
return true; return true;
@ -2848,9 +2853,10 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) {
// - applyies logit filters // - applyies logit filters
// - computes logprobs // - computes logprobs
static void whisper_process_logits( static void whisper_process_logits(
struct whisper_context & ctx, const struct whisper_context & ctx,
const struct whisper_full_params params,
struct whisper_decoder & decoder, struct whisper_decoder & decoder,
struct whisper_full_params params) { float temperature) {
const auto & vocab = ctx.vocab; const auto & vocab = ctx.vocab;
const auto & tokens_cur = decoder.sequence.tokens; const auto & tokens_cur = decoder.sequence.tokens;
@ -2868,6 +2874,12 @@ static void whisper_process_logits(
logits.resize(n_logits); logits.resize(n_logits);
memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float)); memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float));
if (temperature > 0.0f) {
for (int i = 0; i < n_logits; i++) {
logits[i] /= temperature;
}
}
// will be populated a bit later // will be populated a bit later
probs.resize(n_logits); probs.resize(n_logits);
logprobs.resize(n_logits); logprobs.resize(n_logits);
@ -3018,10 +3030,10 @@ static void whisper_process_logits(
#endif #endif
} }
// select the most probable token static whisper_token_data whisper_sample_token(
static whisper_token_data whisper_sample_best( const whisper_context & ctx,
whisper_context & ctx, const whisper_decoder & decoder,
whisper_decoder & decoder) { bool best) {
whisper_token_data result = { whisper_token_data result = {
0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f, 0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
}; };
@ -3053,6 +3065,7 @@ static whisper_token_data whisper_sample_best(
result.ptsum = sum_ts; result.ptsum = sum_ts;
} }
if (best) {
for (int i = 0; i < n_logits; ++i) { for (int i = 0; i < n_logits; ++i) {
if (result.p < probs[i]) { if (result.p < probs[i]) {
result.id = i; result.id = i;
@ -3060,6 +3073,13 @@ static whisper_token_data whisper_sample_best(
result.plog = logprobs[i]; result.plog = logprobs[i];
} }
} }
} else {
std::discrete_distribution<> dist(probs.begin(), probs.end());
result.id = dist(ctx.rng);
result.p = probs[result.id];
result.plog = logprobs[result.id];
}
return result; return result;
} }
@ -3167,22 +3187,24 @@ int whisper_full(
} break; } break;
}; };
for (int i = 1; i < n_decoders; i++) {
// TAGS: WHISPER_DECODER_INIT // TAGS: WHISPER_DECODER_INIT
if (ctx->decoders[i].kv_self.ctx == nullptr) { for (int j = 1; j < n_decoders; j++) {
ctx->decoders[i].kv_self = ctx->decoders[0].kv_self; auto & decoder = ctx->decoders[j];
if (!kv_cache_reinit(ctx->decoders[i].kv_self)) {
fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, i); if (decoder.kv_self.ctx == nullptr) {
decoder.kv_self = ctx->decoders[0].kv_self;
if (!kv_cache_reinit(decoder.kv_self)) {
fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
return -4; return -4;
} }
fprintf(stderr, "%s: initialized self-attention kv cache, decoder %d\n", __func__, i); fprintf(stderr, "%s: initialized self-attention kv cache, decoder %d\n", __func__, j);
ctx->decoders[i].sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity()); decoder.sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity());
ctx->decoders[i].probs.resize (ctx->vocab.n_vocab); decoder.probs.resize (ctx->vocab.n_vocab);
ctx->decoders[i].logits.resize (ctx->vocab.n_vocab); decoder.logits.resize (ctx->vocab.n_vocab);
ctx->decoders[i].logprobs.resize(ctx->vocab.n_vocab); decoder.logprobs.resize(ctx->vocab.n_vocab);
} }
} }
@ -3268,6 +3290,7 @@ int whisper_full(
const float t_cur = temperatures[it]; const float t_cur = temperatures[it];
int n_decoders_cur = 1; int n_decoders_cur = 1;
switch (params.strategy) { switch (params.strategy) {
case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
{ {
@ -3330,7 +3353,7 @@ int whisper_full(
// print the prompt // print the prompt
//printf("\n\n"); //printf("\n\n");
//for (int i = 0; i < (int) prompt.size(); i++) { //for (int i = 0; i < (int) prompt.size(); i++) {
// printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token[prompt[i]].c_str()); // printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str());
//} //}
//printf("\n\n"); //printf("\n\n");
@ -3342,7 +3365,7 @@ int whisper_full(
{ {
const int64_t t_start_sample_us = ggml_time_us(); const int64_t t_start_sample_us = ggml_time_us();
whisper_process_logits(*ctx, ctx->decoders[0], params); whisper_process_logits(*ctx, params, ctx->decoders[0], t_cur);
ctx->decoders[0].n_past += prompt.size(); ctx->decoders[0].n_past += prompt.size();
@ -3376,8 +3399,9 @@ int whisper_full(
case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
{ {
if (t_cur < 1e-6f) { if (t_cur < 1e-6f) {
decoder.sequence.tokens.push_back(whisper_sample_best(*ctx, decoder)); decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true));
} else { } else {
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false));
} }
} break; } break;
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
@ -3402,7 +3426,7 @@ int whisper_full(
// do not allow to go back in time // do not allow to go back in time
if (has_ts && seek_delta > seek_delta_new && result_len < i) { if (has_ts && seek_delta > seek_delta_new && result_len < i) {
failed = true; // TODO: maybe this is not a failure ? failed = true; // TODO: maybe this is not a failure ?
break; continue;
} }
seek_delta = seek_delta_new; seek_delta = seek_delta_new;
@ -3410,11 +3434,11 @@ int whisper_full(
has_ts = true; has_ts = true;
} }
{ //{
const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]"; // const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]";
printf("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n", // printf("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n",
__func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token[token.id].c_str()); // __func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token.at(token.id).c_str());
} //}
// end of segment // end of segment
if (token.id == whisper_token_eot(ctx) || // end of text token if (token.id == whisper_token_eot(ctx) || // end of text token
@ -3426,7 +3450,7 @@ int whisper_full(
result_len = i + 1; result_len = i + 1;
} else { } else {
failed = true; failed = true;
break; continue;
} }
} }
@ -3436,14 +3460,14 @@ int whisper_full(
} }
completed = true; completed = true;
break; continue;
} }
// TESTS: if no tensors are loaded, it means we are running tests // TESTS: if no tensors are loaded, it means we are running tests
if (ctx->model.n_loaded == 0) { if (ctx->model.n_loaded == 0) {
seek_delta = 100*WHISPER_CHUNK_SIZE; seek_delta = 100*WHISPER_CHUNK_SIZE;
completed = true; completed = true;
break; continue;
} }
} }
@ -3451,7 +3475,7 @@ int whisper_full(
// this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) { if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
failed = true; failed = true;
break; continue;
} }
} }
@ -3496,7 +3520,7 @@ int whisper_full(
{ {
const int64_t t_start_sample_us = ggml_time_us(); const int64_t t_start_sample_us = ggml_time_us();
whisper_process_logits(*ctx, decoder, params); whisper_process_logits(*ctx, params, decoder, t_cur);
++decoder.n_past; ++decoder.n_past;
@ -3524,7 +3548,7 @@ int whisper_full(
best_decoder_id = j; best_decoder_id = j;
} }
fprintf(stderr, "%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f\n", __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs); //fprintf(stderr, "%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f\n", __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs);
} }
} }
@ -3541,6 +3565,11 @@ int whisper_full(
} }
if (success) { if (success) {
//fprintf(stderr, "%s: best decoder = %d\n", __func__, best_decoder_id);
//for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) {
// fprintf(stderr, "%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str());
//}
break; break;
} }

Loading…
Cancel
Save