@ -1,3 +1,4 @@
# define WHISPER_BUILD
# include "whisper.h"
# include "ggml.h"
@ -132,11 +133,19 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
static const size_t MB = 1024 * 1024 ;
static const std : : map < e_model , size_t > MEM_REQ_MODEL = {
{ MODEL_TINY , 86ull * MB } ,
{ MODEL_BASE , 165ull * MB } ,
{ MODEL_SMALL , 540ull * MB } ,
{ MODEL_MEDIUM , 1650ull * MB } ,
{ MODEL_LARGE , 3260ull * MB } ,
{ MODEL_TINY , 74ull * MB } ,
{ MODEL_BASE , 142ull * MB } ,
{ MODEL_SMALL , 466ull * MB } ,
{ MODEL_MEDIUM , 1464ull * MB } ,
{ MODEL_LARGE , 2952ull * MB } ,
} ;
static const std : : map < e_model , size_t > MEM_REQ_MEMORY = {
{ MODEL_TINY , 12ull * MB } ,
{ MODEL_BASE , 24ull * MB } ,
{ MODEL_SMALL , 70ull * MB } ,
{ MODEL_MEDIUM , 184ull * MB } ,
{ MODEL_LARGE , 306ull * MB } ,
} ;
static const std : : map < e_model , size_t > MEM_REQ_ENCODE = {
@ -210,14 +219,6 @@ struct whisper_vocab {
}
} ;
struct whisper_token_data {
whisper_token id ; // token id
whisper_token tid ; // forced timestamp token id
float p ; // probability of the token
float pt ; // probability of the timestamp token
} ;
struct whisper_segment {
int64_t t0 ;
int64_t t1 ;
@ -386,6 +387,7 @@ struct whisper_model {
// context
struct ggml_context * ctx ;
struct ggml_context * ctx_mem ;
// tensors
int n_loaded ;
@ -400,9 +402,10 @@ struct whisper_context {
int64_t t_decode_us = 0 ;
int64_t t_start_us = 0 ;
std : : vector < uint8_t > buf_model ;
std : : vector < uint8_t > buf_compute ;
std : : vector < uint8_t > buf_compute_layer ;
std : : vector < uint8_t > * buf_model ; // the model buffer is read-only and can be shared between processors
std : : vector < uint8_t > buf_memory ;
std : : vector < uint8_t > buf_compute ;
std : : vector < uint8_t > buf_compute_layer ;
whisper_model model ;
whisper_vocab vocab ;
@ -412,10 +415,15 @@ struct whisper_context {
std : : vector < float > probs ;
std : : vector < float > logits ;
std : : vector < whisper_token_data > tokens_cur ;
std : : vector < whisper_segment > result_all ;
std : : vector < whisper_token > prompt_past ;
// [EXPERIMENTAL] token-level timestamps data
int64_t t_beg ;
int64_t t_last ;
whisper_token tid_last ;
std : : vector < float > energy ; // PCM signal energy
} ;
// load the model from a ggml file
@ -429,7 +437,7 @@ struct whisper_context {
//
// see the convert-pt-to-ggml.py script for details
//
bool whisper_model_load ( const std : : string & fname , whisper_context & wctx ) {
static bool whisper_model_load ( const std : : string & fname , whisper_context & wctx ) {
fprintf ( stderr , " %s: loading model from '%s' \n " , __func__ , fname . c_str ( ) ) ;
auto & model = wctx . model ;
@ -502,13 +510,16 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
fprintf ( stderr , " %s: f16 = %d \n " , __func__ , hparams . f16 ) ;
fprintf ( stderr , " %s: type = %d \n " , __func__ , model . type ) ;
wctx . buf_model . resize ( MEM_REQ_MODEL . at ( model . type ) ) ;
wctx . buf_model = new std : : vector < uint8_t > ( ) ;
wctx . buf_model - > resize ( MEM_REQ_MODEL . at ( model . type ) ) ;
wctx . buf_memory . resize ( MEM_REQ_MEMORY . at ( model . type ) ) ;
wctx . buf_compute . resize ( std : : max ( MEM_REQ_ENCODE . at ( model . type ) , MEM_REQ_DECODE . at ( model . type ) ) ) ;
wctx . buf_compute_layer . resize ( std : : max ( MEM_REQ_ENCODE_LAYER . at ( model . type ) , MEM_REQ_DECODE_LAYER . at ( model . type ) ) ) ;
// this is the total memory required to run the inference
const size_t mem_required =
wctx . buf_model . size ( ) +
wctx . buf_model - > size ( ) +
wctx . buf_memory . size ( ) +
wctx . buf_compute . size ( ) +
wctx . buf_compute_layer . size ( ) ;
@ -591,6 +602,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
size_t ctx_size = 0 ;
size_t ctx_mem_size = 0 ;
{
const auto & hparams = model . hparams ;
@ -699,11 +711,11 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
ctx_size + = n_text_layer * ( n_text_state * ggml_type_size ( GGML_TYPE_F32 ) ) ; // cross_attn_ln_1_b
}
ctx_ size + = n_text_layer * n_text_ctx * n_text_state * ggml_type_size ( GGML_TYPE_F16 ) ; // memory_k
ctx_ size + = n_text_layer * n_text_ctx * n_text_state * ggml_type_size ( GGML_TYPE_F16 ) ; // memory_v
ctx_ mem_ size + = n_text_layer * n_text_ctx * n_text_state * ggml_type_size ( GGML_TYPE_F16 ) ; // memory_k
ctx_ mem_ size + = n_text_layer * n_text_ctx * n_text_state * ggml_type_size ( GGML_TYPE_F16 ) ; // memory_v
ctx_ size + = n_text_layer * n_audio_ctx * n_text_state * ggml_type_size ( GGML_TYPE_F16 ) ; // memory_cross_k
ctx_ size + = n_text_layer * n_audio_ctx * n_text_state * ggml_type_size ( GGML_TYPE_F16 ) ; // memory_cross_v
ctx_ mem_ size + = n_text_layer * n_audio_ctx * n_text_state * ggml_type_size ( GGML_TYPE_F16 ) ; // memory_cross_k
ctx_ mem_ size + = n_text_layer * n_audio_ctx * n_text_state * ggml_type_size ( GGML_TYPE_F16 ) ; // memory_cross_v
ctx_size + = ( 15 + 15 * n_audio_layer + 24 * n_text_layer ) * 256 ; // object overhead
@ -713,8 +725,8 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
// create the ggml context
{
struct ggml_init_params params = {
. mem_size = wctx . buf_model . size ( ) ,
. mem_buffer = wctx . buf_model . data ( ) ,
. mem_size = wctx . buf_model - > size ( ) ,
. mem_buffer = wctx . buf_model - > data ( ) ,
} ;
model . ctx = ggml_init ( params ) ;
@ -920,9 +932,23 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
}
}
// create the ggml memory context
{
struct ggml_init_params params = {
. mem_size = wctx . buf_memory . size ( ) ,
. mem_buffer = wctx . buf_memory . data ( ) ,
} ;
model . ctx_mem = ggml_init ( params ) ;
if ( ! model . ctx_mem ) {
fprintf ( stderr , " %s: ggml_init() failed \n " , __func__ ) ;
return false ;
}
}
// key + value memory
{
auto & ctx = model . ctx ;
auto & ctx = model . ctx _mem ;
const auto & hparams = model . hparams ;
@ -954,7 +980,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
ggml_nbytes ( model . memory_k ) + ggml_nbytes ( model . memory_v ) +
ggml_nbytes ( model . memory_cross_k ) + ggml_nbytes ( model . memory_cross_v ) ;
fprintf ( stderr , " %s: memory size = %8.2f MB \n " , __func__ , memory_size / 1024.0 / 1024.0 ) ;
fprintf ( stderr , " %s: memory size = %8.2f MB \n " , __func__ , memory_size / 1024.0 / 1024.0 ) ;
}
// load weights
@ -1042,7 +1068,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
// - n_threads: number of threads to use
// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
//
bool whisper_encode (
static bool whisper_encode (
whisper_context & wctx ,
const int n_threads ,
const int mel_offset ) {
@ -1428,7 +1454,7 @@ bool whisper_encode(
// - n_tokens: number of tokens in the prompt
// - n_past: number of past tokens to prefix the prompt with
//
bool whisper_decode (
static bool whisper_decode (
whisper_context & wctx ,
const int n_threads ,
const whisper_token * tokens ,
@ -1791,10 +1817,12 @@ bool whisper_decode(
}
// the most basic sampling scheme - select the top token
whisper_token_data whisper_sample_best (
static whisper_token_data whisper_sample_best (
const whisper_vocab & vocab ,
const float * probs ) {
whisper_token_data result ;
whisper_token_data result = {
0 , 0 , 0.0f , 0.0f , 0.0f , - 1 , - 1 , 0.0f ,
} ;
int n_logits = vocab . id_to_token . size ( ) ;
@ -1831,7 +1859,8 @@ whisper_token_data whisper_sample_best(
}
}
result . pt = max_ts / ( sum_ts + 1e-6 ) ;
result . pt = max_ts / ( sum_ts + 1e-10 ) ;
result . ptsum = sum_ts ;
}
// find the top K tokens
@ -1866,7 +1895,7 @@ whisper_token_data whisper_sample_best(
}
// samples only from the timestamps tokens
whisper_vocab : : id whisper_sample_timestamp (
static whisper_vocab : : id whisper_sample_timestamp (
const whisper_vocab & vocab ,
const float * probs ) {
int n_logits = vocab . id_to_token . size ( ) ;
@ -1898,14 +1927,19 @@ whisper_vocab::id whisper_sample_timestamp(
return probs_id [ 0 ] . second ;
}
static std : : string to_timestamp ( int64_t t ) {
int64_t sec = t / 100 ;
int64_t msec = t - sec * 100 ;
int64_t min = sec / 60 ;
sec = sec - min * 60 ;
// 500 -> 00:05.000
// 6000 -> 01:00.000
static std : : string to_timestamp ( int64_t t , bool comma = false ) {
int64_t msec = t * 10 ;
int64_t hr = msec / ( 1000 * 60 * 60 ) ;
msec = msec - hr * ( 1000 * 60 * 60 ) ;
int64_t min = msec / ( 1000 * 60 ) ;
msec = msec - min * ( 1000 * 60 ) ;
int64_t sec = msec / 1000 ;
msec = msec - sec * 1000 ;
char buf [ 32 ] ;
snprintf ( buf , sizeof ( buf ) , " %02d:%02d.%03d " , ( int ) min , ( int ) sec , ( int ) msec ) ;
snprintf ( buf , sizeof ( buf ) , " %02d:%02d :%02d%s%03d" , ( int ) hr , ( int ) min , ( int ) sec , comma ? " , " : " . " , ( int ) msec ) ;
return std : : string ( buf ) ;
}
@ -1913,7 +1947,7 @@ static std::string to_timestamp(int64_t t) {
// naive Discrete Fourier Transform
// input is real-valued
// output is complex-valued
void dft ( const std : : vector < float > & in , std : : vector < float > & out ) {
static void dft ( const std : : vector < float > & in , std : : vector < float > & out ) {
int N = in . size ( ) ;
out . resize ( N * 2 ) ;
@ -1937,7 +1971,7 @@ void dft(const std::vector<float> & in, std::vector<float> & out) {
// poor man's implementation - use something better
// input is real-valued
// output is complex-valued
void fft ( const std : : vector < float > & in , std : : vector < float > & out ) {
static void fft ( const std : : vector < float > & in , std : : vector < float > & out ) {
out . resize ( in . size ( ) * 2 ) ;
int N = in . size ( ) ;
@ -1988,7 +2022,7 @@ void fft(const std::vector<float> & in, std::vector<float> & out) {
}
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
bool log_mel_spectrogram (
static bool log_mel_spectrogram (
const float * samples ,
const int n_samples ,
const int sample_rate ,
@ -2127,6 +2161,9 @@ struct whisper_context * whisper_init(const char * path_model) {
void whisper_free ( struct whisper_context * ctx ) {
if ( ctx ) {
if ( ctx - > buf_model ) {
delete ctx - > buf_model ;
}
delete ctx ;
}
}
@ -2189,7 +2226,7 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
return 0 ;
}
whisper_token whisper_sample_best ( struct whisper_context * ctx ) {
struct whisper_token_data whisper_sample_best ( struct whisper_context * ctx ) {
const int64_t t_start_sample_us = ggml_time_us ( ) ;
// TODO: simplify
@ -2197,7 +2234,7 @@ whisper_token whisper_sample_best(struct whisper_context * ctx) {
ctx - > t_sample_us + = ggml_time_us ( ) - t_start_sample_us ;
return res .id ;
return res ;
}
whisper_token whisper_sample_timestamp ( struct whisper_context * ctx ) {
@ -2300,6 +2337,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.strategy =*/ WHISPER_SAMPLING_GREEDY ,
/*.n_threads =*/ std : : min ( 4 , ( int32_t ) std : : thread : : hardware_concurrency ( ) ) ,
/*.n_max_text_ctx =*/ 16384 ,
/*.offset_ms =*/ 0 ,
/*.translate =*/ false ,
@ -2309,6 +2347,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.print_realtime =*/ false ,
/*.print_timestamps =*/ true ,
/*.token_timestamps =*/ false ,
/*.thold_pt =*/ 0.01f ,
/*.thold_ptsum =*/ 0.01f ,
/*.max_len =*/ 0 ,
/*.language =*/ " en " ,
/*.greedy =*/ {
@ -2331,6 +2374,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.strategy =*/ WHISPER_SAMPLING_BEAM_SEARCH ,
/*.n_threads =*/ std : : min ( 4 , ( int32_t ) std : : thread : : hardware_concurrency ( ) ) ,
/*.n_max_text_ctx =*/ 16384 ,
/*.offset_ms =*/ 0 ,
/*.translate =*/ false ,
@ -2340,6 +2384,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.print_realtime =*/ false ,
/*.print_timestamps =*/ true ,
/*.token_timestamps =*/ false ,
/*.thold_pt =*/ 0.01f ,
/*.thold_ptsum =*/ 0.01f ,
/*.max_len =*/ 0 ,
/*.language =*/ " en " ,
/*.greedy =*/ {
@ -2361,6 +2410,68 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
return result ;
}
// forward declarations
static std : : vector < float > get_signal_energy ( const float * signal , int n_samples , int n_samples_per_half_window ) ;
static void whisper_exp_compute_token_level_timestamps (
struct whisper_context * ctx ,
int i_segment ,
float thold_pt ,
float thold_ptsum ) ;
// wrap the last segment to max_len characters
// returns the number of new segments
static int whisper_wrap_segment ( struct whisper_context * ctx , int max_len ) {
auto segment = ctx - > result_all . back ( ) ;
int res = 1 ;
int acc = 0 ;
std : : string text ;
for ( int i = 0 ; i < ( int ) segment . tokens . size ( ) ; i + + ) {
const auto & token = segment . tokens [ i ] ;
if ( token . id > = whisper_token_eot ( ctx ) ) {
continue ;
}
const auto txt = whisper_token_to_str ( ctx , token . id ) ;
const int cur = strlen ( txt ) ;
if ( acc + cur > max_len & & i > 0 ) {
// split here
ctx - > result_all . back ( ) . text = std : : move ( text ) ;
ctx - > result_all . back ( ) . t1 = token . t0 ;
ctx - > result_all . back ( ) . tokens . resize ( i ) ;
ctx - > result_all . push_back ( { } ) ;
ctx - > result_all . back ( ) . t0 = token . t0 ;
ctx - > result_all . back ( ) . t1 = segment . t1 ;
// add tokens [i, end] to the new segment
ctx - > result_all . back ( ) . tokens . insert (
ctx - > result_all . back ( ) . tokens . end ( ) ,
segment . tokens . begin ( ) + i ,
segment . tokens . end ( ) ) ;
acc = 0 ;
text = " " ;
segment = ctx - > result_all . back ( ) ;
i = - 1 ;
res + + ;
} else {
acc + = cur ;
text + = txt ;
}
}
ctx - > result_all . back ( ) . text = std : : move ( text ) ;
return res ;
}
int whisper_full (
struct whisper_context * ctx ,
struct whisper_full_params params ,
@ -2368,7 +2479,6 @@ int whisper_full(
int n_samples ) {
// clear old results
auto & result_all = ctx - > result_all ;
auto & tokens_cur = ctx - > tokens_cur ;
result_all . clear ( ) ;
@ -2378,10 +2488,19 @@ int whisper_full(
return - 1 ;
}
if ( params . token_timestamps ) {
ctx - > t_beg = 0 ;
ctx - > t_last = 0 ;
ctx - > tid_last = 0 ;
ctx - > energy = get_signal_energy ( samples , n_samples , 32 ) ;
}
const int seek_start = params . offset_ms / 10 ;
// if length of spectrogram is less than 1s (100 samples), then return
// basically don't process anything that is less than 1s
// see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
if ( whisper_n_len ( ctx ) < 100 ) {
if ( whisper_n_len ( ctx ) < 100 + seek_start ) {
return 0 ;
}
@ -2405,8 +2524,14 @@ int whisper_full(
int progress_prev = 0 ;
int progress_step = 5 ;
std : : vector < whisper_token_data > tokens_cur ;
tokens_cur . reserve ( whisper_n_text_ctx ( ctx ) ) ;
std : : vector < whisper_token > prompt ;
prompt . reserve ( whisper_n_text_ctx ( ctx ) ) ;
// main loop
int seek = params . offset_ms / 10 ;
int seek = seek_start ;
while ( true ) {
int progress_cur = ( 100 * seek ) / whisper_n_len ( ctx ) ;
while ( progress_cur > = progress_prev + progress_step ) {
@ -2426,13 +2551,12 @@ int whisper_full(
return 7 ;
}
std : : vector < whisper_token > prompt ;
int n_past = 0 ;
prompt . clear ( ) ;
// if we have already generated some text, use it as a prompt to condition the next generation
if ( prompt_past . size ( ) > 0 ) {
int n_take = std : : min ( whisper_n_text_ctx( ctx ) / 2 , int ( prompt_past . size ( ) ) ) ;
int n_take = std : : min ( std: : min ( params . n_max_text_ctx , whisper_n_text_ctx( ctx ) / 2 ) , int ( prompt_past . size ( ) ) ) ;
prompt = { whisper_token_prev ( ctx ) } ;
prompt . insert ( prompt . begin ( ) + 1 , prompt_past . end ( ) - n_take , prompt_past . end ( ) ) ;
@ -2474,7 +2598,7 @@ int whisper_full(
// feel free to experiment!
//
{
auto token = whisper_sample_best ( ctx - > vocab , ctx - > probs . data ( ) + ( ctx - > probs . size ( ) - ctx - > vocab . n_vocab ) ) ;
auto token = whisper_sample_best ( ctx ) ;
if ( i = = 0 ) {
token . tid = whisper_token_beg ( ctx ) ;
@ -2490,7 +2614,10 @@ int whisper_full(
prompt . push_back ( token . id ) ;
tokens_cur . push_back ( token ) ;
//printf("%s: %s\n", __func__, ctx->vocab.id_to_token[id].c_str());
//{
// const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]";
// printf("%s: %10s %6.3f '%s'\n", __func__, tt.c_str(), token.pt, ctx->vocab.id_to_token[token.id].c_str());
//}
// end of text token
if ( token . id = = whisper_token_eot ( ctx ) ) {
@ -2517,6 +2644,7 @@ int whisper_full(
}
}
// shrink down to result_len
tokens_cur . resize ( result_len ) ;
for ( const auto & r : tokens_cur ) {
@ -2555,8 +2683,19 @@ int whisper_full(
for ( int j = i0 ; j < = i ; j + + ) {
result_all . back ( ) . tokens . push_back ( tokens_cur [ j ] ) ;
}
int n_new = 1 ;
if ( params . token_timestamps ) {
whisper_exp_compute_token_level_timestamps (
ctx , result_all . size ( ) - 1 , params . thold_pt , params . thold_ptsum ) ;
if ( params . max_len > 0 ) {
n_new = whisper_wrap_segment ( ctx , params . max_len ) ;
}
}
if ( params . new_segment_callback ) {
params . new_segment_callback ( ctx , params . new_segment_callback_user_data ) ;
params . new_segment_callback ( ctx , n_new, params. new_segment_callback_user_data ) ;
}
}
text = " " ;
@ -2585,8 +2724,19 @@ int whisper_full(
for ( int j = i0 ; j < ( int ) tokens_cur . size ( ) ; j + + ) {
result_all . back ( ) . tokens . push_back ( tokens_cur [ j ] ) ;
}
int n_new = 1 ;
if ( params . token_timestamps ) {
whisper_exp_compute_token_level_timestamps (
ctx , result_all . size ( ) - 1 , params . thold_pt , params . thold_ptsum ) ;
if ( params . max_len > 0 ) {
n_new = whisper_wrap_segment ( ctx , params . max_len ) ;
}
}
if ( params . new_segment_callback ) {
params . new_segment_callback ( ctx , params . new_segment_callback_user_data ) ;
params . new_segment_callback ( ctx , n_new, params. new_segment_callback_user_data ) ;
}
}
}
@ -2597,6 +2747,156 @@ int whisper_full(
return 0 ;
}
int whisper_full_parallel (
struct whisper_context * ctx ,
struct whisper_full_params params ,
const float * samples ,
int n_samples ,
const int n_processors ) {
if ( n_processors = = 1 ) {
return whisper_full ( ctx , params , samples , n_samples ) ;
}
int ret = 0 ;
// prepare separate contexts for each thread
std : : vector < struct whisper_context > ctxs ( n_processors - 1 ) ;
for ( int i = 0 ; i < n_processors - 1 ; + + i ) {
ctxs [ i ] = * ctx ;
auto & model = ctxs [ i ] . model ;
// create the ggml memory context
{
struct ggml_init_params params = {
. mem_size = ctxs [ i ] . buf_memory . size ( ) ,
. mem_buffer = ctxs [ i ] . buf_memory . data ( ) ,
} ;
model . ctx_mem = ggml_init ( params ) ;
if ( ! model . ctx_mem ) {
fprintf ( stderr , " %s: ggml_init() failed \n " , __func__ ) ;
return false ;
}
}
// separate key + value memory for each processor
{
auto & ctx = model . ctx_mem ;
const auto & hparams = model . hparams ;
const int n_text_state = hparams . n_text_state ;
const int n_text_layer = hparams . n_text_layer ;
const int n_text_ctx = hparams . n_text_ctx ;
// key/value memory for the self-attention layer
{
const int n_mem = n_text_layer * n_text_ctx ;
const int n_elements = n_text_state * n_mem ;
model . memory_k = ggml_new_tensor_1d ( ctx , GGML_TYPE_F16 , n_elements ) ;
model . memory_v = ggml_new_tensor_1d ( ctx , GGML_TYPE_F16 , n_elements ) ;
}
// key/value memory for the cross-attention layer
{
const int n_audio_ctx = hparams . n_audio_ctx ;
const int n_mem = n_text_layer * n_audio_ctx ;
const int n_elements = n_text_state * n_mem ;
model . memory_cross_k = ggml_new_tensor_1d ( ctx , GGML_TYPE_F16 , n_elements ) ;
model . memory_cross_v = ggml_new_tensor_1d ( ctx , GGML_TYPE_F16 , n_elements ) ;
}
const size_t memory_size =
ggml_nbytes ( model . memory_k ) + ggml_nbytes ( model . memory_v ) +
ggml_nbytes ( model . memory_cross_k ) + ggml_nbytes ( model . memory_cross_v ) ;
}
}
const int offset_samples = ( WHISPER_SAMPLE_RATE * params . offset_ms ) / 1000 ;
const int n_samples_per_processor = ( n_samples - offset_samples ) / n_processors ;
// the calling thread will process the first chunk
// while the other threads will process the remaining chunks
std : : vector < std : : thread > workers ( n_processors - 1 ) ;
for ( int i = 0 ; i < n_processors - 1 ; + + i ) {
const int start_samples = offset_samples + ( i + 1 ) * n_samples_per_processor ;
const int n_samples_cur = ( i = = n_processors - 2 ) ? n_samples - start_samples : n_samples_per_processor ;
auto params_cur = params ;
params_cur . offset_ms = 0 ;
params_cur . print_progress = false ;
params_cur . print_realtime = false ;
params_cur . new_segment_callback = nullptr ;
params_cur . new_segment_callback_user_data = nullptr ;
workers [ i ] = std : : thread ( whisper_full , & ctxs [ i ] , std : : move ( params_cur ) , samples + start_samples , n_samples_cur ) ;
}
{
auto params_cur = params ;
ret = whisper_full ( ctx , std : : move ( params_cur ) , samples , offset_samples + n_samples_per_processor ) ;
}
for ( int i = 0 ; i < n_processors - 1 ; + + i ) {
workers [ i ] . join ( ) ;
}
const int64_t offset_t = ( int64_t ) params . offset_ms / 10.0 ;
// combine results into ctx->result_all
for ( int i = 0 ; i < n_processors - 1 ; + + i ) {
auto & results_i = ctxs [ i ] . result_all ;
for ( int j = 0 ; j < ( int ) results_i . size ( ) ; + + j ) {
// correct the segment timestamp taking into account the offset
results_i [ j ] . t0 + = 100 * ( ( i + 1 ) * n_samples_per_processor ) / WHISPER_SAMPLE_RATE + offset_t ;
results_i [ j ] . t1 + = 100 * ( ( i + 1 ) * n_samples_per_processor ) / WHISPER_SAMPLE_RATE + offset_t ;
// make sure that segments are not overlapping
if ( ctx - > result_all . size ( ) > 0 ) {
results_i [ j ] . t0 = std : : max ( results_i [ j ] . t0 , ctx - > result_all . back ( ) . t1 ) ;
}
ctx - > result_all . push_back ( std : : move ( results_i [ j ] ) ) ;
// call the new_segment_callback for each segment
if ( params . new_segment_callback ) {
params . new_segment_callback ( ctx , 1 , params . new_segment_callback_user_data ) ;
}
}
ctx - > t_mel_us + = ctxs [ i ] . t_mel_us ;
ctx - > t_sample_us + = ctxs [ i ] . t_sample_us ;
ctx - > t_encode_us + = ctxs [ i ] . t_encode_us ;
ctx - > t_decode_us + = ctxs [ i ] . t_decode_us ;
}
// average the timings
ctx - > t_mel_us / = n_processors ;
ctx - > t_sample_us / = n_processors ;
ctx - > t_encode_us / = n_processors ;
ctx - > t_decode_us / = n_processors ;
// print information about the audio boundaries
fprintf ( stderr , " \n " ) ;
fprintf ( stderr , " %s: the audio has been split into %d chunks at the following times: \n " , __func__ , n_processors ) ;
for ( int i = 0 ; i < n_processors - 1 ; + + i ) {
fprintf ( stderr , " %s: split %d - %s \n " , __func__ , ( i + 1 ) , to_timestamp ( 100 * ( ( i + 1 ) * n_samples_per_processor ) / WHISPER_SAMPLE_RATE + offset_t ) . c_str ( ) ) ;
}
fprintf ( stderr , " %s: the transcription quality may be degraded near these boundaries \n " , __func__ ) ;
return ret ;
}
int whisper_full_n_segments ( struct whisper_context * ctx ) {
return ctx - > result_all . size ( ) ;
}
@ -2625,6 +2925,10 @@ whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segm
return ctx - > result_all [ i_segment ] . tokens [ i_token ] . id ;
}
struct whisper_token_data whisper_full_get_token_data ( struct whisper_context * ctx , int i_segment , int i_token ) {
return ctx - > result_all [ i_segment ] . tokens [ i_token ] ;
}
float whisper_full_get_token_p ( struct whisper_context * ctx , int i_segment , int i_token ) {
return ctx - > result_all [ i_segment ] . tokens [ i_token ] . p ;
}
@ -2642,3 +2946,304 @@ const char * whisper_print_system_info() {
return s . c_str ( ) ;
}
// =================================================================================================
//
// Experimental stuff below
//
// Not sure if these should be part of the library at all, because the quality of the results is not
// guaranteed. Might get removed at some point unless a robust algorithm implementation is found
//
// =================================================================================================
//
// token-level timestamps
//
static int timestamp_to_sample ( int64_t t , int n_samples ) {
return std : : max ( 0 , std : : min ( ( int ) n_samples - 1 , ( int ) ( ( t * WHISPER_SAMPLE_RATE ) / 100 ) ) ) ;
}
static int64_t sample_to_timestamp ( int i_sample ) {
return ( 100 * i_sample ) / WHISPER_SAMPLE_RATE ;
}
// a cost-function / heuristic that is high for text that takes longer to pronounce
// obviously, can be improved
static float voice_length ( const std : : string & text ) {
float res = 0.0f ;
for ( size_t i = 0 ; i < text . size ( ) ; + + i ) {
if ( text [ i ] = = ' ' ) {
res + = 0.01f ;
} else if ( text [ i ] = = ' , ' ) {
res + = 2.00f ;
} else if ( text [ i ] = = ' . ' ) {
res + = 3.00f ;
} else if ( text [ i ] = = ' ! ' ) {
res + = 3.00f ;
} else if ( text [ i ] = = ' ? ' ) {
res + = 3.00f ;
} else if ( text [ i ] > = ' 0 ' & & text [ i ] < = ' 9 ' ) {
res + = 3.00f ;
} else {
res + = 1.00f ;
}
}
return res ;
}
// average the fabs of the signal
static std : : vector < float > get_signal_energy ( const float * signal , int n_samples , int n_samples_per_half_window ) {
const int hw = n_samples_per_half_window ;
std : : vector < float > result ( n_samples ) ;
for ( int i = 0 ; i < n_samples ; i + + ) {
float sum = 0 ;
for ( int j = - hw ; j < = hw ; j + + ) {
if ( i + j > = 0 & & i + j < n_samples ) {
sum + = fabs ( signal [ i + j ] ) ;
}
}
result [ i ] = sum / ( 2 * hw + 1 ) ;
}
return result ;
}
static void whisper_exp_compute_token_level_timestamps (
struct whisper_context * ctx ,
int i_segment ,
float thold_pt ,
float thold_ptsum ) {
auto & segment = ctx - > result_all [ i_segment ] ;
auto & tokens = segment . tokens ;
const int n_samples = ctx - > energy . size ( ) ;
if ( n_samples = = 0 ) {
fprintf ( stderr , " %s: no signal data available \n " , __func__ ) ;
return ;
}
const int64_t t0 = segment . t0 ;
const int64_t t1 = segment . t1 ;
const int s0 = timestamp_to_sample ( t0 , n_samples ) ;
const int s1 = timestamp_to_sample ( t1 , n_samples ) ;
const int n = tokens . size ( ) ;
if ( n = = 0 ) {
return ;
}
if ( n = = 1 ) {
tokens [ 0 ] . t0 = t0 ;
tokens [ 0 ] . t1 = t1 ;
return ;
}
auto & t_beg = ctx - > t_beg ;
auto & t_last = ctx - > t_last ;
auto & tid_last = ctx - > tid_last ;
for ( int j = 0 ; j < n ; + + j ) {
auto & token = tokens [ j ] ;
if ( j = = 0 ) {
if ( token . id = = whisper_token_beg ( ctx ) ) {
tokens [ j ] . t0 = t0 ;
tokens [ j ] . t1 = t0 ;
tokens [ j + 1 ] . t0 = t0 ;
t_beg = t0 ;
t_last = t0 ;
tid_last = whisper_token_beg ( ctx ) ;
} else {
tokens [ j ] . t0 = t_last ;
}
}
const int64_t tt = t_beg + 2 * ( token . tid - whisper_token_beg ( ctx ) ) ;
tokens [ j ] . id = token . id ;
tokens [ j ] . tid = token . tid ;
tokens [ j ] . p = token . p ;
tokens [ j ] . pt = token . pt ;
tokens [ j ] . ptsum = token . ptsum ;
tokens [ j ] . vlen = voice_length ( whisper_token_to_str ( ctx , token . id ) ) ;
if ( token . pt > thold_pt & & token . ptsum > thold_ptsum & & token . tid > tid_last & & tt < = t1 ) {
if ( j > 0 ) {
tokens [ j - 1 ] . t1 = tt ;
}
tokens [ j ] . t0 = tt ;
tid_last = token . tid ;
}
}
tokens [ n - 2 ] . t1 = t1 ;
tokens [ n - 1 ] . t0 = t1 ;
tokens [ n - 1 ] . t1 = t1 ;
t_last = t1 ;
// find intervals of tokens with unknown timestamps
// fill the timestamps by proportionally splitting the interval based on the token voice lengths
{
int p0 = 0 ;
int p1 = 0 ;
while ( true ) {
while ( p1 < n & & tokens [ p1 ] . t1 < 0 ) {
p1 + + ;
}
if ( p1 > = n ) {
p1 - - ;
}
if ( p1 > p0 ) {
double psum = 0.0 ;
for ( int j = p0 ; j < = p1 ; j + + ) {
psum + = tokens [ j ] . vlen ;
}
//printf("analyzing %d - %d, psum = %f\n", p0, p1, psum);
const double dt = tokens [ p1 ] . t1 - tokens [ p0 ] . t0 ;
// split the time proportionally to the voice length
for ( int j = p0 + 1 ; j < = p1 ; j + + ) {
const double ct = tokens [ j - 1 ] . t0 + dt * tokens [ j - 1 ] . vlen / psum ;
tokens [ j - 1 ] . t1 = ct ;
tokens [ j ] . t0 = ct ;
}
}
p1 + + ;
p0 = p1 ;
if ( p1 > = n ) {
break ;
}
}
}
// fix up (just in case)
for ( int j = 0 ; j < n - 1 ; j + + ) {
if ( tokens [ j ] . t1 < 0 ) {
tokens [ j + 1 ] . t0 = tokens [ j ] . t1 ;
}
if ( j > 0 ) {
if ( tokens [ j - 1 ] . t1 > tokens [ j ] . t0 ) {
tokens [ j ] . t0 = tokens [ j - 1 ] . t1 ;
tokens [ j ] . t1 = std : : max ( tokens [ j ] . t0 , tokens [ j ] . t1 ) ;
}
}
}
// VAD
// expand or contract tokens based on voice activity
{
const int hw = WHISPER_SAMPLE_RATE / 8 ;
for ( int j = 0 ; j < n ; j + + ) {
if ( tokens [ j ] . id > = whisper_token_eot ( ctx ) ) {
continue ;
}
int s0 = timestamp_to_sample ( tokens [ j ] . t0 , n_samples ) ;
int s1 = timestamp_to_sample ( tokens [ j ] . t1 , n_samples ) ;
const int ss0 = std : : max ( s0 - hw , 0 ) ;
const int ss1 = std : : min ( s1 + hw , n_samples ) ;
const int ns = ss1 - ss0 ;
float sum = 0.0f ;
for ( int k = ss0 ; k < ss1 ; k + + ) {
sum + = ctx - > energy [ k ] ;
}
const float thold = 0.5 * sum / ns ;
{
int k = s0 ;
if ( ctx - > energy [ k ] > thold & & j > 0 ) {
while ( k > 0 & & ctx - > energy [ k ] > thold ) {
k - - ;
}
tokens [ j ] . t0 = sample_to_timestamp ( k ) ;
if ( tokens [ j ] . t0 < tokens [ j - 1 ] . t1 ) {
tokens [ j ] . t0 = tokens [ j - 1 ] . t1 ;
} else {
s0 = k ;
}
} else {
while ( ctx - > energy [ k ] < thold & & k < s1 ) {
k + + ;
}
s0 = k ;
tokens [ j ] . t0 = sample_to_timestamp ( k ) ;
}
}
{
int k = s1 ;
if ( ctx - > energy [ k ] > thold ) {
while ( k < n_samples - 1 & & ctx - > energy [ k ] > thold ) {
k + + ;
}
tokens [ j ] . t1 = sample_to_timestamp ( k ) ;
if ( j < ns - 1 & & tokens [ j ] . t1 > tokens [ j + 1 ] . t0 ) {
tokens [ j ] . t1 = tokens [ j + 1 ] . t0 ;
} else {
s1 = k ;
}
} else {
while ( ctx - > energy [ k ] < thold & & k > s0 ) {
k - - ;
}
s1 = k ;
tokens [ j ] . t1 = sample_to_timestamp ( k ) ;
}
}
}
}
// fixed token expand (optional)
//{
// const int t_expand = 0;
// for (int j = 0; j < n; j++) {
// if (j > 0) {
// tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand));
// }
// if (j < n - 1) {
// tokens[j].t1 = tokens[j].t1 + t_expand;
// }
// }
//}
// debug info
//for (int j = 0; j < n; ++j) {
// const auto & token = tokens[j];
// const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]";
// printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__,
// tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(ctx, token.id));
// if (tokens[j].id >= whisper_token_eot(ctx)) {
// continue;
// }
//}
}