@ -418,6 +418,12 @@ struct whisper_context {
std : : vector < whisper_segment > result_all ;
std : : vector < whisper_token > prompt_past ;
// [EXPERIMENTAL] token-level timestamps data
int64_t t_beg ;
int64_t t_last ;
whisper_token tid_last ;
std : : vector < float > energy ; // PCM signal energy
} ;
// load the model from a ggml file
@ -431,7 +437,7 @@ struct whisper_context {
//
// see the convert-pt-to-ggml.py script for details
//
bool whisper_model_load ( const std : : string & fname , whisper_context & wctx ) {
static bool whisper_model_load ( const std : : string & fname , whisper_context & wctx ) {
fprintf ( stderr , " %s: loading model from '%s' \n " , __func__ , fname . c_str ( ) ) ;
auto & model = wctx . model ;
@ -1062,7 +1068,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
// - n_threads: number of threads to use
// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
//
bool whisper_encode (
static bool whisper_encode (
whisper_context & wctx ,
const int n_threads ,
const int mel_offset ) {
@ -1448,7 +1454,7 @@ bool whisper_encode(
// - n_tokens: number of tokens in the prompt
// - n_past: number of past tokens to prefix the prompt with
//
bool whisper_decode (
static bool whisper_decode (
whisper_context & wctx ,
const int n_threads ,
const whisper_token * tokens ,
@ -1811,10 +1817,12 @@ bool whisper_decode(
}
// the most basic sampling scheme - select the top token
whisper_token_data whisper_sample_best (
static whisper_token_data whisper_sample_best (
const whisper_vocab & vocab ,
const float * probs ) {
whisper_token_data result ;
whisper_token_data result = {
0 , 0 , 0.0f , 0.0f , 0.0f , - 1 , - 1 , 0.0f ,
} ;
int n_logits = vocab . id_to_token . size ( ) ;
@ -1887,7 +1895,7 @@ whisper_token_data whisper_sample_best(
}
// samples only from the timestamps tokens
whisper_vocab : : id whisper_sample_timestamp (
static whisper_vocab : : id whisper_sample_timestamp (
const whisper_vocab & vocab ,
const float * probs ) {
int n_logits = vocab . id_to_token . size ( ) ;
@ -1939,7 +1947,7 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
// naive Discrete Fourier Transform
// input is real-valued
// output is complex-valued
void dft ( const std : : vector < float > & in , std : : vector < float > & out ) {
static void dft ( const std : : vector < float > & in , std : : vector < float > & out ) {
int N = in . size ( ) ;
out . resize ( N * 2 ) ;
@ -1963,7 +1971,7 @@ void dft(const std::vector<float> & in, std::vector<float> & out) {
// poor man's implementation - use something better
// input is real-valued
// output is complex-valued
void fft ( const std : : vector < float > & in , std : : vector < float > & out ) {
static void fft ( const std : : vector < float > & in , std : : vector < float > & out ) {
out . resize ( in . size ( ) * 2 ) ;
int N = in . size ( ) ;
@ -2014,7 +2022,7 @@ void fft(const std::vector<float> & in, std::vector<float> & out) {
}
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
bool log_mel_spectrogram (
static bool log_mel_spectrogram (
const float * samples ,
const int n_samples ,
const int sample_rate ,
@ -2339,6 +2347,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.print_realtime =*/ false ,
/*.print_timestamps =*/ true ,
/*.token_timestamps =*/ false ,
/*.thold_pt =*/ 0.01f ,
/*.thold_ptsum =*/ 0.01f ,
/*.max_len =*/ 0 ,
/*.language =*/ " en " ,
/*.greedy =*/ {
@ -2371,6 +2384,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.print_realtime =*/ false ,
/*.print_timestamps =*/ true ,
/*.token_timestamps =*/ false ,
/*.thold_pt =*/ 0.01f ,
/*.thold_ptsum =*/ 0.01f ,
/*.max_len =*/ 0 ,
/*.language =*/ " en " ,
/*.greedy =*/ {
@ -2392,6 +2410,68 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
return result ;
}
// forward declarations
static std : : vector < float > get_signal_energy ( const float * signal , int n_samples , int n_samples_per_half_window ) ;
static void whisper_exp_compute_token_level_timestamps (
struct whisper_context * ctx ,
int i_segment ,
float thold_pt ,
float thold_ptsum ) ;
// wrap the last segment to max_len characters
// returns the number of new segments
static int whisper_wrap_segment ( struct whisper_context * ctx , int max_len ) {
auto segment = ctx - > result_all . back ( ) ;
int res = 1 ;
int acc = 0 ;
std : : string text ;
for ( int i = 0 ; i < ( int ) segment . tokens . size ( ) ; i + + ) {
const auto & token = segment . tokens [ i ] ;
if ( token . id > = whisper_token_eot ( ctx ) ) {
continue ;
}
const auto txt = whisper_token_to_str ( ctx , token . id ) ;
const int cur = strlen ( txt ) ;
if ( acc + cur > max_len & & i > 0 ) {
// split here
ctx - > result_all . back ( ) . text = std : : move ( text ) ;
ctx - > result_all . back ( ) . t1 = token . t0 ;
ctx - > result_all . back ( ) . tokens . resize ( i ) ;
ctx - > result_all . push_back ( { } ) ;
ctx - > result_all . back ( ) . t0 = token . t0 ;
ctx - > result_all . back ( ) . t1 = segment . t1 ;
// add tokens [i, end] to the new segment
ctx - > result_all . back ( ) . tokens . insert (
ctx - > result_all . back ( ) . tokens . end ( ) ,
segment . tokens . begin ( ) + i ,
segment . tokens . end ( ) ) ;
acc = 0 ;
text = " " ;
segment = ctx - > result_all . back ( ) ;
i = - 1 ;
res + + ;
} else {
acc + = cur ;
text + = txt ;
}
}
ctx - > result_all . back ( ) . text = std : : move ( text ) ;
return res ;
}
int whisper_full (
struct whisper_context * ctx ,
struct whisper_full_params params ,
@ -2408,6 +2488,13 @@ int whisper_full(
return - 1 ;
}
if ( params . token_timestamps ) {
ctx - > t_beg = 0 ;
ctx - > t_last = 0 ;
ctx - > tid_last = 0 ;
ctx - > energy = get_signal_energy ( samples , n_samples , 32 ) ;
}
const int seek_start = params . offset_ms / 10 ;
// if length of spectrogram is less than 1s (100 samples), then return
@ -2557,6 +2644,7 @@ int whisper_full(
}
}
// shrink down to result_len
tokens_cur . resize ( result_len ) ;
for ( const auto & r : tokens_cur ) {
@ -2595,8 +2683,19 @@ int whisper_full(
for ( int j = i0 ; j < = i ; j + + ) {
result_all . back ( ) . tokens . push_back ( tokens_cur [ j ] ) ;
}
int n_new = 1 ;
if ( params . token_timestamps ) {
whisper_exp_compute_token_level_timestamps (
ctx , result_all . size ( ) - 1 , params . thold_pt , params . thold_ptsum ) ;
if ( params . max_len > 0 ) {
n_new = whisper_wrap_segment ( ctx , params . max_len ) ;
}
}
if ( params . new_segment_callback ) {
params . new_segment_callback ( ctx , params . new_segment_callback_user_data ) ;
params . new_segment_callback ( ctx , n_new, params. new_segment_callback_user_data ) ;
}
}
text = " " ;
@ -2625,8 +2724,19 @@ int whisper_full(
for ( int j = i0 ; j < ( int ) tokens_cur . size ( ) ; j + + ) {
result_all . back ( ) . tokens . push_back ( tokens_cur [ j ] ) ;
}
int n_new = 1 ;
if ( params . token_timestamps ) {
whisper_exp_compute_token_level_timestamps (
ctx , result_all . size ( ) - 1 , params . thold_pt , params . thold_ptsum ) ;
if ( params . max_len > 0 ) {
n_new = whisper_wrap_segment ( ctx , params . max_len ) ;
}
}
if ( params . new_segment_callback ) {
params . new_segment_callback ( ctx , params . new_segment_callback_user_data ) ;
params . new_segment_callback ( ctx , n_new, params. new_segment_callback_user_data ) ;
}
}
}
@ -2760,7 +2870,7 @@ int whisper_full_parallel(
// call the new_segment_callback for each segment
if ( params . new_segment_callback ) {
params . new_segment_callback ( ctx , params . new_segment_callback_user_data ) ;
params . new_segment_callback ( ctx , 1 , params . new_segment_callback_user_data ) ;
}
}
@ -2836,3 +2946,304 @@ const char * whisper_print_system_info() {
return s . c_str ( ) ;
}
// =================================================================================================
//
// Experimental stuff below
//
// Not sure if these should be part of the library at all, because the quality of the results is not
// guaranteed. Might get removed at some point unless a robust algorithm implementation is found
//
// =================================================================================================
//
// token-level timestamps
//
static int timestamp_to_sample ( int64_t t , int n_samples ) {
return std : : max ( 0 , std : : min ( ( int ) n_samples - 1 , ( int ) ( ( t * WHISPER_SAMPLE_RATE ) / 100 ) ) ) ;
}
static int64_t sample_to_timestamp ( int i_sample ) {
return ( 100 * i_sample ) / WHISPER_SAMPLE_RATE ;
}
// a cost-function / heuristic that is high for text that takes longer to pronounce
// obviously, can be improved
static float voice_length ( const std : : string & text ) {
float res = 0.0f ;
for ( size_t i = 0 ; i < text . size ( ) ; + + i ) {
if ( text [ i ] = = ' ' ) {
res + = 0.01f ;
} else if ( text [ i ] = = ' , ' ) {
res + = 2.00f ;
} else if ( text [ i ] = = ' . ' ) {
res + = 3.00f ;
} else if ( text [ i ] = = ' ! ' ) {
res + = 3.00f ;
} else if ( text [ i ] = = ' ? ' ) {
res + = 3.00f ;
} else if ( text [ i ] > = ' 0 ' & & text [ i ] < = ' 9 ' ) {
res + = 3.00f ;
} else {
res + = 1.00f ;
}
}
return res ;
}
// average the fabs of the signal
static std : : vector < float > get_signal_energy ( const float * signal , int n_samples , int n_samples_per_half_window ) {
const int hw = n_samples_per_half_window ;
std : : vector < float > result ( n_samples ) ;
for ( int i = 0 ; i < n_samples ; i + + ) {
float sum = 0 ;
for ( int j = - hw ; j < = hw ; j + + ) {
if ( i + j > = 0 & & i + j < n_samples ) {
sum + = fabs ( signal [ i + j ] ) ;
}
}
result [ i ] = sum / ( 2 * hw + 1 ) ;
}
return result ;
}
static void whisper_exp_compute_token_level_timestamps (
struct whisper_context * ctx ,
int i_segment ,
float thold_pt ,
float thold_ptsum ) {
auto & segment = ctx - > result_all [ i_segment ] ;
auto & tokens = segment . tokens ;
const int n_samples = ctx - > energy . size ( ) ;
if ( n_samples = = 0 ) {
fprintf ( stderr , " %s: no signal data available \n " , __func__ ) ;
return ;
}
const int64_t t0 = segment . t0 ;
const int64_t t1 = segment . t1 ;
const int s0 = timestamp_to_sample ( t0 , n_samples ) ;
const int s1 = timestamp_to_sample ( t1 , n_samples ) ;
const int n = tokens . size ( ) ;
if ( n = = 0 ) {
return ;
}
if ( n = = 1 ) {
tokens [ 0 ] . t0 = t0 ;
tokens [ 0 ] . t1 = t1 ;
return ;
}
auto & t_beg = ctx - > t_beg ;
auto & t_last = ctx - > t_last ;
auto & tid_last = ctx - > tid_last ;
for ( int j = 0 ; j < n ; + + j ) {
auto & token = tokens [ j ] ;
if ( j = = 0 ) {
if ( token . id = = whisper_token_beg ( ctx ) ) {
tokens [ j ] . t0 = t0 ;
tokens [ j ] . t1 = t0 ;
tokens [ j + 1 ] . t0 = t0 ;
t_beg = t0 ;
t_last = t0 ;
tid_last = whisper_token_beg ( ctx ) ;
} else {
tokens [ j ] . t0 = t_last ;
}
}
const int64_t tt = t_beg + 2 * ( token . tid - whisper_token_beg ( ctx ) ) ;
tokens [ j ] . id = token . id ;
tokens [ j ] . tid = token . tid ;
tokens [ j ] . p = token . p ;
tokens [ j ] . pt = token . pt ;
tokens [ j ] . ptsum = token . ptsum ;
tokens [ j ] . vlen = voice_length ( whisper_token_to_str ( ctx , token . id ) ) ;
if ( token . pt > thold_pt & & token . ptsum > thold_ptsum & & token . tid > tid_last & & tt < = t1 ) {
if ( j > 0 ) {
tokens [ j - 1 ] . t1 = tt ;
}
tokens [ j ] . t0 = tt ;
tid_last = token . tid ;
}
}
tokens [ n - 2 ] . t1 = t1 ;
tokens [ n - 1 ] . t0 = t1 ;
tokens [ n - 1 ] . t1 = t1 ;
t_last = t1 ;
// find intervals of tokens with unknown timestamps
// fill the timestamps by proportionally splitting the interval based on the token voice lengths
{
int p0 = 0 ;
int p1 = 0 ;
while ( true ) {
while ( p1 < n & & tokens [ p1 ] . t1 < 0 ) {
p1 + + ;
}
if ( p1 > = n ) {
p1 - - ;
}
if ( p1 > p0 ) {
double psum = 0.0 ;
for ( int j = p0 ; j < = p1 ; j + + ) {
psum + = tokens [ j ] . vlen ;
}
//printf("analyzing %d - %d, psum = %f\n", p0, p1, psum);
const double dt = tokens [ p1 ] . t1 - tokens [ p0 ] . t0 ;
// split the time proportionally to the voice length
for ( int j = p0 + 1 ; j < = p1 ; j + + ) {
const double ct = tokens [ j - 1 ] . t0 + dt * tokens [ j - 1 ] . vlen / psum ;
tokens [ j - 1 ] . t1 = ct ;
tokens [ j ] . t0 = ct ;
}
}
p1 + + ;
p0 = p1 ;
if ( p1 > = n ) {
break ;
}
}
}
// fix up (just in case)
for ( int j = 0 ; j < n - 1 ; j + + ) {
if ( tokens [ j ] . t1 < 0 ) {
tokens [ j + 1 ] . t0 = tokens [ j ] . t1 ;
}
if ( j > 0 ) {
if ( tokens [ j - 1 ] . t1 > tokens [ j ] . t0 ) {
tokens [ j ] . t0 = tokens [ j - 1 ] . t1 ;
tokens [ j ] . t1 = std : : max ( tokens [ j ] . t0 , tokens [ j ] . t1 ) ;
}
}
}
// VAD
// expand or contract tokens based on voice activity
{
const int hw = WHISPER_SAMPLE_RATE / 8 ;
for ( int j = 0 ; j < n ; j + + ) {
if ( tokens [ j ] . id > = whisper_token_eot ( ctx ) ) {
continue ;
}
int s0 = timestamp_to_sample ( tokens [ j ] . t0 , n_samples ) ;
int s1 = timestamp_to_sample ( tokens [ j ] . t1 , n_samples ) ;
const int ss0 = std : : max ( s0 - hw , 0 ) ;
const int ss1 = std : : min ( s1 + hw , n_samples ) ;
const int ns = ss1 - ss0 ;
float sum = 0.0f ;
for ( int k = ss0 ; k < ss1 ; k + + ) {
sum + = ctx - > energy [ k ] ;
}
const float thold = 0.5 * sum / ns ;
{
int k = s0 ;
if ( ctx - > energy [ k ] > thold & & j > 0 ) {
while ( k > 0 & & ctx - > energy [ k ] > thold ) {
k - - ;
}
tokens [ j ] . t0 = sample_to_timestamp ( k ) ;
if ( tokens [ j ] . t0 < tokens [ j - 1 ] . t1 ) {
tokens [ j ] . t0 = tokens [ j - 1 ] . t1 ;
} else {
s0 = k ;
}
} else {
while ( ctx - > energy [ k ] < thold & & k < s1 ) {
k + + ;
}
s0 = k ;
tokens [ j ] . t0 = sample_to_timestamp ( k ) ;
}
}
{
int k = s1 ;
if ( ctx - > energy [ k ] > thold ) {
while ( k < n_samples - 1 & & ctx - > energy [ k ] > thold ) {
k + + ;
}
tokens [ j ] . t1 = sample_to_timestamp ( k ) ;
if ( j < ns - 1 & & tokens [ j ] . t1 > tokens [ j + 1 ] . t0 ) {
tokens [ j ] . t1 = tokens [ j + 1 ] . t0 ;
} else {
s1 = k ;
}
} else {
while ( ctx - > energy [ k ] < thold & & k > s0 ) {
k - - ;
}
s1 = k ;
tokens [ j ] . t1 = sample_to_timestamp ( k ) ;
}
}
}
}
// fixed token expand (optional)
//{
// const int t_expand = 0;
// for (int j = 0; j < n; j++) {
// if (j > 0) {
// tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand));
// }
// if (j < n - 1) {
// tokens[j].t1 = tokens[j].t1 + t_expand;
// }
// }
//}
// debug info
//for (int j = 0; j < n; ++j) {
// const auto & token = tokens[j];
// const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]";
// printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__,
// tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(ctx, token.id));
// if (tokens[j].id >= whisper_token_eot(ctx)) {
// continue;
// }
//}
}