@ -792,7 +792,7 @@ int main(int argc, char ** argv) {
printf ( " %6d -> '%s' \n " , embd_inp [ i ] , vocab . id_to_token . at ( embd_inp [ i ] ) . c_str ( ) ) ;
}
printf ( " \n " ) ;
printf ( " sampling parameters: temp = %f, top_k = %d, top_p = %f \n " , params . temp , params . top_k , params . top_p ) ;
printf ( " sampling parameters: temp = %f, top_k = %d, top_p = %f , repeat_last_n = %i, repeat_penalty = %f \n " , params . temp , params . top_k , params . top_p , params . repeat_last_n , params . repeat_penalty ) ;
printf ( " \n \n " ) ;
std : : vector < gpt_vocab : : id > embd ;
@ -801,6 +801,10 @@ int main(int argc, char ** argv) {
size_t mem_per_token = 0 ;
llama_eval ( model , params . n_threads , 0 , { 0 , 1 , 2 , 3 } , logits , mem_per_token ) ;
int last_n_size = params . repeat_last_n ;
std : : vector < gpt_vocab : : id > last_n_tokens ( last_n_size ) ;
std : : fill ( last_n_tokens . begin ( ) , last_n_tokens . end ( ) , 0 ) ;
for ( int i = embd . size ( ) ; i < embd_inp . size ( ) + params . n_predict ; i + + ) {
// predict
if ( embd . size ( ) > 0 ) {
@ -821,6 +825,7 @@ int main(int argc, char ** argv) {
// sample next token
const float top_p = params . top_p ;
const float temp = params . temp ;
const float repeat_penalty = params . repeat_penalty ;
const int n_vocab = model . hparams . n_vocab ;
@ -829,7 +834,10 @@ int main(int argc, char ** argv) {
{
const int64_t t_start_sample_us = ggml_time_us ( ) ;
id = llama_sample_top_p ( vocab , logits . data ( ) + ( logits . size ( ) - n_vocab ) , top_p , temp , rng ) ;
id = llama_sample_top_p ( vocab , logits . data ( ) + ( logits . size ( ) - n_vocab ) , last_n_tokens , repeat_penalty , top_p , temp , rng ) ;
last_n_tokens . erase ( last_n_tokens . begin ( ) ) ;
last_n_tokens . push_back ( id ) ;
t_sample_us + = ggml_time_us ( ) - t_start_sample_us ;
}
@ -840,6 +848,8 @@ int main(int argc, char ** argv) {
// if here, it means we are still processing the input prompt
for ( int k = i ; k < embd_inp . size ( ) ; k + + ) {
embd . push_back ( embd_inp [ k ] ) ;
last_n_tokens . erase ( last_n_tokens . begin ( ) ) ;
last_n_tokens . push_back ( embd_inp [ k ] ) ;
if ( embd . size ( ) > params . n_batch ) {
break ;
}