gpt : fix sampling to use the temperature (close #16)

4bit
Georgi Gerganov 2 years ago
parent c40a5b51a0
commit fb64edddb7
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -347,7 +347,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
// - n_threads: number of threads to use // - n_threads: number of threads to use
// - n_past: the context size so far // - n_past: the context size so far
// - embd_inp: the embeddings of the tokens in the context // - embd_inp: the embeddings of the tokens in the context
// - embd_w: the predicted probabilities of the next token // - embd_w: the predicted logits for the next token
// //
bool gpt2_eval( bool gpt2_eval(
const gpt2_model & model, const gpt2_model & model,
@ -627,7 +627,7 @@ bool gpt2_eval(
inpL = ggml_mul_mat(ctx0, model.wte, inpL); inpL = ggml_mul_mat(ctx0, model.wte, inpL);
// logits -> probs // logits -> probs
inpL = ggml_soft_max(ctx0, inpL); //inpL = ggml_soft_max(ctx0, inpL);
// run the computation // run the computation
ggml_build_forward_expand(&gf, inpL); ggml_build_forward_expand(&gf, inpL);
@ -641,7 +641,7 @@ bool gpt2_eval(
//embd_w.resize(n_vocab*N); //embd_w.resize(n_vocab*N);
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N); //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
// return result for just the last token // return result just for the last token
embd_w.resize(n_vocab); embd_w.resize(n_vocab);
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
@ -698,7 +698,7 @@ int main(int argc, char ** argv) {
int64_t t_sample_us = 0; int64_t t_sample_us = 0;
int64_t t_predict_us = 0; int64_t t_predict_us = 0;
std::vector<float> embd_w; std::vector<float> logits;
// tokenize the prompt // tokenize the prompt
std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt); std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);
@ -714,14 +714,14 @@ int main(int argc, char ** argv) {
// determine the required inference memory per token: // determine the required inference memory per token:
size_t mem_per_token = 0; size_t mem_per_token = 0;
gpt2_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, embd_w, mem_per_token); gpt2_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) { for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
// predict // predict
if (embd.size() > 0) { if (embd.size() > 0) {
const int64_t t_start_us = ggml_time_us(); const int64_t t_start_us = ggml_time_us();
if (!gpt2_eval(model, params.n_threads, n_past, embd, embd_w, mem_per_token)) { if (!gpt2_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) {
printf("Failed to predict\n"); printf("Failed to predict\n");
return 1; return 1;
} }
@ -745,7 +745,7 @@ int main(int argc, char ** argv) {
{ {
const int64_t t_start_sample_us = ggml_time_us(); const int64_t t_start_sample_us = ggml_time_us();
id = gpt_sample_top_k_top_p(vocab, embd_w.data() + (embd_w.size() - n_vocab), top_k, top_p, temp, rng); id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng);
t_sample_us += ggml_time_us() - t_start_sample_us; t_sample_us += ggml_time_us() - t_start_sample_us;
} }

@ -355,7 +355,7 @@ bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab &
// - n_threads: number of threads to use // - n_threads: number of threads to use
// - n_past: the context size so far // - n_past: the context size so far
// - embd_inp: the embeddings of the tokens in the context // - embd_inp: the embeddings of the tokens in the context
// - embd_w: the predicted probabilities of the next token // - embd_w: the predicted logits for the next token
// //
// The GPT-J model requires about 16MB of memory per input token. // The GPT-J model requires about 16MB of memory per input token.
// //
@ -559,7 +559,7 @@ bool gptj_eval(
} }
// logits -> probs // logits -> probs
inpL = ggml_soft_max(ctx0, inpL); //inpL = ggml_soft_max(ctx0, inpL);
// run the computation // run the computation
ggml_build_forward_expand(&gf, inpL); ggml_build_forward_expand(&gf, inpL);
@ -630,7 +630,7 @@ int main(int argc, char ** argv) {
int64_t t_sample_us = 0; int64_t t_sample_us = 0;
int64_t t_predict_us = 0; int64_t t_predict_us = 0;
std::vector<float> embd_w; std::vector<float> logits;
// tokenize the prompt // tokenize the prompt
std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt); std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);
@ -644,14 +644,14 @@ int main(int argc, char ** argv) {
// determine the required inference memory per token: // determine the required inference memory per token:
size_t mem_per_token = 0; size_t mem_per_token = 0;
gptj_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, embd_w, mem_per_token); gptj_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) { for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
// predict // predict
if (embd.size() > 0) { if (embd.size() > 0) {
const int64_t t_start_us = ggml_time_us(); const int64_t t_start_us = ggml_time_us();
if (!gptj_eval(model, params.n_threads, n_past, embd, embd_w, mem_per_token)) { if (!gptj_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) {
printf("Failed to predict\n"); printf("Failed to predict\n");
return 1; return 1;
} }
@ -675,7 +675,7 @@ int main(int argc, char ** argv) {
{ {
const int64_t t_start_sample_us = ggml_time_us(); const int64_t t_start_sample_us = ggml_time_us();
id = gpt_sample_top_k_top_p(vocab, embd_w.data() + (embd_w.size() - n_vocab), top_k, top_p, temp, rng); id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng);
t_sample_us += ggml_time_us() - t_start_sample_us; t_sample_us += ggml_time_us() - t_start_sample_us;
} }

@ -261,8 +261,11 @@ gpt_vocab::id gpt_sample_top_k_top_p(
std::vector<std::pair<double, gpt_vocab::id>> logits_id; std::vector<std::pair<double, gpt_vocab::id>> logits_id;
logits_id.reserve(n_logits); logits_id.reserve(n_logits);
for (int i = 0; i < n_logits; i++) { {
logits_id.push_back(std::make_pair(logits[i], i)); const double scale = 1.0/temp;
for (int i = 0; i < n_logits; ++i) {
logits_id.push_back(std::make_pair(logits[i]*scale, i));
}
} }
// find the top K tokens // find the top K tokens
@ -275,59 +278,51 @@ gpt_vocab::id gpt_sample_top_k_top_p(
logits_id.resize(top_k); logits_id.resize(top_k);
// normalize double maxl = -INFINITY;
{ for (const auto & kv : logits_id) {
double sum = 0.0f; maxl = std::max(maxl, kv.first);
for (int i = 0; i < (int)logits_id.size(); i++) {
sum += logits_id[i].first;
} }
sum = 1.0/sum; // compute probs for the top K tokens
for (int i = 0; i < (int)logits_id.size(); i++) { std::vector<double> probs;
logits_id[i].first *= sum; probs.reserve(logits_id.size());
double sum = 0.0;
for (const auto & kv : logits_id) {
double p = exp(kv.first - maxl);
probs.push_back(p);
sum += p;
} }
// normalize the probs
for (auto & p : probs) {
p /= sum;
} }
if (top_p < 1.0f) { if (top_p < 1.0f) {
{
double cumsum = 0.0f; double cumsum = 0.0f;
for (int i = 0; i < top_k; i++) { for (int i = 0; i < top_k; i++) {
cumsum += logits_id[i].first; cumsum += probs[i];
if (cumsum >= top_p) { if (cumsum >= top_p) {
logits_id.resize(i+1); top_k = i + 1;
probs.resize(top_k);
logits_id.resize(top_k);
break; break;
} }
} }
}
// normalize again cumsum = 1.0/cumsum;
{ for (int i = 0; i < (int) probs.size(); i++) {
double sum = 0.0f; probs[i] *= cumsum;
for (int i = 0; i < (int)logits_id.size(); i++) {
sum += logits_id[i].first;
}
sum = 1.0/sum;
for (int i = 0; i < (int)logits_id.size(); i++) {
logits_id[i].first *= sum;
}
} }
} }
//printf("\n"); //printf("\n");
//for (int i = 0; i < (int)logits_id.size(); i++) { //for (int i = 0; i < (int) probs.size(); i++) {
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), logits_id[i].first); // printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
//} //}
//exit(0); //exit(0);
// sample from the obtained distribution
std::vector<double> probs;
probs.reserve(logits_id.size());
for (int i = 0; i < (int) logits_id.size(); i++) {
probs.push_back(logits_id[i].first);
}
std::discrete_distribution<> dist(probs.begin(), probs.end()); std::discrete_distribution<> dist(probs.begin(), probs.end());
int idx = dist(rng); int idx = dist(rng);

Loading…
Cancel
Save