@ -405,6 +405,8 @@ struct whisper_context {
std : : vector < whisper_result > result_cur ;
std : : vector < whisper_segment > result_all ;
std : : vector < whisper_token > prompt_past ;
} ;
// load the model from a ggml file
@ -419,7 +421,7 @@ struct whisper_context {
// see the convert-pt-to-ggml.py script for details
//
bool whisper_model_load ( const std : : string & fname , whisper_context & wctx ) {
printf( " %s: loading model from '%s' \n " , __func__ , fname . c_str ( ) ) ;
f printf( stderr , " %s: loading model from '%s' \n " , __func__ , fname . c_str ( ) ) ;
auto & model = wctx . model ;
auto & vocab = wctx . vocab ;
@ -478,18 +480,18 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
model . type = e_model : : MODEL_LARGE ;
}
printf( " %s: n_vocab = %d \n " , __func__ , hparams . n_vocab ) ;
printf( " %s: n_audio_ctx = %d \n " , __func__ , hparams . n_audio_ctx ) ;
printf( " %s: n_audio_state = %d \n " , __func__ , hparams . n_audio_state ) ;
printf( " %s: n_audio_head = %d \n " , __func__ , hparams . n_audio_head ) ;
printf( " %s: n_audio_layer = %d \n " , __func__ , hparams . n_audio_layer ) ;
printf( " %s: n_text_ctx = %d \n " , __func__ , hparams . n_text_ctx ) ;
printf( " %s: n_text_state = %d \n " , __func__ , hparams . n_text_state ) ;
printf( " %s: n_text_head = %d \n " , __func__ , hparams . n_text_head ) ;
printf( " %s: n_text_layer = %d \n " , __func__ , hparams . n_text_layer ) ;
printf( " %s: n_mels = %d \n " , __func__ , hparams . n_mels ) ;
printf( " %s: f16 = %d \n " , __func__ , hparams . f16 ) ;
printf( " %s: type = %d \n " , __func__ , model . type ) ;
f printf( stderr , " %s: n_vocab = %d \n " , __func__ , hparams . n_vocab ) ;
f printf( stderr , " %s: n_audio_ctx = %d \n " , __func__ , hparams . n_audio_ctx ) ;
f printf( stderr , " %s: n_audio_state = %d \n " , __func__ , hparams . n_audio_state ) ;
f printf( stderr , " %s: n_audio_head = %d \n " , __func__ , hparams . n_audio_head ) ;
f printf( stderr , " %s: n_audio_layer = %d \n " , __func__ , hparams . n_audio_layer ) ;
f printf( stderr , " %s: n_text_ctx = %d \n " , __func__ , hparams . n_text_ctx ) ;
f printf( stderr , " %s: n_text_state = %d \n " , __func__ , hparams . n_text_state ) ;
f printf( stderr , " %s: n_text_head = %d \n " , __func__ , hparams . n_text_head ) ;
f printf( stderr , " %s: n_text_layer = %d \n " , __func__ , hparams . n_text_layer ) ;
f printf( stderr , " %s: n_mels = %d \n " , __func__ , hparams . n_mels ) ;
f printf( stderr , " %s: f16 = %d \n " , __func__ , hparams . f16 ) ;
f printf( stderr , " %s: type = %d \n " , __func__ , model . type ) ;
wctx . buf_model . resize ( MEM_REQ_MODEL . at ( model . type ) ) ;
wctx . buf_compute . resize ( std : : max ( MEM_REQ_ENCODE . at ( model . type ) , MEM_REQ_DECODE . at ( model . type ) ) ) ;
@ -501,7 +503,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
wctx . buf_compute . size ( ) +
wctx . buf_compute_layer . size ( ) ;
printf( " %s: mem_required = %.2f MB \n " , __func__ , mem_required / 1024.0 / 1024.0 ) ;
f printf( stderr , " %s: mem_required = %.2f MB \n " , __func__ , mem_required / 1024.0 / 1024.0 ) ;
}
// load mel filters
@ -551,7 +553,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
}
if ( n_vocab < model . hparams . n_vocab ) {
printf( " %s: adding %d extra tokens \n " , __func__ , model . hparams . n_vocab - n_vocab ) ;
f printf( stderr , " %s: adding %d extra tokens \n " , __func__ , model . hparams . n_vocab - n_vocab ) ;
for ( int i = n_vocab ; i < model . hparams . n_vocab ; i + + ) {
if ( i > vocab . token_beg ) {
word = " [_TT_ " + std : : to_string ( i - vocab . token_beg ) + " ] " ;
@ -696,7 +698,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
ctx_size + = ( 15 + 15 * n_audio_layer + 24 * n_text_layer ) * 256 ; // object overhead
printf( " %s: ggml ctx size = %6.2f MB \n " , __func__ , ctx_size / ( 1024.0 * 1024.0 ) ) ;
f printf( stderr , " %s: ggml ctx size = %6.2f MB \n " , __func__ , ctx_size / ( 1024.0 * 1024.0 ) ) ;
}
// create the ggml context
@ -943,11 +945,12 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
ggml_nbytes ( model . memory_k ) + ggml_nbytes ( model . memory_v ) +
ggml_nbytes ( model . memory_cross_k ) + ggml_nbytes ( model . memory_cross_v ) ;
printf( " %s: memory size = %8.2f MB \n " , __func__ , memory_size / 1024.0 / 1024.0 ) ;
f printf( stderr , " %s: memory size = %8.2f MB \n " , __func__ , memory_size / 1024.0 / 1024.0 ) ;
}
// load weights
{
int n_loaded = 0 ;
size_t total_size = 0 ;
while ( true ) {
@ -1002,9 +1005,17 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
//printf("%24s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
total_size + = ggml_nbytes ( tensor ) ;
n_loaded + + ;
}
printf ( " %s: model size = %8.2f MB \n " , __func__ , total_size / 1024.0 / 1024.0 ) ;
fprintf ( stderr , " %s: model size = %8.2f MB \n " , __func__ , total_size / 1024.0 / 1024.0 ) ;
if ( n_loaded = = 0 ) {
fprintf ( stderr , " %s: WARN no tensors loaded from model file - assuming empty model for testing \n " , __func__ ) ;
} else if ( n_loaded ! = ( int ) model . tensors . size ( ) ) {
fprintf ( stderr , " %s: ERROR not all tensors loaded from model file - expected %zu, got %d \n " , __func__ , model . tensors . size ( ) , n_loaded ) ;
return false ;
}
}
fin . close ( ) ;
@ -1020,8 +1031,6 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
// - model: the model
// - n_threads: number of threads to use
// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
// - mel_inp: input mel spectrogram
// - features: output encoded features
//
bool whisper_encode (
whisper_context & wctx ,
@ -1405,10 +1414,9 @@ bool whisper_encode(
//
// - model: the model
// - n_threads: number of threads to use
// - n_past: prompt length
// - prompt: text prompt
// - logits_out: output logits
// - probs_out: output probabilities
// - tokens: text prompt
// - n_tokens: number of tokens in the prompt
// - n_past: number of past tokens to prefix the prompt with
//
bool whisper_decode (
whisper_context & wctx ,
@ -1773,8 +1781,6 @@ bool whisper_decode(
}
// the most basic sampling scheme - select the top token
// TODO: beam search
// TODO: temperature
whisper_vocab : : id whisper_sample_best (
const whisper_vocab & vocab ,
const float * probs , bool need_timestamp ) {
@ -2236,13 +2242,13 @@ whisper_token whisper_token_transcribe() {
void whisper_print_timings ( struct whisper_context * ctx ) {
const int64_t t_end_us = ggml_time_us ( ) ;
printf( " \n \n " ) ;
printf( " %s: load time = %8.2f ms \n " , __func__ , ctx - > t_load_us / 1000.0f ) ;
printf( " %s: mel time = %8.2f ms \n " , __func__ , ctx - > t_mel_us / 1000.0f ) ;
printf( " %s: sample time = %8.2f ms \n " , __func__ , ctx - > t_sample_us / 1000.0f ) ;
printf( " %s: encode time = %8.2f ms / %.2f ms per layer \n " , __func__ , ctx - > t_encode_us / 1000.0f , ctx - > t_encode_us / 1000.0f / ctx - > model . hparams . n_audio_layer ) ;
printf( " %s: decode time = %8.2f ms / %.2f ms per layer \n " , __func__ , ctx - > t_decode_us / 1000.0f , ctx - > t_decode_us / 1000.0f / ctx - > model . hparams . n_text_layer ) ;
printf( " %s: total time = %8.2f ms \n " , __func__ , ( t_end_us - ctx - > t_start_us ) / 1000.0f ) ;
f printf( stderr , " \n " ) ;
f printf( stderr , " %s: load time = %8.2f ms \n " , __func__ , ctx - > t_load_us / 1000.0f ) ;
f printf( stderr , " %s: mel time = %8.2f ms \n " , __func__ , ctx - > t_mel_us / 1000.0f ) ;
f printf( stderr , " %s: sample time = %8.2f ms \n " , __func__ , ctx - > t_sample_us / 1000.0f ) ;
f printf( stderr , " %s: encode time = %8.2f ms / %.2f ms per layer \n " , __func__ , ctx - > t_encode_us / 1000.0f , ctx - > t_encode_us / 1000.0f / ctx - > model . hparams . n_audio_layer ) ;
f printf( stderr , " %s: decode time = %8.2f ms / %.2f ms per layer \n " , __func__ , ctx - > t_decode_us / 1000.0f , ctx - > t_decode_us / 1000.0f / ctx - > model . hparams . n_text_layer ) ;
f printf( stderr , " %s: total time = %8.2f ms \n " , __func__ , ( t_end_us - ctx - > t_start_us ) / 1000.0f ) ;
}
////////////////////////////////////////////////////////////////////////////
@ -2256,8 +2262,10 @@ struct whisper_full_params whisper_full_default_params(enum whisper_decode_strat
result = ( struct whisper_full_params ) {
. strategy = WHISPER_DECODE_GREEDY ,
. n_threads = std : : min ( 4 , ( int32_t ) std : : thread : : hardware_concurrency ( ) ) ,
. offset_ms = 0 ,
. translate = false ,
. no_context = false ,
. print_special_tokens = false ,
. print_progress = true ,
. print_realtime = false ,
@ -2275,8 +2283,10 @@ struct whisper_full_params whisper_full_default_params(enum whisper_decode_strat
result = ( struct whisper_full_params ) {
. strategy = WHISPER_DECODE_GREEDY ,
. n_threads = std : : min ( 4 , ( int32_t ) std : : thread : : hardware_concurrency ( ) ) ,
. offset_ms = 0 ,
. translate = false ,
. no_context = false ,
. print_special_tokens = false ,
. print_progress = true ,
. print_realtime = false ,
@ -2295,6 +2305,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_decode_strat
return result ;
}
int whisper_full (
struct whisper_context * ctx ,
struct whisper_full_params params ,
@ -2307,7 +2318,10 @@ int whisper_full(
}
// the accumulated text context so far
std : : vector < whisper_token > prompt_past = { } ;
auto & prompt_past = ctx - > prompt_past ;
if ( params . no_context ) {
prompt_past . clear ( ) ;
}
// these tokens determine the task that will be performed
std : : vector < whisper_token > prompt_init = { whisper_token_sot ( ctx ) } ;
@ -2329,13 +2343,13 @@ int whisper_full(
int progress_step = 5 ;
// main loop
int seek = 0;
int seek = params . offset_ms / 1 0;
while ( true ) {
int progress_cur = ( 100 * seek ) / whisper_n_len ( ctx ) ;
while ( progress_cur > = progress_prev + progress_step ) {
progress_prev + = progress_step ;
if ( params . print_progress ) {
printf( " %s: progress = %3d%% \n " , __func__ , progress_prev ) ;
f printf( stderr , " %s: progress = %3d%% \n " , __func__ , progress_prev ) ;
}
}
@ -2463,7 +2477,7 @@ int whisper_full(
result_all . push_back ( { t0 , t1 , text } ) ;
}
text = " " ;
while ( result_cur[ i ] . id > whisper_token_beg ( ctx ) & & i < ( int ) result_cur . size ( ) ) {
while ( i < ( int ) result_cur . size ( ) & & result_cur[ i ] . id > whisper_token_beg ( ctx ) ) {
i + + ;
}
i - - ;