From 1b87fe1e90c4e519586e28a0096d00e3008747b6 Mon Sep 17 00:00:00 2001 From: wizard Date: Mon, 13 Mar 2023 10:00:34 +0800 Subject: [PATCH] call a standalone function to untokenize output --- main.cpp | 7 ++++--- utils.cpp | 63 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ utils.h | 5 +++++ 3 files changed, 72 insertions(+), 3 deletions(-) diff --git a/main.cpp b/main.cpp index 31872c3..d3debfd 100644 --- a/main.cpp +++ b/main.cpp @@ -947,9 +947,10 @@ int main(int argc, char ** argv) { // display text if (!input_noecho) { - for (auto id : embd) { - printf("%s", vocab.id_to_token[id].c_str()); - } + untokenize(sp, embd); + // for (auto id : embd) { + // printf("%s", vocab.id_to_token[id].c_str()); + // } // reset color to default if we there is no pending user input if (params.use_color && embd_inp.size() <= input_consumed) { printf(ANSI_COLOR_RESET); diff --git a/utils.cpp b/utils.cpp index b52f419..8a75039 100644 --- a/utils.cpp +++ b/utils.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -540,3 +541,65 @@ size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t return (n/k)*row_size; } + +void untokenize(sentencepiece::SentencePieceProcessor & sp, std::vector & embd) +{ + // Convert the IDs in embd to tokens using SentencePiece + // std::vector pieces; + // for (const auto& id : embd) { + // //std::string s = sp.DecodeIds(id); + + // //s = std::regex_replace(s, std::regex("▁"), " "); + + // // if (s.find("<0x") == 0 && s[s.length() - 1] == '>') + // // { + // // s = sp.IdToPiece(id); + // // } + // //printf("%s", s.c_str()); + + // pieces.push_back(id); + // // if(s.length() > 1) + // // tokens.push_back(" "); + // } + // // Insert spaces between tokens + // // std::string text; + // // for (const auto& token : tokens) { + // // // Add a space before the token if it is not the first token and it doesn't start with a special character + // // if (!text.empty() && !(token[0] == '\0x25' && token[1] == '\0x81') && token[0] != ' ') { + // // text += ' '; + // // } + // // text += sp.DecodePieces(tokens); + // // } + // //sp.DecodeIds(embd); + // std::string text = + // sp.DecodeIds(pieces); + + // printf("%s", text.c_str()); + for (auto id : embd) { + std::string s = sp.IdToPiece(id); //vocab.id_to_token[id]; + + if (s.find("<0x") == 0 && s[s.length() - 1] == '>') + { + // Extract the hexadecimal value from the token + std::string hex_value = s.substr(s.find("0x")); + + // Convert the hexadecimal value to binary and print it + int decimal_value; + std::stringstream(hex_value) >> std::hex >> decimal_value; + std::bitset<8> binary_value(decimal_value); + + char* bytes = reinterpret_cast(&decimal_value); + printf("%s", bytes); + } + else if(s.find("▁") == 0) + { + s = std::regex_replace(s, std::regex("▁"), " "); + //s.replace(0, 2, 1, ' '); + printf("%s", s.c_str()); + } + else + { + printf("%s", s.c_str()); + } + } +} \ No newline at end of file diff --git a/utils.h b/utils.h index 4f98011..3904133 100644 --- a/utils.h +++ b/utils.h @@ -7,6 +7,8 @@ #include #include #include +#include +#include // // CLI argument parsing @@ -102,3 +104,6 @@ void sample_top_k(std::vector> & logits_id, int size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k, int qk, int64_t * hist); size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t * hist); + +void untokenize(sentencepiece::SentencePieceProcessor & sp, std::vector & embd); +