@ -206,6 +206,7 @@ struct whisper_vocab {
id token_sot = 50257 ;
id token_sot = 50257 ;
id token_prev = 50360 ;
id token_prev = 50360 ;
id token_solm = 50361 ; // ??
id token_solm = 50361 ; // ??
id token_not = 50362 ; // no timestamps
id token_beg = 50363 ;
id token_beg = 50363 ;
// available tasks
// available tasks
@ -217,17 +218,20 @@ struct whisper_vocab {
}
}
} ;
} ;
struct whisper_result {
whisper_vocab : : id id ;
int64_t t ;
} ;
// command-line parameters
// command-line parameters
struct whisper_params {
struct whisper_params {
int32_t seed = - 1 ; // RNG seed, not used currently
int32_t seed = - 1 ; // RNG seed, not used currently
int32_t n_threads = std : : min ( 4 , ( int32_t ) std : : thread : : hardware_concurrency ( ) ) ;
int32_t n_threads = std : : min ( 4 , ( int32_t ) std : : thread : : hardware_concurrency ( ) ) ;
// sampling parameter - used for the greedy strategy
int32_t max_tokens_per_iter = 64 ;
bool verbose = false ;
bool verbose = false ;
bool translate = false ;
bool translate = false ;
bool print_special_tokens = false ;
bool print_special_tokens = false ;
bool no_timestamps = false ;
std : : string language = " en " ;
std : : string language = " en " ;
std : : string model = " models/ggml-base.en.bin " ;
std : : string model = " models/ggml-base.en.bin " ;
@ -244,8 +248,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
params . seed = std : : stoi ( argv [ + + i ] ) ;
params . seed = std : : stoi ( argv [ + + i ] ) ;
} else if ( arg = = " -t " | | arg = = " --threads " ) {
} else if ( arg = = " -t " | | arg = = " --threads " ) {
params . n_threads = std : : stoi ( argv [ + + i ] ) ;
params . n_threads = std : : stoi ( argv [ + + i ] ) ;
} else if ( arg = = " -T " | | arg = = " --tokens " ) {
params . max_tokens_per_iter = std : : stoi ( argv [ + + i ] ) ;
} else if ( arg = = " -v " | | arg = = " --verbose " ) {
} else if ( arg = = " -v " | | arg = = " --verbose " ) {
params . verbose = true ;
params . verbose = true ;
} else if ( arg = = " --translate " ) {
} else if ( arg = = " --translate " ) {
@ -259,6 +261,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
}
}
} else if ( arg = = " -ps " | | arg = = " --print_special " ) {
} else if ( arg = = " -ps " | | arg = = " --print_special " ) {
params . print_special_tokens = true ;
params . print_special_tokens = true ;
} else if ( arg = = " -nt " | | arg = = " --no_timestamps " ) {
params . no_timestamps = true ;
} else if ( arg = = " -m " | | arg = = " --model " ) {
} else if ( arg = = " -m " | | arg = = " --model " ) {
params . model = argv [ + + i ] ;
params . model = argv [ + + i ] ;
} else if ( arg = = " -f " | | arg = = " --file " ) {
} else if ( arg = = " -f " | | arg = = " --file " ) {
@ -284,10 +288,10 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
fprintf ( stderr , " -h, --help show this help message and exit \n " ) ;
fprintf ( stderr , " -h, --help show this help message and exit \n " ) ;
fprintf ( stderr , " -s SEED, --seed SEED RNG seed (default: -1) \n " ) ;
fprintf ( stderr , " -s SEED, --seed SEED RNG seed (default: -1) \n " ) ;
fprintf ( stderr , " -t N, --threads N number of threads to use during computation (default: %d) \n " , params . n_threads ) ;
fprintf ( stderr , " -t N, --threads N number of threads to use during computation (default: %d) \n " , params . n_threads ) ;
fprintf ( stderr , " -T N, --tokens N maximum number of tokens to generate per iteration (default: %d) \n " , params . max_tokens_per_iter ) ;
fprintf ( stderr , " -v, --verbose verbose output \n " ) ;
fprintf ( stderr , " -v, --verbose verbose output \n " ) ;
fprintf ( stderr , " --translate translate from source language to english \n " ) ;
fprintf ( stderr , " --translate translate from source language to english \n " ) ;
fprintf ( stderr , " -ps, --print_special print special tokens \n " ) ;
fprintf ( stderr , " -ps, --print_special print special tokens \n " ) ;
fprintf ( stderr , " -nt, --no_timestamps do not print timestamps \n " ) ;
fprintf ( stderr , " -l LANG, --language LANG spoken language (default: %s) \n " , params . language . c_str ( ) ) ;
fprintf ( stderr , " -l LANG, --language LANG spoken language (default: %s) \n " , params . language . c_str ( ) ) ;
fprintf ( stderr , " -m FNAME, --model FNAME model path (default: %s) \n " , params . model . c_str ( ) ) ;
fprintf ( stderr , " -m FNAME, --model FNAME model path (default: %s) \n " , params . model . c_str ( ) ) ;
fprintf ( stderr , " -f FNAME, --file FNAME input WAV file path (default: %s) \n " , params . fname_inp . c_str ( ) ) ;
fprintf ( stderr , " -f FNAME, --file FNAME input WAV file path (default: %s) \n " , params . fname_inp . c_str ( ) ) ;
@ -591,6 +595,7 @@ bool whisper_model_load(const std::string & fname, whisper_model & model, whispe
vocab . token_sot + + ;
vocab . token_sot + + ;
vocab . token_prev + + ;
vocab . token_prev + + ;
vocab . token_solm + + ;
vocab . token_solm + + ;
vocab . token_not + + ;
vocab . token_beg + + ;
vocab . token_beg + + ;
}
}
@ -605,6 +610,8 @@ bool whisper_model_load(const std::string & fname, whisper_model & model, whispe
word = " [_SOT_] " ;
word = " [_SOT_] " ;
} else if ( i = = vocab . token_prev ) {
} else if ( i = = vocab . token_prev ) {
word = " [_PREV_] " ;
word = " [_PREV_] " ;
} else if ( i = = vocab . token_not ) {
word = " [_NOT_] " ;
} else if ( i = = vocab . token_beg ) {
} else if ( i = = vocab . token_beg ) {
word = " [_BEG_] " ;
word = " [_BEG_] " ;
} else {
} else {
@ -1842,15 +1849,13 @@ bool whisper_decode(
// TODO: temperature
// TODO: temperature
whisper_vocab : : id whisper_sample_best (
whisper_vocab : : id whisper_sample_best (
const whisper_vocab & vocab ,
const whisper_vocab & vocab ,
const float * probs ,
const float * probs ) {
double temp ,
int offset = 0 ) {
int n_logits = vocab . id_to_token . size ( ) ;
int n_logits = vocab . id_to_token . size ( ) ;
std : : vector < std : : pair < double , whisper_vocab : : id > > probs_id ;
std : : vector < std : : pair < double , whisper_vocab : : id > > probs_id ;
probs_id . reserve ( n_logits ) ;
probs_id . reserve ( n_logits ) ;
for ( int i = offset ; i < n_logits ; i + + ) {
for ( int i = 0 ; i < n_logits ; i + + ) {
probs_id . push_back ( std : : make_pair ( probs [ i ] , i ) ) ;
probs_id . push_back ( std : : make_pair ( probs [ i ] , i ) ) ;
}
}
@ -1872,13 +1877,49 @@ whisper_vocab::id whisper_sample_best(
//}
//}
int res = 0 ;
int res = 0 ;
while ( probs_id [ res ] . second = = vocab . token_solm & & res < ( int ) probs_id . size ( ) - 1 ) {
while ( ( probs_id [ res ] . second = = vocab . token_sot | |
probs_id [ res ] . second = = vocab . token_solm | |
probs_id [ res ] . second = = vocab . token_not ) & &
res < ( int ) probs_id . size ( ) - 1 ) {
res + + ;
res + + ;
}
}
return probs_id [ res ] . second ;
return probs_id [ res ] . second ;
}
}
// samples only from the timestamps tokens
whisper_vocab : : id whisper_sample_timestamp (
const whisper_vocab & vocab ,
const float * probs ) {
int n_logits = vocab . id_to_token . size ( ) ;
std : : vector < std : : pair < double , whisper_vocab : : id > > probs_id ;
probs_id . reserve ( n_logits ) ;
for ( int i = vocab . token_beg + 1 ; i < n_logits ; i + + ) {
probs_id . push_back ( std : : make_pair ( probs [ i ] , i ) ) ;
}
const int top_k = 10 ;
// find the top K tokens
std : : partial_sort (
probs_id . begin ( ) ,
probs_id . begin ( ) + top_k , probs_id . end ( ) ,
[ ] ( const std : : pair < double , whisper_vocab : : id > & a , const std : : pair < double , whisper_vocab : : id > & b ) {
return a . first > b . first ;
} ) ;
probs_id . resize ( top_k ) ;
//printf("\n");
//for (int i = 0; i < (int) probs_id.size(); i++) {
// printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
//}
return probs_id [ 0 ] . second ;
}
// Cooley-Tukey FFT
// Cooley-Tukey FFT
// poor man's implmentation - use something better
// poor man's implmentation - use something better
// input is real-valued
// input is real-valued
@ -2032,6 +2073,20 @@ bool log_mel_spectrogram(
return true ;
return true ;
}
}
// 500 -> 00:05.000
// 6000 -> 01:00.000
std : : string to_timestamp ( int64_t t ) {
int64_t sec = t / 100 ;
int64_t msec = t - sec * 100 ;
int64_t min = sec / 60 ;
sec = sec - min * 60 ;
char buf [ 32 ] ;
snprintf ( buf , sizeof ( buf ) , " %02d:%02d.%03d " , ( int ) min , ( int ) sec , ( int ) msec ) ;
return std : : string ( buf ) ;
}
int main ( int argc , char * * argv ) {
int main ( int argc , char * * argv ) {
const int64_t t_main_start_us = ggml_time_us ( ) ;
const int64_t t_main_start_us = ggml_time_us ( ) ;
@ -2128,10 +2183,12 @@ int main(int argc, char ** argv) {
printf ( " %s: WARNING: model is not multilingual, ignoring language and translation options \n " , __func__ ) ;
printf ( " %s: WARNING: model is not multilingual, ignoring language and translation options \n " , __func__ ) ;
}
}
}
}
printf ( " %s: processing %d samples (%.1f sec), %d threads, lang = %s, task = %s ...\n " ,
printf ( " %s: processing %d samples (%.1f sec), %d threads, lang = %s, task = %s , timestamps = %d ...\n " ,
__func__ , int ( pcmf32 . size ( ) ) , float ( pcmf32 . size ( ) ) / SAMPLE_RATE , params . n_threads ,
__func__ , int ( pcmf32 . size ( ) ) , float ( pcmf32 . size ( ) ) / SAMPLE_RATE , params . n_threads ,
g_lang . at ( params . language ) . second . c_str ( ) ,
g_lang . at ( params . language ) . second . c_str ( ) ,
params . translate ? " translate " : " transcribe " ) ;
params . translate ? " translate " : " transcribe " ,
params . no_timestamps ? 0 : 1 ) ;
printf ( " \n " ) ;
}
}
// the accumulated text context so far
// the accumulated text context so far
@ -2148,6 +2205,9 @@ int main(int argc, char ** argv) {
}
}
}
}
// the generated text including timestamps
std : : vector < whisper_result > result_all ;
// main loop
// main loop
int seek = 0 ;
int seek = 0 ;
while ( true ) {
while ( true ) {
@ -2165,7 +2225,7 @@ int main(int argc, char ** argv) {
return 1 ;
return 1 ;
}
}
t_encode_us = ggml_time_us ( ) - t_start_us ;
t_encode_us + = ggml_time_us ( ) - t_start_us ;
}
}
std : : vector < float > probs ;
std : : vector < float > probs ;
@ -2192,11 +2252,16 @@ int main(int argc, char ** argv) {
int seek_delta = 100 * CHUNK_SIZE ;
int seek_delta = 100 * CHUNK_SIZE ;
whisper_vocab : : id last_id = 0 ;
whisper_vocab : : id last_id = 0 ;
//printf("\n\n");
//for (int i = 0; i < prompt.size(); i++) {
//for (int i = 0; i < prompt.size(); i++) {
// printf("%s: prompt[%d] = %s\n", __func__, i, vocab.id_to_token[prompt[i]].c_str());
// printf("%s: prompt[%d] = %s\n", __func__, i, vocab.id_to_token[prompt[i]].c_str());
//}
//}
//printf("\n\n");
// the accumulated transcription in the current interation
int result_len = 0 ;
std : : vector < whisper_result > result_cur ;
printf ( " \n " ) ;
for ( int i = 0 ; i < model . hparams . n_text_ctx / 2 ; + + i ) {
for ( int i = 0 ; i < model . hparams . n_text_ctx / 2 ; + + i ) {
// decode
// decode
if ( prompt . size ( ) > 0 ) {
if ( prompt . size ( ) > 0 ) {
@ -2216,63 +2281,118 @@ int main(int argc, char ** argv) {
// very basic greedy sampling strategy:
// very basic greedy sampling strategy:
//
//
// - always take the most probable token
// - always take the most probable token
// - if we have accumulated more than 'params.max_tokens_per_iter' -> pick most probable timestamp token
// and advance the sliding window by that amount
// - in the meantime, if we encounter 2 consecutive timestamp tokens, we advance the sliding window too
//
//
// more sophisticated sampling strategies could be implemented here, but we keep it simple
// more sophisticated sampling strategies could be implemented here, but we keep it simple
// feel free to experiment!
// feel free to experiment!
//
//
{
{
// sample next token
const float temp = 1.0 ; // TODO
const int n_vocab = model . hparams . n_vocab ;
const int n_vocab = model . hparams . n_vocab ;
whisper_vocab : : id id = 0 ;
whisper_vocab : : id id = 0 ;
whisper_vocab : : id tid = vocab . token_beg ;
{
{
const int64_t t_start_sample_us = ggml_time_us ( ) ;
const int64_t t_start_sample_us = ggml_time_us ( ) ;
id = whisper_sample_best ( vocab , probs . data ( ) + ( probs . size ( ) - n_vocab ) , temp , i > params . max_tokens_per_iter ? vocab . token_beg : 0 );
id = whisper_sample_best ( vocab , probs . data ( ) + ( probs . size ( ) - n_vocab ) );
if ( i > 0 ) {
t_sample_us + = ggml_time_us ( ) - t_start_sample_us ;
tid = whisper_sample_timestamp ( vocab , probs . data ( ) + ( probs . size ( ) - n_vocab ) ) ;
}
}
// end of text token
t_sample_us + = ggml_time_us ( ) - t_start_sample_us ;
if ( id = = vocab . token_eot ) {
break ;
}
}
// 2 consecutive time tokens
// update sliding window
if ( id > vocab . token_beg & & last_id > vocab . token_beg ) {
if ( id > vocab . token_beg ) {
seek_delta = 2 * ( id - vocab . token_beg ) ;
seek_delta = 2 * ( id - vocab . token_beg ) ;
done = true ;
result_len = i + 1 ;
}
}
last_id = id ;
last_id = id ;
// add it to the context
// add it to the context
prompt . push_back ( id ) ;
prompt . push_back ( id ) ;
prompt_past . push_back ( id ) ;
result_cur . push_back ( { id , seek + 2 * ( tid - vocab . token_beg ) } ) ;
}
// display text
// end of text token
for ( auto id : prompt ) {
if ( id = = vocab . token_eot ) {
if ( params . print_special_tokens = = false & & id > = vocab . token_eot ) {
break ;
continue ;
}
}
printf ( " %s " , vocab . id_to_token [ id ] . c_str ( ) ) ;
}
}
fflush ( stdout ) ;
if ( done ) {
if ( done ) {
break ;
break ;
}
}
}
}
result_cur . resize ( result_len ) ;
result_all . insert ( result_all . end ( ) , result_cur . begin ( ) , result_cur . end ( ) ) ;
for ( const auto & r : result_cur ) {
prompt_past . push_back ( r . id ) ;
}
// print the text from this iteration
if ( result_cur . size ( ) > 0 ) {
auto t0 = result_cur . front ( ) . t ;
std : : string text = " " ;
for ( int i = 0 ; i < result_cur . size ( ) ; i + + ) {
if ( params . print_special_tokens = = false & & result_cur [ i ] . id > = vocab . token_eot ) {
} else {
text + = vocab . id_to_token [ result_cur [ i ] . id ] ;
}
if ( result_cur [ i ] . id > vocab . token_beg ) {
const auto t1 = result_cur [ i ] . t ;
if ( ! text . empty ( ) ) {
if ( params . no_timestamps ) {
printf ( " %s " , text . c_str ( ) ) ;
fflush ( stdout ) ;
} else {
printf ( " [%s --> %s] %s \n " , to_timestamp ( t0 ) . c_str ( ) , to_timestamp ( t1 ) . c_str ( ) , text . c_str ( ) ) ;
}
}
text = " " ;
while ( result_cur [ i ] . id > vocab . token_beg & & i < result_cur . size ( ) ) {
i + + ;
}
i - - ;
t0 = result_cur [ i ] . t ;
}
}
if ( ! text . empty ( ) ) {
printf ( " [%s --> %s] %s \n " , to_timestamp ( t0 ) . c_str ( ) , to_timestamp ( seek + seek_delta ) . c_str ( ) , text . c_str ( ) ) ;
}
}
seek + = seek_delta ;
seek + = seek_delta ;
}
}
// WIP: attempt for per-token timestamps
//if (!params.no_timestamps && result_all.size() > 0) {
// const int64_t dt = 500; // 5 second intervals
// int i0 = 0;
// int64_t t0 = result_all[0].t;
// int64_t t1 = t0;
// printf("\n\n");
// for (int i = 0; i < result_all.size(); ++i) {
// printf("'%s' -> %lld\n", vocab.id_to_token[result_all[i].id].c_str(), result_all[i].t);
// if (result_all[i].t - t0 > dt) {
// t1 = result_all[i - 1].t;
// printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
// for (int j = i0; j < i; ++j) {
// printf("%s", vocab.id_to_token.at(result_all[j].id).c_str());
// }
// printf("\n");
// i0 = i;
// t0 = result_all[i].t;
// }
// }
//}
// report timing
// report timing
{
{
const int64_t t_main_end_us = ggml_time_us ( ) ;
const int64_t t_main_end_us = ggml_time_us ( ) ;