@ -23,6 +23,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params . top_p = std : : stof ( argv [ + + i ] ) ;
params . top_p = std : : stof ( argv [ + + i ] ) ;
} else if ( arg = = " --temp " ) {
} else if ( arg = = " --temp " ) {
params . temp = std : : stof ( argv [ + + i ] ) ;
params . temp = std : : stof ( argv [ + + i ] ) ;
} else if ( arg = = " --repeat_last_n " ) {
params . repeat_last_n = std : : stoi ( argv [ + + i ] ) ;
} else if ( arg = = " --repeat_penalty " ) {
params . repeat_penalty = std : : stof ( argv [ + + i ] ) ;
} else if ( arg = = " -b " | | arg = = " --batch_size " ) {
} else if ( arg = = " -b " | | arg = = " --batch_size " ) {
params . n_batch = std : : stoi ( argv [ + + i ] ) ;
params . n_batch = std : : stoi ( argv [ + + i ] ) ;
} else if ( arg = = " -m " | | arg = = " --model " ) {
} else if ( arg = = " -m " | | arg = = " --model " ) {
@ -52,6 +56,8 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params) {
fprintf ( stderr , " -n N, --n_predict N number of tokens to predict (default: %d) \n " , params . n_predict ) ;
fprintf ( stderr , " -n N, --n_predict N number of tokens to predict (default: %d) \n " , params . n_predict ) ;
fprintf ( stderr , " --top_k N top-k sampling (default: %d) \n " , params . top_k ) ;
fprintf ( stderr , " --top_k N top-k sampling (default: %d) \n " , params . top_k ) ;
fprintf ( stderr , " --top_p N top-p sampling (default: %.1f) \n " , params . top_p ) ;
fprintf ( stderr , " --top_p N top-p sampling (default: %.1f) \n " , params . top_p ) ;
fprintf ( stderr , " --repeat_last_n N last n tokens to consider for penalize (default: %d) \n " , params . repeat_last_n ) ;
fprintf ( stderr , " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f) \n " , params . repeat_penalty ) ;
fprintf ( stderr , " --temp N temperature (default: %.1f) \n " , params . temp ) ;
fprintf ( stderr , " --temp N temperature (default: %.1f) \n " , params . temp ) ;
fprintf ( stderr , " -b N, --batch_size N batch size for prompt processing (default: %d) \n " , params . n_batch ) ;
fprintf ( stderr , " -b N, --batch_size N batch size for prompt processing (default: %d) \n " , params . n_batch ) ;
fprintf ( stderr , " -m FNAME, --model FNAME \n " ) ;
fprintf ( stderr , " -m FNAME, --model FNAME \n " ) ;
@ -372,6 +378,8 @@ gpt_vocab::id gpt_sample_top_k_top_p(
gpt_vocab : : id llama_sample_top_p (
gpt_vocab : : id llama_sample_top_p (
const gpt_vocab & vocab ,
const gpt_vocab & vocab ,
const float * logits ,
const float * logits ,
std : : vector < gpt_vocab : : id > & last_n_tokens ,
double repeat_penalty ,
double top_p ,
double top_p ,
double temp ,
double temp ,
std : : mt19937 & rng ) {
std : : mt19937 & rng ) {
@ -383,7 +391,18 @@ gpt_vocab::id llama_sample_top_p(
{
{
const double scale = 1.0 / temp ;
const double scale = 1.0 / temp ;
for ( int i = 0 ; i < n_logits ; + + i ) {
for ( int i = 0 ; i < n_logits ; + + i ) {
logits_id . push_back ( std : : make_pair ( logits [ i ] * scale , i ) ) ;
// repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
// credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
if ( std : : find ( last_n_tokens . begin ( ) , last_n_tokens . end ( ) , i ) ! = last_n_tokens . end ( ) ) {
// if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if ( logits [ i ] < 0.0 ) {
logits_id . push_back ( std : : make_pair ( logits [ i ] * scale * repeat_penalty , i ) ) ;
} else {
logits_id . push_back ( std : : make_pair ( logits [ i ] * scale / repeat_penalty , i ) ) ;
}
} else {
logits_id . push_back ( std : : make_pair ( logits [ i ] * scale , i ) ) ;
}
}
}
}
}