@ -301,25 +301,8 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
return true ;
return true ;
}
}
gpt_vocab : : id gpt_sample_top_k_top_p (
const gpt_vocab & vocab ,
const float * logits ,
int top_k ,
double top_p ,
double temp ,
std : : mt19937 & rng ) {
int n_logits = vocab . id_to_token . size ( ) ;
std : : vector < std : : pair < double , gpt_vocab : : id > > logits_id ;
logits_id . reserve ( n_logits ) ;
{
const double scale = 1.0 / temp ;
for ( int i = 0 ; i < n_logits ; + + i ) {
logits_id . push_back ( std : : make_pair ( logits [ i ] * scale , i ) ) ;
}
}
void sample_top_k ( std : : vector < std : : pair < double , gpt_vocab : : id > > & logits_id , int top_k ) {
// find the top K tokens
// find the top K tokens
std : : partial_sort (
std : : partial_sort (
logits_id . begin ( ) ,
logits_id . begin ( ) ,
@ -329,63 +312,14 @@ gpt_vocab::id gpt_sample_top_k_top_p(
} ) ;
} ) ;
logits_id . resize ( top_k ) ;
logits_id . resize ( top_k ) ;
double maxl = - INFINITY ;
for ( const auto & kv : logits_id ) {
maxl = std : : max ( maxl , kv . first ) ;
}
// compute probs for the top K tokens
std : : vector < double > probs ;
probs . reserve ( logits_id . size ( ) ) ;
double sum = 0.0 ;
for ( const auto & kv : logits_id ) {
double p = exp ( kv . first - maxl ) ;
probs . push_back ( p ) ;
sum + = p ;
}
// normalize the probs
for ( auto & p : probs ) {
p / = sum ;
}
if ( top_p < 1.0f ) {
double cumsum = 0.0f ;
for ( int i = 0 ; i < top_k ; i + + ) {
cumsum + = probs [ i ] ;
if ( cumsum > = top_p ) {
top_k = i + 1 ;
probs . resize ( top_k ) ;
logits_id . resize ( top_k ) ;
break ;
}
}
cumsum = 1.0 / cumsum ;
for ( int i = 0 ; i < ( int ) probs . size ( ) ; i + + ) {
probs [ i ] * = cumsum ;
}
}
}
//printf("\n");
//for (int i = 0; i < (int) probs.size(); i++) {
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
//}
//exit(0);
std : : discrete_distribution < > dist ( probs . begin ( ) , probs . end ( ) ) ;
int idx = dist ( rng ) ;
return logits_id [ idx ] . second ;
gpt_vocab : : id llama_sample_top_p_top_k (
}
gpt_vocab : : id llama_sample_top_p (
const gpt_vocab & vocab ,
const gpt_vocab & vocab ,
const float * logits ,
const float * logits ,
std : : vector < gpt_vocab : : id > & last_n_tokens ,
std : : vector < gpt_vocab : : id > & last_n_tokens ,
double repeat_penalty ,
double repeat_penalty ,
int top_k ,
double top_p ,
double top_p ,
double temp ,
double temp ,
std : : mt19937 & rng ) {
std : : mt19937 & rng ) {
@ -412,12 +346,7 @@ gpt_vocab::id llama_sample_top_p(
}
}
}
}
std : : sort (
sample_top_k ( logits_id , top_k ) ;
logits_id . begin ( ) ,
logits_id . end ( ) ,
[ ] ( const std : : pair < double , gpt_vocab : : id > & a , const std : : pair < double , gpt_vocab : : id > & b ) {
return a . first > b . first ;
} ) ;
double maxl = - INFINITY ;
double maxl = - INFINITY ;
for ( const auto & kv : logits_id ) {
for ( const auto & kv : logits_id ) {