From 78651d5792616e7df5a68e0397d131da5b138ef6 Mon Sep 17 00:00:00 2001 From: beiller Date: Sat, 11 Mar 2023 14:24:32 -0500 Subject: [PATCH] Update utils.cpp --- utils.cpp | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/utils.cpp b/utils.cpp index abb3475..59fd05a 100644 --- a/utils.cpp +++ b/utils.cpp @@ -23,6 +23,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.top_p = std::stof(argv[++i]); } else if (arg == "--temp") { params.temp = std::stof(argv[++i]); + } else if (arg == "--repeat_last_n") { + params.repeat_last_n = std::stoi(argv[++i]); + } else if (arg == "--repeat_penalty") { + params.repeat_penalty = std::stof(argv[++i]); } else if (arg == "-b" || arg == "--batch_size") { params.n_batch = std::stoi(argv[++i]); } else if (arg == "-m" || arg == "--model") { @@ -52,6 +56,8 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params) { fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d)\n", params.n_predict); fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k); fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p); + fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n); + fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", params.repeat_penalty); fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp); fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch); fprintf(stderr, " -m FNAME, --model FNAME\n"); @@ -372,6 +378,8 @@ gpt_vocab::id gpt_sample_top_k_top_p( gpt_vocab::id llama_sample_top_p( const gpt_vocab & vocab, const float * logits, + std::vector & last_n_tokens, + double repeat_penalty, double top_p, double temp, std::mt19937 & rng) { @@ -383,7 +391,11 @@ gpt_vocab::id llama_sample_top_p( { const double scale = 1.0/temp; for (int i = 0; i < n_logits; ++i) { - logits_id.push_back(std::make_pair(logits[i]*scale, i)); + if ( std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end() ) { + logits_id.push_back(std::make_pair(logits[i]*(1/repeat_penalty), i)); + } else { + logits_id.push_back(std::make_pair(logits[i]*scale, i)); + } } }