From 307dba3dd2e4be3e99d61907d3be3a27d2ca492f Mon Sep 17 00:00:00 2001 From: wizard Date: Mon, 13 Mar 2023 10:00:01 +0800 Subject: [PATCH 1/5] first try to intergrate sentencepiece --- .gitignore | 1 + Makefile | 6 +++--- main.cpp | 20 ++++++++++++++++---- utils.cpp | 40 +++++++++++++++++++--------------------- 4 files changed, 39 insertions(+), 28 deletions(-) diff --git a/.gitignore b/.gitignore index 5eb1ff1..25111eb 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,4 @@ models/* arm_neon.h compile_commands.json +*.dSYM/ diff --git a/Makefile b/Makefile index 8388c29..35b627b 100644 --- a/Makefile +++ b/Makefile @@ -30,9 +30,9 @@ endif # Compile flags # -CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC -CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC -LDFLAGS = +CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC -g -I/opt/homebrew/include +CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC -g -I/opt/homebrew/include +LDFLAGS = -L/opt/homebrew/lib -lsentencepiece # OS specific # TODO: support Windows diff --git a/main.cpp b/main.cpp index d1defe2..31872c3 100644 --- a/main.cpp +++ b/main.cpp @@ -10,6 +10,12 @@ #include #include #include +#include +#include +#include +#include +#include +#include #include #include @@ -82,7 +88,7 @@ struct llama_model { }; // load the model's weights from a file -bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx) { +bool llama_model_load(const std::string & fname, llama_model & model, sentencepiece::SentencePieceProcessor & sp, gpt_vocab & vocab, int n_ctx) { printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str()); auto fin = std::ifstream(fname, std::ios::binary); @@ -144,6 +150,8 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab return false; } + printf("total pieces: %d", sp.GetPieceSize()); + std::string word; for (int i = 0; i < n_vocab; i++) { uint32_t len; @@ -152,8 +160,9 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab word.resize(len); fin.read((char *) word.data(), len); - vocab.token_to_id[word] = i; - vocab.id_to_token[i] = word; + std::string wordx = sp.IdToPiece(i); + vocab.token_to_id[wordx] = i; + vocab.id_to_token[i] = wordx; //if (i < 30000) { // printf("%s: vocab[%d] = '%s'\n", __func__, i, word.c_str()); @@ -764,6 +773,9 @@ int main(int argc, char ** argv) { gpt_params params; params.model = "models/llama-7B/ggml-model.bin"; + sentencepiece::SentencePieceProcessor sp; + sp.Load("./models/tokenizer.model"); + if (gpt_params_parse(argc, argv, params) == false) { return 1; } @@ -791,7 +803,7 @@ int main(int argc, char ** argv) { { const int64_t t_start_us = ggml_time_us(); - if (!llama_model_load(params.model, model, vocab, 512)) { // TODO: set context from user input ?? + if (!llama_model_load(params.model, model, sp, vocab, 512)) { // TODO: set context from user input ?? fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); return 1; } diff --git a/utils.cpp b/utils.cpp index b340bd6..b52f419 100644 --- a/utils.cpp +++ b/utils.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -281,33 +282,30 @@ std::vector llama_tokenize(const gpt_vocab & vocab, const std::st std::vector res; - if (bos) { - res.push_back(1); // TODO: replace with vocab.bos - } + // if (bos) { + // res.push_back(1); // TODO: replace with vocab.bos + // } - //find the longest token that matches the text - int pos = 0; - while (true) { - int l = 0; - int t = 0; - for (const auto & kv : vocab.id_to_token) { - if (kv.second.size() < l) continue; - if (kv.second.size() > text.size() - pos) continue; - if (text.substr(pos, kv.second.size()) == kv.second) { - l = kv.second.size(); - t = kv.first; - } - } + sentencepiece::SentencePieceProcessor sp; + sp.Load("./models/tokenizer.model"); - if (l == 0) { - break; + std::vector pieces; + return sp.EncodeAsIds(text); +/* + for (const auto & piece : pieces) { + printf("piece: %s\n", piece.c_str()); + if (vocab.token_to_id.count(piece) > 0) { + res.push_back(vocab.token_to_id.at(piece)); + } else { + // handle unknown token } + } - res.push_back(t); - pos += l; + for (const auto& id : res) { + printf("%d\n", id); } - return res; + return res;*/ } bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) { From 1b87fe1e90c4e519586e28a0096d00e3008747b6 Mon Sep 17 00:00:00 2001 From: wizard Date: Mon, 13 Mar 2023 10:00:34 +0800 Subject: [PATCH 2/5] call a standalone function to untokenize output --- main.cpp | 7 ++++--- utils.cpp | 63 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ utils.h | 5 +++++ 3 files changed, 72 insertions(+), 3 deletions(-) diff --git a/main.cpp b/main.cpp index 31872c3..d3debfd 100644 --- a/main.cpp +++ b/main.cpp @@ -947,9 +947,10 @@ int main(int argc, char ** argv) { // display text if (!input_noecho) { - for (auto id : embd) { - printf("%s", vocab.id_to_token[id].c_str()); - } + untokenize(sp, embd); + // for (auto id : embd) { + // printf("%s", vocab.id_to_token[id].c_str()); + // } // reset color to default if we there is no pending user input if (params.use_color && embd_inp.size() <= input_consumed) { printf(ANSI_COLOR_RESET); diff --git a/utils.cpp b/utils.cpp index b52f419..8a75039 100644 --- a/utils.cpp +++ b/utils.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -540,3 +541,65 @@ size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t return (n/k)*row_size; } + +void untokenize(sentencepiece::SentencePieceProcessor & sp, std::vector & embd) +{ + // Convert the IDs in embd to tokens using SentencePiece + // std::vector pieces; + // for (const auto& id : embd) { + // //std::string s = sp.DecodeIds(id); + + // //s = std::regex_replace(s, std::regex("▁"), " "); + + // // if (s.find("<0x") == 0 && s[s.length() - 1] == '>') + // // { + // // s = sp.IdToPiece(id); + // // } + // //printf("%s", s.c_str()); + + // pieces.push_back(id); + // // if(s.length() > 1) + // // tokens.push_back(" "); + // } + // // Insert spaces between tokens + // // std::string text; + // // for (const auto& token : tokens) { + // // // Add a space before the token if it is not the first token and it doesn't start with a special character + // // if (!text.empty() && !(token[0] == '\0x25' && token[1] == '\0x81') && token[0] != ' ') { + // // text += ' '; + // // } + // // text += sp.DecodePieces(tokens); + // // } + // //sp.DecodeIds(embd); + // std::string text = + // sp.DecodeIds(pieces); + + // printf("%s", text.c_str()); + for (auto id : embd) { + std::string s = sp.IdToPiece(id); //vocab.id_to_token[id]; + + if (s.find("<0x") == 0 && s[s.length() - 1] == '>') + { + // Extract the hexadecimal value from the token + std::string hex_value = s.substr(s.find("0x")); + + // Convert the hexadecimal value to binary and print it + int decimal_value; + std::stringstream(hex_value) >> std::hex >> decimal_value; + std::bitset<8> binary_value(decimal_value); + + char* bytes = reinterpret_cast(&decimal_value); + printf("%s", bytes); + } + else if(s.find("▁") == 0) + { + s = std::regex_replace(s, std::regex("▁"), " "); + //s.replace(0, 2, 1, ' '); + printf("%s", s.c_str()); + } + else + { + printf("%s", s.c_str()); + } + } +} \ No newline at end of file diff --git a/utils.h b/utils.h index 4f98011..3904133 100644 --- a/utils.h +++ b/utils.h @@ -7,6 +7,8 @@ #include #include #include +#include +#include // // CLI argument parsing @@ -102,3 +104,6 @@ void sample_top_k(std::vector> & logits_id, int size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k, int qk, int64_t * hist); size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t * hist); + +void untokenize(sentencepiece::SentencePieceProcessor & sp, std::vector & embd); + From 86e967c54b126312090ac540585e9903b4d28efa Mon Sep 17 00:00:00 2001 From: wizard Date: Mon, 13 Mar 2023 13:06:01 +0800 Subject: [PATCH 3/5] buffering output for UTF-8 encoded token --- utils.cpp | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/utils.cpp b/utils.cpp index 8a75039..17cd917 100644 --- a/utils.cpp +++ b/utils.cpp @@ -544,6 +544,9 @@ size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t void untokenize(sentencepiece::SentencePieceProcessor & sp, std::vector & embd) { + // std::string output = sp.DecodeIds(embd); + // printf("%s", output.c_str()); + // return; // Convert the IDs in embd to tokens using SentencePiece // std::vector pieces; // for (const auto& id : embd) { @@ -575,6 +578,7 @@ void untokenize(sentencepiece::SentencePieceProcessor & sp, std::vector binary_value(decimal_value); char* bytes = reinterpret_cast(&decimal_value); - printf("%s", bytes); + buff = buff + std::string(bytes); + //printf("bufferring %s, total buffer: %s\n", s.c_str(), buff.c_str()); } else if(s.find("▁") == 0) { + if(!buff.empty()) + { + printf("%s", buff.c_str()); + buff = ""; + } s = std::regex_replace(s, std::regex("▁"), " "); //s.replace(0, 2, 1, ' '); printf("%s", s.c_str()); } else { + if(!buff.empty()) + { + printf("%s", buff.c_str()); + buff = ""; + } printf("%s", s.c_str()); } } From 15f06f6b4f2074da448851adfc1d887ea7cb76f0 Mon Sep 17 00:00:00 2001 From: wizard Date: Mon, 13 Mar 2023 16:27:11 +0800 Subject: [PATCH 4/5] buffering utf-8 output to make it complete for spliting output. --- main.cpp | 3 ++- utils.cpp | 14 ++++++++++---- utils.h | 2 +- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/main.cpp b/main.cpp index d3debfd..5ad5dc0 100644 --- a/main.cpp +++ b/main.cpp @@ -886,6 +886,7 @@ int main(int argc, char ** argv) { printf(ANSI_COLOR_YELLOW); } + std::vector buffids = {}; while (remaining_tokens > 0) { // predict if (embd.size() > 0) { @@ -947,7 +948,7 @@ int main(int argc, char ** argv) { // display text if (!input_noecho) { - untokenize(sp, embd); + untokenize(sp, buffids, embd); // for (auto id : embd) { // printf("%s", vocab.id_to_token[id].c_str()); // } diff --git a/utils.cpp b/utils.cpp index 17cd917..87ee980 100644 --- a/utils.cpp +++ b/utils.cpp @@ -542,7 +542,7 @@ size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t return (n/k)*row_size; } -void untokenize(sentencepiece::SentencePieceProcessor & sp, std::vector & embd) +void untokenize(sentencepiece::SentencePieceProcessor & sp, std::vector & buffids, std::vector & embd) { // std::string output = sp.DecodeIds(embd); // printf("%s", output.c_str()); @@ -578,12 +578,14 @@ void untokenize(sentencepiece::SentencePieceProcessor & sp, std::vector') { + buffids.push_back(id); // Extract the hexadecimal value from the token std::string hex_value = s.substr(s.find("0x")); @@ -600,7 +602,9 @@ void untokenize(sentencepiece::SentencePieceProcessor & sp, std::vector> & logits_id, int size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k, int qk, int64_t * hist); size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t * hist); -void untokenize(sentencepiece::SentencePieceProcessor & sp, std::vector & embd); +void untokenize(sentencepiece::SentencePieceProcessor & sp, std::vector & buffids, std::vector & embd); From ed10def70e5cb9c5d23f3d21fc239047ba919e29 Mon Sep 17 00:00:00 2001 From: wizard Date: Mon, 13 Mar 2023 17:02:06 +0800 Subject: [PATCH 5/5] clean code --- main.cpp | 5 ++- utils.cpp | 103 +++++++++++++++--------------------------------------- 2 files changed, 30 insertions(+), 78 deletions(-) diff --git a/main.cpp b/main.cpp index 5ad5dc0..820e40a 100644 --- a/main.cpp +++ b/main.cpp @@ -886,6 +886,7 @@ int main(int argc, char ** argv) { printf(ANSI_COLOR_YELLOW); } + // buffering UTF-8 tokens like <0xE6>,<0xAC><0xA2> spanning across multiple output to make it complete. std::vector buffids = {}; while (remaining_tokens > 0) { // predict @@ -949,9 +950,7 @@ int main(int argc, char ** argv) { // display text if (!input_noecho) { untokenize(sp, buffids, embd); - // for (auto id : embd) { - // printf("%s", vocab.id_to_token[id].c_str()); - // } + // reset color to default if we there is no pending user input if (params.use_color && embd_inp.size() <= input_consumed) { printf(ANSI_COLOR_RESET); diff --git a/utils.cpp b/utils.cpp index 87ee980..36cb95e 100644 --- a/utils.cpp +++ b/utils.cpp @@ -542,85 +542,38 @@ size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t return (n/k)*row_size; } -void untokenize(sentencepiece::SentencePieceProcessor & sp, std::vector & buffids, std::vector & embd) +void untokenize(sentencepiece::SentencePieceProcessor &sp, std::vector &buffids, std::vector &embd) { - // std::string output = sp.DecodeIds(embd); - // printf("%s", output.c_str()); - // return; - // Convert the IDs in embd to tokens using SentencePiece - // std::vector pieces; - // for (const auto& id : embd) { - // //std::string s = sp.DecodeIds(id); - - // //s = std::regex_replace(s, std::regex("▁"), " "); - - // // if (s.find("<0x") == 0 && s[s.length() - 1] == '>') - // // { - // // s = sp.IdToPiece(id); - // // } - // //printf("%s", s.c_str()); - - // pieces.push_back(id); - // // if(s.length() > 1) - // // tokens.push_back(" "); - // } - // // Insert spaces between tokens - // // std::string text; - // // for (const auto& token : tokens) { - // // // Add a space before the token if it is not the first token and it doesn't start with a special character - // // if (!text.empty() && !(token[0] == '\0x25' && token[1] == '\0x81') && token[0] != ' ') { - // // text += ' '; - // // } - // // text += sp.DecodePieces(tokens); - // // } - // //sp.DecodeIds(embd); - // std::string text = - // sp.DecodeIds(pieces); - - // printf("%s", text.c_str()); - - std::string buff; - for (auto id : embd) { - std::string s = sp.IdToPiece(id); //vocab.id_to_token[id]; - - if (s.find("<0x") == 0 && s[s.length() - 1] == '>') - { - buffids.push_back(id); - // Extract the hexadecimal value from the token - std::string hex_value = s.substr(s.find("0x")); - - // Convert the hexadecimal value to binary and print it - int decimal_value; - std::stringstream(hex_value) >> std::hex >> decimal_value; - std::bitset<8> binary_value(decimal_value); - - char* bytes = reinterpret_cast(&decimal_value); - buff = buff + std::string(bytes); - //printf("bufferring %s, total buffer: %s\n", s.c_str(), buff.c_str()); - } - else if(s.find("▁") == 0) + for (auto id : embd) + { + std::string s = sp.IdToPiece(id); // vocab.id_to_token[id]; + + if (s.find("<0x") == 0 && s[s.length() - 1] == '>') + { + buffids.push_back(id); + std::string txt = sp.DecodeIds(buffids); + // printf("bufferring %s, total buffer: %s\n", s.c_str(), txt.c_str()); + } + else if (s.find("▁") == 0) + { + if (!buffids.empty()) { - if(!buff.empty()) - { - std::string txt = sp.DecodeIds(buffids); - printf("%s", txt.c_str()); - buffids.clear(); - buff = ""; - } - s = std::regex_replace(s, std::regex("▁"), " "); - //s.replace(0, 2, 1, ' '); - printf("%s", s.c_str()); + std::string txt = sp.DecodeIds(buffids); + printf("%s", txt.c_str()); + buffids.clear(); } - else + s = std::regex_replace(s, std::regex("▁"), " "); + printf("%s", s.c_str()); + } + else + { + if (!buffids.empty()) { - if(!buff.empty()) - { - std::string txt = sp.DecodeIds(buffids); - printf("%s", txt.c_str()); - buffids.clear(); - buff = ""; - } - printf("%s", s.c_str()); + std::string txt = sp.DecodeIds(buffids); + printf("%s", txt.c_str()); + buffids.clear(); } + printf("%s", s.c_str()); } + } } \ No newline at end of file