From 02f0c6fe7f9b7be24c7d339aed016e54a92388ea Mon Sep 17 00:00:00 2001 From: beiller Date: Sun, 12 Mar 2023 16:23:15 -0400 Subject: [PATCH] Add back top_k (#56) * Add back top_k * Update utils.cpp * Update utils.h --------- Co-authored-by: Bill Hamilton Co-authored-by: Georgi Gerganov --- main.cpp | 3 ++- utils.cpp | 79 +++---------------------------------------------------- utils.h | 19 +++++-------- 3 files changed, 12 insertions(+), 89 deletions(-) diff --git a/main.cpp b/main.cpp index a11d755..0155614 100644 --- a/main.cpp +++ b/main.cpp @@ -825,6 +825,7 @@ int main(int argc, char ** argv) { if (i >= embd_inp.size()) { // sample next token + const float top_k = params.top_k; const float top_p = params.top_p; const float temp = params.temp; const float repeat_penalty = params.repeat_penalty; @@ -836,7 +837,7 @@ int main(int argc, char ** argv) { { const int64_t t_start_sample_us = ggml_time_us(); - id = llama_sample_top_p(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_p, temp, rng); + id = llama_sample_top_p_top_k(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_k, top_p, temp, rng); last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(id); diff --git a/utils.cpp b/utils.cpp index 58e7070..5435d47 100644 --- a/utils.cpp +++ b/utils.cpp @@ -301,25 +301,8 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) { return true; } -gpt_vocab::id gpt_sample_top_k_top_p( - const gpt_vocab & vocab, - const float * logits, - int top_k, - double top_p, - double temp, - std::mt19937 & rng) { - int n_logits = vocab.id_to_token.size(); - - std::vector> logits_id; - logits_id.reserve(n_logits); - - { - const double scale = 1.0/temp; - for (int i = 0; i < n_logits; ++i) { - logits_id.push_back(std::make_pair(logits[i]*scale, i)); - } - } +void sample_top_k(std::vector> & logits_id, int top_k) { // find the top K tokens std::partial_sort( logits_id.begin(), @@ -329,63 +312,14 @@ gpt_vocab::id gpt_sample_top_k_top_p( }); logits_id.resize(top_k); - - double maxl = -INFINITY; - for (const auto & kv : logits_id) { - maxl = std::max(maxl, kv.first); - } - - // compute probs for the top K tokens - std::vector probs; - probs.reserve(logits_id.size()); - - double sum = 0.0; - for (const auto & kv : logits_id) { - double p = exp(kv.first - maxl); - probs.push_back(p); - sum += p; - } - - // normalize the probs - for (auto & p : probs) { - p /= sum; - } - - if (top_p < 1.0f) { - double cumsum = 0.0f; - for (int i = 0; i < top_k; i++) { - cumsum += probs[i]; - if (cumsum >= top_p) { - top_k = i + 1; - probs.resize(top_k); - logits_id.resize(top_k); - break; - } - } - - cumsum = 1.0/cumsum; - for (int i = 0; i < (int) probs.size(); i++) { - probs[i] *= cumsum; - } - } - - //printf("\n"); - //for (int i = 0; i < (int) probs.size(); i++) { - // printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]); - //} - //exit(0); - - std::discrete_distribution<> dist(probs.begin(), probs.end()); - int idx = dist(rng); - - return logits_id[idx].second; } -gpt_vocab::id llama_sample_top_p( +gpt_vocab::id llama_sample_top_p_top_k( const gpt_vocab & vocab, const float * logits, std::vector & last_n_tokens, double repeat_penalty, + int top_k, double top_p, double temp, std::mt19937 & rng) { @@ -412,12 +346,7 @@ gpt_vocab::id llama_sample_top_p( } } - std::sort( - logits_id.begin(), - logits_id.end(), - [](const std::pair & a, const std::pair & b) { - return a.first > b.first; - }); + sample_top_k(logits_id, top_k); double maxl = -INFINITY; for (const auto & kv : logits_id) { diff --git a/utils.h b/utils.h index e331904..5b3d736 100644 --- a/utils.h +++ b/utils.h @@ -19,7 +19,7 @@ struct gpt_params { int32_t repeat_last_n = 64; // last n tokens to penalize // sampling parameters - int32_t top_k = 40; // unused + int32_t top_k = 40; float top_p = 0.95f; float temp = 0.80f; float repeat_penalty = 1.30f; @@ -77,26 +77,19 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab); // - consider only the top K tokens // - from them, consider only the top tokens with cumulative probability > P // -// TODO: not sure if this implementation is correct -// TODO: temperature is not implemented -// -gpt_vocab::id gpt_sample_top_k_top_p( - const gpt_vocab & vocab, - const float * logits, - int top_k, - double top_p, - double temp, - std::mt19937 & rng); - -gpt_vocab::id llama_sample_top_p( +gpt_vocab::id llama_sample_top_p_top_k( const gpt_vocab & vocab, const float * logits, std::vector & last_n_tokens, double repeat_penalty, + int top_k, double top_p, double temp, std::mt19937 & rng); +// filer to top K tokens from list of logits +void sample_top_k(std::vector> & logits_id, int top_k); + // // Quantization //