diff --git a/utils.cpp b/utils.cpp index eceab45..b6c9493 100644 --- a/utils.cpp +++ b/utils.cpp @@ -391,8 +391,16 @@ gpt_vocab::id llama_sample_top_p( { const double scale = 1.0/temp; for (int i = 0; i < n_logits; ++i) { + // repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) + // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main if ( std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end() ) { - logits_id.push_back(std::make_pair(logits[i]*scale*(1/repeat_penalty), i)); + // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability + if(logits[i] < 0.0) { + logits_id.push_back(std::make_pair(logits[i]*scale*repeat_penalty, i)); + } else { + logits_id.push_back(std::make_pair(logits[i]*scale/repeat_penalty, i)); + } + } else { logits_id.push_back(std::make_pair(logits[i]*scale, i)); }