From fb64edddb77fe10c49778ad0ce75ea9910e3b46d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 15 Jan 2023 15:53:08 +0200 Subject: [PATCH] gpt : fix sampling to use the temperature (close #16) --- examples/gpt-2/main.cpp | 14 ++++---- examples/gpt-j/main.cpp | 12 +++---- examples/utils.cpp | 77 +++++++++++++++++++---------------------- 3 files changed, 49 insertions(+), 54 deletions(-) diff --git a/examples/gpt-2/main.cpp b/examples/gpt-2/main.cpp index 333d93b..134a930 100644 --- a/examples/gpt-2/main.cpp +++ b/examples/gpt-2/main.cpp @@ -347,7 +347,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & // - n_threads: number of threads to use // - n_past: the context size so far // - embd_inp: the embeddings of the tokens in the context -// - embd_w: the predicted probabilities of the next token +// - embd_w: the predicted logits for the next token // bool gpt2_eval( const gpt2_model & model, @@ -627,7 +627,7 @@ bool gpt2_eval( inpL = ggml_mul_mat(ctx0, model.wte, inpL); // logits -> probs - inpL = ggml_soft_max(ctx0, inpL); + //inpL = ggml_soft_max(ctx0, inpL); // run the computation ggml_build_forward_expand(&gf, inpL); @@ -641,7 +641,7 @@ bool gpt2_eval( //embd_w.resize(n_vocab*N); //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N); - // return result for just the last token + // return result just for the last token embd_w.resize(n_vocab); memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); @@ -698,7 +698,7 @@ int main(int argc, char ** argv) { int64_t t_sample_us = 0; int64_t t_predict_us = 0; - std::vector embd_w; + std::vector logits; // tokenize the prompt std::vector embd_inp = ::gpt_tokenize(vocab, params.prompt); @@ -714,14 +714,14 @@ int main(int argc, char ** argv) { // determine the required inference memory per token: size_t mem_per_token = 0; - gpt2_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, embd_w, mem_per_token); + gpt2_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token); for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) { // predict if (embd.size() > 0) { const int64_t t_start_us = ggml_time_us(); - if (!gpt2_eval(model, params.n_threads, n_past, embd, embd_w, mem_per_token)) { + if (!gpt2_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) { printf("Failed to predict\n"); return 1; } @@ -745,7 +745,7 @@ int main(int argc, char ** argv) { { const int64_t t_start_sample_us = ggml_time_us(); - id = gpt_sample_top_k_top_p(vocab, embd_w.data() + (embd_w.size() - n_vocab), top_k, top_p, temp, rng); + id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng); t_sample_us += ggml_time_us() - t_start_sample_us; } diff --git a/examples/gpt-j/main.cpp b/examples/gpt-j/main.cpp index c55157d..63248d7 100644 --- a/examples/gpt-j/main.cpp +++ b/examples/gpt-j/main.cpp @@ -355,7 +355,7 @@ bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab & // - n_threads: number of threads to use // - n_past: the context size so far // - embd_inp: the embeddings of the tokens in the context -// - embd_w: the predicted probabilities of the next token +// - embd_w: the predicted logits for the next token // // The GPT-J model requires about 16MB of memory per input token. // @@ -559,7 +559,7 @@ bool gptj_eval( } // logits -> probs - inpL = ggml_soft_max(ctx0, inpL); + //inpL = ggml_soft_max(ctx0, inpL); // run the computation ggml_build_forward_expand(&gf, inpL); @@ -630,7 +630,7 @@ int main(int argc, char ** argv) { int64_t t_sample_us = 0; int64_t t_predict_us = 0; - std::vector embd_w; + std::vector logits; // tokenize the prompt std::vector embd_inp = ::gpt_tokenize(vocab, params.prompt); @@ -644,14 +644,14 @@ int main(int argc, char ** argv) { // determine the required inference memory per token: size_t mem_per_token = 0; - gptj_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, embd_w, mem_per_token); + gptj_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token); for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) { // predict if (embd.size() > 0) { const int64_t t_start_us = ggml_time_us(); - if (!gptj_eval(model, params.n_threads, n_past, embd, embd_w, mem_per_token)) { + if (!gptj_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) { printf("Failed to predict\n"); return 1; } @@ -675,7 +675,7 @@ int main(int argc, char ** argv) { { const int64_t t_start_sample_us = ggml_time_us(); - id = gpt_sample_top_k_top_p(vocab, embd_w.data() + (embd_w.size() - n_vocab), top_k, top_p, temp, rng); + id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng); t_sample_us += ggml_time_us() - t_start_sample_us; } diff --git a/examples/utils.cpp b/examples/utils.cpp index 1cd59d9..30057b7 100644 --- a/examples/utils.cpp +++ b/examples/utils.cpp @@ -261,8 +261,11 @@ gpt_vocab::id gpt_sample_top_k_top_p( std::vector> logits_id; logits_id.reserve(n_logits); - for (int i = 0; i < n_logits; i++) { - logits_id.push_back(std::make_pair(logits[i], i)); + { + const double scale = 1.0/temp; + for (int i = 0; i < n_logits; ++i) { + logits_id.push_back(std::make_pair(logits[i]*scale, i)); + } } // find the top K tokens @@ -275,59 +278,51 @@ gpt_vocab::id gpt_sample_top_k_top_p( logits_id.resize(top_k); - // normalize - { - double sum = 0.0f; - for (int i = 0; i < (int)logits_id.size(); i++) { - sum += logits_id[i].first; - } + double maxl = -INFINITY; + for (const auto & kv : logits_id) { + maxl = std::max(maxl, kv.first); + } - sum = 1.0/sum; - for (int i = 0; i < (int)logits_id.size(); i++) { - logits_id[i].first *= sum; - } + // compute probs for the top K tokens + std::vector probs; + probs.reserve(logits_id.size()); + + double sum = 0.0; + for (const auto & kv : logits_id) { + double p = exp(kv.first - maxl); + probs.push_back(p); + sum += p; + } + + // normalize the probs + for (auto & p : probs) { + p /= sum; } if (top_p < 1.0f) { - { - double cumsum = 0.0f; - for (int i = 0; i < top_k; i++) { - cumsum += logits_id[i].first; - if (cumsum >= top_p) { - logits_id.resize(i+1); - break; - } + double cumsum = 0.0f; + for (int i = 0; i < top_k; i++) { + cumsum += probs[i]; + if (cumsum >= top_p) { + top_k = i + 1; + probs.resize(top_k); + logits_id.resize(top_k); + break; } } - // normalize again - { - double sum = 0.0f; - for (int i = 0; i < (int)logits_id.size(); i++) { - sum += logits_id[i].first; - } - - sum = 1.0/sum; - for (int i = 0; i < (int)logits_id.size(); i++) { - logits_id[i].first *= sum; - } + cumsum = 1.0/cumsum; + for (int i = 0; i < (int) probs.size(); i++) { + probs[i] *= cumsum; } } //printf("\n"); - //for (int i = 0; i < (int)logits_id.size(); i++) { - // printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), logits_id[i].first); + //for (int i = 0; i < (int) probs.size(); i++) { + // printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]); //} //exit(0); - // sample from the obtained distribution - std::vector probs; - probs.reserve(logits_id.size()); - - for (int i = 0; i < (int) logits_id.size(); i++) { - probs.push_back(logits_id[i].first); - } - std::discrete_distribution<> dist(probs.begin(), probs.end()); int idx = dist(rng);