@ -1846,7 +1846,9 @@ static bool whisper_decode(
// the most basic sampling scheme - select the top token
static whisper_token_data whisper_sample_best (
const whisper_vocab & vocab ,
const float * probs ) {
const float * probs ,
bool force_timestamp ,
bool is_initial ) {
whisper_token_data result = {
0 , 0 , 0.0f , 0.0f , 0.0f , - 1 , - 1 , 0.0f ,
} ;
@ -1869,7 +1871,18 @@ static whisper_token_data whisper_sample_best(
max_tx = std : : max ( max_tx , probs_id [ i ] . first ) ;
}
for ( int i = vocab . token_beg ; i < n_logits ; i + + ) {
const auto i0 = is_initial ? vocab . token_beg + 101 : vocab . token_beg ;
const auto i1 = is_initial ? vocab . token_beg + 101 : n_logits ;
// the initial timestamp cannot be larger than 100
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429
if ( is_initial ) {
for ( int i = i0 ; i < n_logits ; + + i ) {
probs_id [ i ] . first = - INFINITY ;
}
}
for ( int i = vocab . token_beg ; i < i1 ; i + + ) {
sum_ts + = probs_id [ i ] . first ;
if ( probs_id [ i ] . first > max_ts ) {
max_ts = probs_id [ i ] . first ;
@ -1879,7 +1892,7 @@ static whisper_token_data whisper_sample_best(
// if the probability sum of all timestamp tokens is higher than the max probability of the text tokens - sample a
// timestamp token
if ( sum_ts > max_tx ) {
if ( sum_ts > max_tx | | force_timestamp ) {
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
for ( int i = 0 ; i < vocab . token_beg ; i + + ) {
probs_id [ i ] . first = - INFINITY ;
@ -1921,39 +1934,6 @@ static whisper_token_data whisper_sample_best(
return result ;
}
// samples only from the timestamps tokens
static whisper_vocab : : id whisper_sample_timestamp (
const whisper_vocab & vocab ,
const float * probs ) {
int n_logits = vocab . id_to_token . size ( ) ;
std : : vector < std : : pair < double , whisper_vocab : : id > > probs_id ;
probs_id . reserve ( n_logits ) ;
for ( int i = vocab . token_beg + 1 ; i < n_logits ; i + + ) {
probs_id . push_back ( std : : make_pair ( probs [ i ] , i ) ) ;
}
const int top_k = 10 ;
// find the top K tokens
std : : partial_sort (
probs_id . begin ( ) ,
probs_id . begin ( ) + top_k , probs_id . end ( ) ,
[ ] ( const std : : pair < double , whisper_vocab : : id > & a , const std : : pair < double , whisper_vocab : : id > & b ) {
return a . first > b . first ;
} ) ;
probs_id . resize ( top_k ) ;
//printf("\n");
//for (int i = 0; i < (int) probs_id.size(); i++) {
// printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
//}
return probs_id [ 0 ] . second ;
}
// 500 -> 00:05.000
// 6000 -> 01:00.000
static std : : string to_timestamp ( int64_t t , bool comma = false ) {
@ -2284,19 +2264,17 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
struct whisper_token_data whisper_sample_best ( struct whisper_context * ctx ) {
const int64_t t_start_sample_us = ggml_time_us ( ) ;
// TODO: simplify
auto res = whisper_sample_best ( ctx - > vocab , ctx - > probs . data ( ) + ( ctx - > probs . size ( ) - ctx - > vocab . n_vocab ) ) ;
const auto res = whisper_sample_best ( ctx - > vocab , ctx - > probs . data ( ) + ( ctx - > probs . size ( ) - ctx - > vocab . n_vocab ) , false , false ) ;
ctx - > t_sample_us + = ggml_time_us ( ) - t_start_sample_us ;
return res ;
}
whisper_token whisper_sample_timestamp ( struct whisper_context * ctx ) {
struct whisper_token_data whisper_sample_timestamp ( struct whisper_context * ctx , bool is_initial ) {
const int64_t t_start_sample_us = ggml_time_us ( ) ;
// TODO: simplify
auto res = whisper_sample_timestamp ( ctx - > vocab , ctx - > probs . data ( ) + ( ctx - > probs . size ( ) - ctx - > vocab . n_vocab ) ) ;
const auto res = whisper_sample_best ( ctx - > vocab , ctx - > probs . data ( ) + ( ctx - > probs . size ( ) - ctx - > vocab . n_vocab ) , true , is_initial ) ;
ctx - > t_sample_us + = ggml_time_us ( ) - t_start_sample_us ;
@ -2694,7 +2672,6 @@ int whisper_full(
prompt . insert ( prompt . end ( ) , prompt_init . begin ( ) , prompt_init . end ( ) ) ;
bool done = false ;
int seek_delta = 100 * WHISPER_CHUNK_SIZE ;
// print the prompt
@ -2708,7 +2685,9 @@ int whisper_full(
int result_len = 0 ;
tokens_cur . clear ( ) ;
for ( int i = 0 ; i < whisper_n_text_ctx ( ctx ) / 2 - 4 ; + + i ) {
bool failed = false ;
for ( int i = 0 , n_max = whisper_n_text_ctx ( ctx ) / 2 - 4 ; i < n_max ; + + i ) {
if ( whisper_decode ( ctx , prompt . data ( ) , prompt . size ( ) , n_past , params . n_threads ) ! = 0 ) {
fprintf ( stderr , " %s: failed to decode \n " , __func__ ) ;
return 8 ;
@ -2725,15 +2704,19 @@ int whisper_full(
// feel free to experiment!
//
{
auto token = whisper_sample_best ( ctx ) ;
if ( i = = 0 ) {
token . tid = whisper_token_beg ( ctx ) ;
}
const auto token = ( i = = 0 ) ? whisper_sample_timestamp ( ctx , true ) : whisper_sample_best ( ctx ) ;
// timestamp token - update sliding window
if ( token . id > whisper_token_beg ( ctx ) ) {
seek_delta = 2 * ( token . id - whisper_token_beg ( ctx ) ) ;
const int seek_delta_new = 2 * ( token . id - whisper_token_beg ( ctx ) ) ;
// do not allow to go back in time
if ( seek_delta ! = 100 * WHISPER_CHUNK_SIZE & &
seek_delta > seek_delta_new & & result_len < i ) {
break ;
}
seek_delta = seek_delta_new ;
result_len = i + 1 ;
}
@ -2752,8 +2735,8 @@ int whisper_full(
if ( seek + seek_delta + 100 > = seek_end ) {
result_len = i + 1 ;
} else {
// TODO: figure out how to resolve this
fprintf ( stderr , " \n %s: failed to generate timestamp token - this should not happen \n \n " , __func__ ) ;
failed = true ;
break ;
}
}
@ -2772,11 +2755,21 @@ int whisper_full(
}
}
if ( done ) {
// sometimes, the decoding can get stuck in a repetition loop
// this is a simple strategy to avoid such cases - we simply flag the decoding as failed and advance
// the sliding window by 1 second
if ( i = = n_max - 1 & & ( result_len = = 0 | | seek_delta < 100 * WHISPER_CHUNK_SIZE / 2 ) ) {
failed = true ;
break ;
}
}
if ( failed ) {
fprintf ( stderr , " \n %s: failed to generate timestamp token - using fallback strategy \n \n " , __func__ ) ;
seek + = 100 ;
continue ;
}
// shrink down to result_len
tokens_cur . resize ( result_len ) ;