@ -567,7 +567,7 @@ struct whisper_state {
whisper_decoder decoders [ WHISPER_MAX_DECODERS ] = { } ;
// memory buffers used by encode / decode contexts
std : : vector < uint8_t > buf_compute {} ;
std : : vector < uint8_t > buf_compute ;
std : : vector < uint8_t > buf_scratch [ WHISPER_MAX_SCRATCH_BUFFERS ] ;
int buf_last = 0 ;
@ -576,11 +576,11 @@ struct whisper_state {
// decode output (2-dimensional array: [n_tokens][n_vocab])
std : : vector < float > logits ;
std : : vector < whisper_segment > result_all {} ;
std : : vector < whisper_token > prompt_past {} ;
std : : vector < whisper_segment > result_all ;
std : : vector < whisper_token > prompt_past ;
// work container used to avoid memory allocations
std : : vector < std : : pair < double , whisper_vocab : : id > > logits_id {} ;
std : : vector < std : : pair < double , whisper_vocab : : id > > logits_id ;
mutable std : : mt19937 rng ; // used for sampling at t > 0.0
@ -590,20 +590,19 @@ struct whisper_state {
int64_t t_beg = 0 ;
int64_t t_last = 0 ;
whisper_token tid_last ;
std : : vector < float > energy {} ; // PCM signal energy
std : : vector < float > energy ; // PCM signal energy
// [EXPERIMENTAL] speed-up techniques
int32_t exp_n_audio_ctx = 0 ; // 0 - use default
void use_buf ( struct ggml_context * ctx , int i ) {
void use_buf ( struct ggml_context * ctx , int i ) {
# if defined(WHISPER_USE_SCRATCH)
size_t last_size = 0 ;
if ( i = = - 1 ) {
last_size = ggml_set_scratch ( ctx , { 0 , 0 , nullptr , } ) ;
}
else {
auto & buf = buf_scratch [ i ] ;
} else {
auto & buf = buf_scratch [ i ] ;
last_size = ggml_set_scratch ( ctx , { 0 , buf . size ( ) , buf . data ( ) , } ) ;
}
@ -613,8 +612,8 @@ struct whisper_state {
buf_last = i ;
# else
( void ) i ;
( void ) ctx ;
( void ) i ;
( void ) ctx ;
# endif
}
@ -622,7 +621,7 @@ struct whisper_state {
# if defined(WHISPER_USE_SCRATCH)
return buf_max_size [ i ] ;
# else
( void ) i ;
( void ) i ;
return 0 ;
# endif
}
@ -637,7 +636,7 @@ struct whisper_context {
whisper_model model ;
whisper_vocab vocab ;
whisper_state * default_ state = nullptr ;
whisper_state * state = nullptr ;
} ;
template < typename T >
@ -2475,19 +2474,6 @@ static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, cons
return tokens ;
}
// Initializes default state in the context
// returns true if new state was initiallized or false if the state already exists
//
bool whisper_init_default_state ( struct whisper_context * wctx )
{
if ( wctx - > default_state = = nullptr ) {
wctx - > default_state = whisper_init_state ( wctx ) ;
return true ;
}
return false ;
}
//
// interface implementation
//
@ -2540,7 +2526,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
return state ;
}
struct whisper_context * whisper_init_from_file ( const char * path_model ) {
struct whisper_context * whisper_init_from_file _no_state ( const char * path_model ) {
whisper_model_loader loader = { } ;
fprintf ( stderr , " %s: loading model from '%s' \n " , __func__ , path_model ) ;
@ -2568,10 +2554,10 @@ struct whisper_context * whisper_init_from_file(const char * path_model) {
fin - > close ( ) ;
} ;
return whisper_init ( & loader ) ;
return whisper_init _no_state ( & loader ) ;
}
struct whisper_context * whisper_init_from_buffer ( void * buffer , size_t buffer_size ) {
struct whisper_context * whisper_init_from_buffer _no_state ( void * buffer , size_t buffer_size ) {
struct buf_context {
uint8_t * buffer ;
size_t size ;
@ -2604,10 +2590,10 @@ struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_s
loader . close = [ ] ( void * /*ctx*/ ) { } ;
return whisper_init ( & loader ) ;
return whisper_init _no_state ( & loader ) ;
}
struct whisper_context * whisper_init ( struct whisper_model_loader * loader ) {
struct whisper_context * whisper_init _no_state ( struct whisper_model_loader * loader ) {
ggml_time_init ( ) ;
whisper_context * ctx = new whisper_context ;
@ -2624,7 +2610,52 @@ struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
return ctx ;
}
void whisper_free_state ( struct whisper_state * state )
struct whisper_context * whisper_init_from_file ( const char * path_model ) {
whisper_context * ctx = whisper_init_from_file_no_state ( path_model ) ;
if ( ! ctx ) {
return nullptr ;
}
ctx - > state = whisper_init_state ( ctx ) ;
if ( ! ctx - > state ) {
whisper_free ( ctx ) ;
return nullptr ;
}
return ctx ;
}
struct whisper_context * whisper_init_from_buffer ( void * buffer , size_t buffer_size ) {
whisper_context * ctx = whisper_init_from_buffer_no_state ( buffer , buffer_size ) ;
if ( ! ctx ) {
return nullptr ;
}
ctx - > state = whisper_init_state ( ctx ) ;
if ( ! ctx - > state ) {
whisper_free ( ctx ) ;
return nullptr ;
}
return ctx ;
}
struct whisper_context * whisper_init ( struct whisper_model_loader * loader ) {
whisper_context * ctx = whisper_init_no_state ( loader ) ;
if ( ! ctx ) {
return nullptr ;
}
ctx - > state = whisper_init_state ( ctx ) ;
if ( ! ctx - > state ) {
whisper_free ( ctx ) ;
return nullptr ;
}
return ctx ;
}
void whisper_free_state ( struct whisper_state * state )
{
if ( state ) {
if ( state - > kv_cross . ctx ) {
@ -2648,15 +2679,13 @@ void whisper_free(struct whisper_context * ctx) {
delete ctx - > model . buf ;
}
whisper_free_state ( ctx - > default_ state) ;
whisper_free_state ( ctx - > state) ;
delete ctx ;
}
}
int whisper_pcm_to_mel_with_state ( struct whisper_context * ctx , struct whisper_state * state , const float * samples , int n_samples , int n_threads ) {
whisper_init_default_state ( ctx ) ;
if ( ! log_mel_spectrogram ( * state , samples , n_samples , WHISPER_SAMPLE_RATE , WHISPER_N_FFT , WHISPER_HOP_LENGTH , WHISPER_N_MEL , n_threads , ctx - > model . filters , false , state - > mel ) ) {
fprintf ( stderr , " %s: failed to compute mel spectrogram \n " , __func__ ) ;
return - 1 ;
@ -2666,9 +2695,7 @@ int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_s
}
int whisper_pcm_to_mel ( struct whisper_context * ctx , const float * samples , int n_samples , int n_threads ) {
whisper_init_default_state ( ctx ) ;
return whisper_pcm_to_mel_with_state ( ctx , ctx - > default_state , samples , n_samples , n_threads ) ;
return whisper_pcm_to_mel_with_state ( ctx , ctx - > state , samples , n_samples , n_threads ) ;
}
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
@ -2683,9 +2710,7 @@ int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, st
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
int whisper_pcm_to_mel_phase_vocoder ( struct whisper_context * ctx , const float * samples , int n_samples , int n_threads ) {
whisper_init_default_state ( ctx ) ;
return whisper_pcm_to_mel_phase_vocoder_with_state ( ctx , ctx - > default_state , samples , n_samples , n_threads ) ;
return whisper_pcm_to_mel_phase_vocoder_with_state ( ctx , ctx - > state , samples , n_samples , n_threads ) ;
}
int whisper_set_mel_with_state (
@ -2713,9 +2738,7 @@ int whisper_set_mel(
const float * data ,
int n_len ,
int n_mel ) {
whisper_init_default_state ( ctx ) ;
return whisper_set_mel_with_state ( ctx , ctx - > default_state , data , n_len , n_mel ) ;
return whisper_set_mel_with_state ( ctx , ctx - > state , data , n_len , n_mel ) ;
}
int whisper_encode_with_state ( struct whisper_context * ctx , struct whisper_state * state , int offset , int n_threads ) {
@ -2728,7 +2751,7 @@ int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state
}
int whisper_encode ( struct whisper_context * ctx , int offset , int n_threads ) {
if ( ! whisper_encode_internal ( * ctx , * ctx - > default_ state, offset , n_threads ) ) {
if ( ! whisper_encode_internal ( * ctx , * ctx - > state, offset , n_threads ) ) {
fprintf ( stderr , " %s: failed to eval \n " , __func__ ) ;
return - 1 ;
}
@ -2751,13 +2774,13 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
// TODO: add selected_decoder_id to state
const int selected_decoder_id = 0 ;
if ( ctx - > default_ state = = nullptr ) {
fprintf ( stderr , " %s: ERROR default_ state was not loaded.\n " , __func__ ) ;
if ( ctx - > state = = nullptr ) {
fprintf ( stderr , " %s: ERROR state was not loaded.\n " , __func__ ) ;
return false ;
}
if ( ! whisper_decode_internal ( * ctx , * ctx - > default_ state, ctx - > default_ state- > decoders [ selected_decoder_id ] , tokens , n_tokens , n_past , n_threads ) ) {
if ( ! whisper_decode_internal ( * ctx , * ctx - > state, ctx - > state- > decoders [ selected_decoder_id ] , tokens , n_tokens , n_past , n_threads ) ) {
fprintf ( stderr , " %s: failed to eval \n " , __func__ ) ;
return 1 ;
}
@ -2895,8 +2918,7 @@ int whisper_lang_auto_detect(
int offset_ms ,
int n_threads ,
float * lang_probs ) {
whisper_init_default_state ( ctx ) ;
return whisper_lang_auto_detect_with_state ( ctx , ctx - > default_state , offset_ms , n_threads , lang_probs ) ;
return whisper_lang_auto_detect_with_state ( ctx , ctx - > state , offset_ms , n_threads , lang_probs ) ;
}
int whisper_n_len_from_state ( struct whisper_state * state ) {
@ -2904,7 +2926,7 @@ int whisper_n_len_from_state(struct whisper_state * state) {
}
int whisper_n_len ( struct whisper_context * ctx ) {
return ctx - > default_ state- > mel . n_len ;
return ctx - > state- > mel . n_len ;
}
int whisper_n_vocab ( struct whisper_context * ctx ) {
@ -2924,7 +2946,7 @@ int whisper_is_multilingual(struct whisper_context * ctx) {
}
float * whisper_get_logits ( struct whisper_context * ctx ) {
return ctx - > default_ state- > logits . data ( ) ;
return ctx - > state- > logits . data ( ) ;
}
@ -2977,27 +2999,26 @@ void whisper_print_timings(struct whisper_context * ctx) {
fprintf ( stderr , " \n " ) ;
fprintf ( stderr , " %s: load time = %8.2f ms \n " , __func__ , ctx - > t_load_us / 1000.0f ) ;
if ( ctx - > default_ state ! = nullptr ) {
if ( ctx - > state ! = nullptr ) {
const int32_t n_sample = std : : max ( 1 , ctx - > default_ state- > n_sample ) ;
const int32_t n_encode = std : : max ( 1 , ctx - > default_ state- > n_encode ) ;
const int32_t n_decode = std : : max ( 1 , ctx - > default_ state- > n_decode ) ;
const int32_t n_sample = std : : max ( 1 , ctx - > state- > n_sample ) ;
const int32_t n_encode = std : : max ( 1 , ctx - > state- > n_encode ) ;
const int32_t n_decode = std : : max ( 1 , ctx - > state- > n_decode ) ;
fprintf ( stderr , " %s: fallbacks = %3d p / %3d h \n " , __func__ , ctx - > default_ state- > n_fail_p , ctx - > default_ state- > n_fail_h ) ;
fprintf ( stderr , " %s: mel time = %8.2f ms \n " , __func__ , ctx - > default_ state- > t_mel_us / 1000.0f ) ;
fprintf ( stderr , " %s: sample time = %8.2f ms / %5d runs (%8.2f ms per run) \n " , __func__ , 1e-3 f * ctx - > default_ state- > t_sample_us , n_sample , 1e-3 f * ctx - > default_ state- > t_sample_us / n_sample ) ;
fprintf ( stderr , " %s: encode time = %8.2f ms / %5d runs (%8.2f ms per run) \n " , __func__ , 1e-3 f * ctx - > default_ state- > t_encode_us , n_encode , 1e-3 f * ctx - > default_ state- > t_encode_us / n_encode ) ;
fprintf ( stderr , " %s: decode time = %8.2f ms / %5d runs (%8.2f ms per run) \n " , __func__ , 1e-3 f * ctx - > default_ state- > t_decode_us , n_decode , 1e-3 f * ctx - > default_ state- > t_decode_us / n_decode ) ;
fprintf ( stderr , " %s: fallbacks = %3d p / %3d h \n " , __func__ , ctx - > state- > n_fail_p , ctx - > state- > n_fail_h ) ;
fprintf ( stderr , " %s: mel time = %8.2f ms \n " , __func__ , ctx - > state- > t_mel_us / 1000.0f ) ;
fprintf ( stderr , " %s: sample time = %8.2f ms / %5d runs (%8.2f ms per run) \n " , __func__ , 1e-3 f * ctx - > state- > t_sample_us , n_sample , 1e-3 f * ctx - > state- > t_sample_us / n_sample ) ;
fprintf ( stderr , " %s: encode time = %8.2f ms / %5d runs (%8.2f ms per run) \n " , __func__ , 1e-3 f * ctx - > state- > t_encode_us , n_encode , 1e-3 f * ctx - > state- > t_encode_us / n_encode ) ;
fprintf ( stderr , " %s: decode time = %8.2f ms / %5d runs (%8.2f ms per run) \n " , __func__ , 1e-3 f * ctx - > state- > t_decode_us , n_decode , 1e-3 f * ctx - > state- > t_decode_us / n_decode ) ;
}
fprintf ( stderr , " %s: total time = %8.2f ms \n " , __func__ , ( t_end_us - ctx - > t_start_us ) / 1000.0f ) ;
}
void whisper_reset_timings ( struct whisper_context * ctx ) {
if ( ctx - > default_state ! = nullptr ) {
ctx - > default_state - > t_sample_us = 0 ;
ctx - > default_state - > t_encode_us = 0 ;
ctx - > default_state - > t_decode_us = 0 ;
if ( ctx - > state ! = nullptr ) {
ctx - > state - > t_sample_us = 0 ;
ctx - > state - > t_encode_us = 0 ;
ctx - > state - > t_decode_us = 0 ;
}
}
@ -4335,9 +4356,7 @@ int whisper_full(
struct whisper_full_params params ,
const float * samples ,
int n_samples ) {
//This is not thread safe and it's using default_state
whisper_init_default_state ( ctx ) ;
return whisper_full_with_state ( ctx , ctx - > default_state , params , samples , n_samples ) ;
return whisper_full_with_state ( ctx , ctx - > state , params , samples , n_samples ) ;
}
int whisper_full_parallel (
@ -4349,12 +4368,10 @@ int whisper_full_parallel(
if ( n_processors = = 1 ) {
return whisper_full ( ctx , params , samples , n_samples ) ;
}
whisper_init_default_state ( ctx ) ;
int ret = 0 ;
// prepare separate states for each thread
std : : vector < whisper_state * > states {} ;
std : : vector < whisper_state * > states ;
const int offset_samples = ( WHISPER_SAMPLE_RATE * params . offset_ms ) / 1000 ;
const int n_samples_per_processor = ( n_samples - offset_samples ) / n_processors ;
@ -4389,7 +4406,7 @@ int whisper_full_parallel(
params_cur . print_realtime = false ;
// Run the first transformation using default state but only for the first chunk.
ret = whisper_full_with_state ( ctx , ctx - > default_ state, std : : move ( params_cur ) , samples , offset_samples + n_samples_per_processor ) ;
ret = whisper_full_with_state ( ctx , ctx - > state, std : : move ( params_cur ) , samples , offset_samples + n_samples_per_processor ) ;
}
for ( int i = 0 ; i < n_processors - 1 ; + + i ) {
@ -4409,32 +4426,32 @@ int whisper_full_parallel(
// make sure that segments are not overlapping
if ( ! ctx - > default_ state- > result_all . empty ( ) ) {
result . t0 = std : : max ( result . t0 , ctx - > default_ state- > result_all . back ( ) . t1 ) ;
if ( ! ctx - > state- > result_all . empty ( ) ) {
result . t0 = std : : max ( result . t0 , ctx - > state- > result_all . back ( ) . t1 ) ;
}
ctx - > default_ state- > result_all . push_back ( std : : move ( result ) ) ;
ctx - > state- > result_all . push_back ( std : : move ( result ) ) ;
// call the new_segment_callback for each segment
if ( params . new_segment_callback ) {
params . new_segment_callback ( ctx , ctx - > default_ state, 1 , params . new_segment_callback_user_data ) ;
params . new_segment_callback ( ctx , ctx - > state, 1 , params . new_segment_callback_user_data ) ;
}
}
ctx - > default_ state- > t_mel_us + = states [ i ] - > t_mel_us ;
ctx - > state- > t_mel_us + = states [ i ] - > t_mel_us ;
ctx - > default_ state- > t_sample_us + = states [ i ] - > t_sample_us ;
ctx - > default_ state- > t_encode_us + = states [ i ] - > t_encode_us ;
ctx - > default_ state- > t_decode_us + = states [ i ] - > t_decode_us ;
ctx - > state- > t_sample_us + = states [ i ] - > t_sample_us ;
ctx - > state- > t_encode_us + = states [ i ] - > t_encode_us ;
ctx - > state- > t_decode_us + = states [ i ] - > t_decode_us ;
whisper_free_state ( states [ i ] ) ;
}
// average the timings
ctx - > default_ state- > t_mel_us / = n_processors ;
ctx - > default_ state- > t_sample_us / = n_processors ;
ctx - > default_ state- > t_encode_us / = n_processors ;
ctx - > default_ state- > t_decode_us / = n_processors ;
ctx - > state- > t_mel_us / = n_processors ;
ctx - > state- > t_sample_us / = n_processors ;
ctx - > state- > t_encode_us / = n_processors ;
ctx - > state- > t_decode_us / = n_processors ;
// print information about the audio boundaries
fprintf ( stderr , " \n " ) ;
@ -4452,7 +4469,7 @@ int whisper_full_n_segments_from_state(struct whisper_state * state) {
}
int whisper_full_n_segments ( struct whisper_context * ctx ) {
return ctx - > default_ state- > result_all . size ( ) ;
return ctx - > state- > result_all . size ( ) ;
}
int whisper_full_lang_id_from_state ( struct whisper_state * state ) {
@ -4460,7 +4477,7 @@ int whisper_full_lang_id_from_state(struct whisper_state * state) {
}
int whisper_full_lang_id ( struct whisper_context * ctx ) {
return ctx - > default_ state- > lang_id ;
return ctx - > state- > lang_id ;
}
int64_t whisper_full_get_segment_t0_from_state ( struct whisper_state * state , int i_segment ) {
@ -4468,7 +4485,7 @@ int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int
}
int64_t whisper_full_get_segment_t0 ( struct whisper_context * ctx , int i_segment ) {
return ctx - > default_ state- > result_all [ i_segment ] . t0 ;
return ctx - > state- > result_all [ i_segment ] . t0 ;
}
int64_t whisper_full_get_segment_t1_from_state ( struct whisper_state * state , int i_segment ) {
@ -4476,7 +4493,7 @@ int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int
}
int64_t whisper_full_get_segment_t1 ( struct whisper_context * ctx , int i_segment ) {
return ctx - > default_ state- > result_all [ i_segment ] . t1 ;
return ctx - > state- > result_all [ i_segment ] . t1 ;
}
const char * whisper_full_get_segment_text_from_state ( struct whisper_state * state , int i_segment ) {
@ -4484,7 +4501,7 @@ const char * whisper_full_get_segment_text_from_state(struct whisper_state * sta
}
const char * whisper_full_get_segment_text ( struct whisper_context * ctx , int i_segment ) {
return ctx - > default_ state- > result_all [ i_segment ] . text . c_str ( ) ;
return ctx - > state- > result_all [ i_segment ] . text . c_str ( ) ;
}
int whisper_full_n_tokens_from_state ( struct whisper_state * state , int i_segment ) {
@ -4492,7 +4509,7 @@ int whisper_full_n_tokens_from_state(struct whisper_state * state, int i_segment
}
int whisper_full_n_tokens ( struct whisper_context * ctx , int i_segment ) {
return ctx - > default_ state- > result_all [ i_segment ] . tokens . size ( ) ;
return ctx - > state- > result_all [ i_segment ] . tokens . size ( ) ;
}
const char * whisper_full_get_token_text_from_state ( struct whisper_context * ctx , struct whisper_state * state , int i_segment , int i_token ) {
@ -4500,7 +4517,7 @@ const char * whisper_full_get_token_text_from_state(struct whisper_context * ctx
}
const char * whisper_full_get_token_text ( struct whisper_context * ctx , int i_segment , int i_token ) {
return ctx - > vocab . id_to_token [ ctx - > default_ state- > result_all [ i_segment ] . tokens [ i_token ] . id ] . c_str ( ) ;
return ctx - > vocab . id_to_token [ ctx - > state- > result_all [ i_segment ] . tokens [ i_token ] . id ] . c_str ( ) ;
}
whisper_token whisper_full_get_token_id_from_state ( struct whisper_state * state , int i_segment , int i_token ) {
@ -4508,7 +4525,7 @@ whisper_token whisper_full_get_token_id_from_state(struct whisper_state * state,
}
whisper_token whisper_full_get_token_id ( struct whisper_context * ctx , int i_segment , int i_token ) {
return ctx - > default_ state- > result_all [ i_segment ] . tokens [ i_token ] . id ;
return ctx - > state- > result_all [ i_segment ] . tokens [ i_token ] . id ;
}
struct whisper_token_data whisper_full_get_token_data_from_state ( struct whisper_state * state , int i_segment , int i_token ) {
@ -4516,7 +4533,7 @@ struct whisper_token_data whisper_full_get_token_data_from_state(struct whisper_
}
struct whisper_token_data whisper_full_get_token_data ( struct whisper_context * ctx , int i_segment , int i_token ) {
return ctx - > default_ state- > result_all [ i_segment ] . tokens [ i_token ] ;
return ctx - > state- > result_all [ i_segment ] . tokens [ i_token ] ;
}
float whisper_full_get_token_p_from_state ( struct whisper_context * ctx , struct whisper_state * state , int i_segment , int i_token ) {
@ -4524,7 +4541,7 @@ float whisper_full_get_token_p_from_state(struct whisper_context * ctx, struct w
}
float whisper_full_get_token_p ( struct whisper_context * ctx , int i_segment , int i_token ) {
return ctx - > default_ state- > result_all [ i_segment ] . tokens [ i_token ] . p ;
return ctx - > state- > result_all [ i_segment ] . tokens [ i_token ] . p ;
}
// =================================================================================================