|
|
@ -347,7 +347,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
|
|
|
// - n_threads: number of threads to use
|
|
|
|
// - n_threads: number of threads to use
|
|
|
|
// - n_past: the context size so far
|
|
|
|
// - n_past: the context size so far
|
|
|
|
// - embd_inp: the embeddings of the tokens in the context
|
|
|
|
// - embd_inp: the embeddings of the tokens in the context
|
|
|
|
// - embd_w: the predicted probabilities of the next token
|
|
|
|
// - embd_w: the predicted logits for the next token
|
|
|
|
//
|
|
|
|
//
|
|
|
|
bool gpt2_eval(
|
|
|
|
bool gpt2_eval(
|
|
|
|
const gpt2_model & model,
|
|
|
|
const gpt2_model & model,
|
|
|
@ -627,7 +627,7 @@ bool gpt2_eval(
|
|
|
|
inpL = ggml_mul_mat(ctx0, model.wte, inpL);
|
|
|
|
inpL = ggml_mul_mat(ctx0, model.wte, inpL);
|
|
|
|
|
|
|
|
|
|
|
|
// logits -> probs
|
|
|
|
// logits -> probs
|
|
|
|
inpL = ggml_soft_max(ctx0, inpL);
|
|
|
|
//inpL = ggml_soft_max(ctx0, inpL);
|
|
|
|
|
|
|
|
|
|
|
|
// run the computation
|
|
|
|
// run the computation
|
|
|
|
ggml_build_forward_expand(&gf, inpL);
|
|
|
|
ggml_build_forward_expand(&gf, inpL);
|
|
|
@ -641,7 +641,7 @@ bool gpt2_eval(
|
|
|
|
//embd_w.resize(n_vocab*N);
|
|
|
|
//embd_w.resize(n_vocab*N);
|
|
|
|
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
|
|
|
|
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
|
|
|
|
|
|
|
|
|
|
|
|
// return result for just the last token
|
|
|
|
// return result just for the last token
|
|
|
|
embd_w.resize(n_vocab);
|
|
|
|
embd_w.resize(n_vocab);
|
|
|
|
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
|
|
|
|
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
|
|
|
|
|
|
|
|
|
|
|
@ -698,7 +698,7 @@ int main(int argc, char ** argv) {
|
|
|
|
int64_t t_sample_us = 0;
|
|
|
|
int64_t t_sample_us = 0;
|
|
|
|
int64_t t_predict_us = 0;
|
|
|
|
int64_t t_predict_us = 0;
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<float> embd_w;
|
|
|
|
std::vector<float> logits;
|
|
|
|
|
|
|
|
|
|
|
|
// tokenize the prompt
|
|
|
|
// tokenize the prompt
|
|
|
|
std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);
|
|
|
|
std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);
|
|
|
@ -714,14 +714,14 @@ int main(int argc, char ** argv) {
|
|
|
|
|
|
|
|
|
|
|
|
// determine the required inference memory per token:
|
|
|
|
// determine the required inference memory per token:
|
|
|
|
size_t mem_per_token = 0;
|
|
|
|
size_t mem_per_token = 0;
|
|
|
|
gpt2_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, embd_w, mem_per_token);
|
|
|
|
gpt2_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
|
|
|
|
for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
|
|
|
|
// predict
|
|
|
|
// predict
|
|
|
|
if (embd.size() > 0) {
|
|
|
|
if (embd.size() > 0) {
|
|
|
|
const int64_t t_start_us = ggml_time_us();
|
|
|
|
const int64_t t_start_us = ggml_time_us();
|
|
|
|
|
|
|
|
|
|
|
|
if (!gpt2_eval(model, params.n_threads, n_past, embd, embd_w, mem_per_token)) {
|
|
|
|
if (!gpt2_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) {
|
|
|
|
printf("Failed to predict\n");
|
|
|
|
printf("Failed to predict\n");
|
|
|
|
return 1;
|
|
|
|
return 1;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -745,7 +745,7 @@ int main(int argc, char ** argv) {
|
|
|
|
{
|
|
|
|
{
|
|
|
|
const int64_t t_start_sample_us = ggml_time_us();
|
|
|
|
const int64_t t_start_sample_us = ggml_time_us();
|
|
|
|
|
|
|
|
|
|
|
|
id = gpt_sample_top_k_top_p(vocab, embd_w.data() + (embd_w.size() - n_vocab), top_k, top_p, temp, rng);
|
|
|
|
id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng);
|
|
|
|
|
|
|
|
|
|
|
|
t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
|
|
t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
|
|
}
|
|
|
|
}
|
|
|
|