@ -5,6 +5,7 @@
# define DR_WAV_IMPLEMENTATION
# define DR_WAV_IMPLEMENTATION
# include "dr_wav.h"
# include "dr_wav.h"
# include <fstream>
# include <cstdio>
# include <cstdio>
# include <string>
# include <string>
# include <thread>
# include <thread>
@ -28,15 +29,20 @@ std::string to_timestamp(int64_t t) {
struct whisper_params {
struct whisper_params {
int32_t seed = - 1 ; // RNG seed, not used currently
int32_t seed = - 1 ; // RNG seed, not used currently
int32_t n_threads = std : : min ( 4 , ( int32_t ) std : : thread : : hardware_concurrency ( ) ) ;
int32_t n_threads = std : : min ( 4 , ( int32_t ) std : : thread : : hardware_concurrency ( ) ) ;
int32_t offset_ms = 0 ;
bool verbose = false ;
bool verbose = false ;
bool translate = false ;
bool translate = false ;
bool output_txt = false ;
bool output_vtt = false ;
bool output_srt = false ;
bool print_special_tokens = false ;
bool print_special_tokens = false ;
bool no_timestamps = false ;
bool no_timestamps = false ;
std : : string language = " en " ;
std : : string language = " en " ;
std : : string model = " models/ggml-base.en.bin " ;
std : : string model = " models/ggml-base.en.bin " ;
std : : string fname_inp = " samples/jfk.wav " ;
std : : vector < std : : string > fname_inp = { } ;
} ;
} ;
void whisper_print_usage ( int argc , char * * argv , const whisper_params & params ) ;
void whisper_print_usage ( int argc , char * * argv , const whisper_params & params ) ;
@ -45,10 +51,17 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
for ( int i = 1 ; i < argc ; i + + ) {
for ( int i = 1 ; i < argc ; i + + ) {
std : : string arg = argv [ i ] ;
std : : string arg = argv [ i ] ;
if ( arg [ 0 ] ! = ' - ' ) {
params . fname_inp . push_back ( arg ) ;
continue ;
}
if ( arg = = " -s " | | arg = = " --seed " ) {
if ( arg = = " -s " | | arg = = " --seed " ) {
params . seed = std : : stoi ( argv [ + + i ] ) ;
params . seed = std : : stoi ( argv [ + + i ] ) ;
} else if ( arg = = " -t " | | arg = = " --threads " ) {
} else if ( arg = = " -t " | | arg = = " --threads " ) {
params . n_threads = std : : stoi ( argv [ + + i ] ) ;
params . n_threads = std : : stoi ( argv [ + + i ] ) ;
} else if ( arg = = " -o " | | arg = = " --offset " ) {
params . offset_ms = std : : stoi ( argv [ + + i ] ) ;
} else if ( arg = = " -v " | | arg = = " --verbose " ) {
} else if ( arg = = " -v " | | arg = = " --verbose " ) {
params . verbose = true ;
params . verbose = true ;
} else if ( arg = = " --translate " ) {
} else if ( arg = = " --translate " ) {
@ -60,6 +73,12 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
whisper_print_usage ( argc , argv , params ) ;
whisper_print_usage ( argc , argv , params ) ;
exit ( 0 ) ;
exit ( 0 ) ;
}
}
} else if ( arg = = " -otxt " | | arg = = " --output-txt " ) {
params . output_txt = true ;
} else if ( arg = = " -ovtt " | | arg = = " --output-vtt " ) {
params . output_vtt = true ;
} else if ( arg = = " -osrt " | | arg = = " --output-srt " ) {
params . output_srt = true ;
} else if ( arg = = " -ps " | | arg = = " --print_special " ) {
} else if ( arg = = " -ps " | | arg = = " --print_special " ) {
params . print_special_tokens = true ;
params . print_special_tokens = true ;
} else if ( arg = = " -nt " | | arg = = " --no_timestamps " ) {
} else if ( arg = = " -nt " | | arg = = " --no_timestamps " ) {
@ -67,7 +86,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
} else if ( arg = = " -m " | | arg = = " --model " ) {
} else if ( arg = = " -m " | | arg = = " --model " ) {
params . model = argv [ + + i ] ;
params . model = argv [ + + i ] ;
} else if ( arg = = " -f " | | arg = = " --file " ) {
} else if ( arg = = " -f " | | arg = = " --file " ) {
params . fname_inp = argv [ + + i ] ;
params . fname_inp . push_back ( argv [ + + i ] ) ;
} else if ( arg = = " -h " | | arg = = " --help " ) {
} else if ( arg = = " -h " | | arg = = " --help " ) {
whisper_print_usage ( argc , argv , params ) ;
whisper_print_usage ( argc , argv , params ) ;
exit ( 0 ) ;
exit ( 0 ) ;
@ -83,19 +102,23 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
void whisper_print_usage ( int argc , char * * argv , const whisper_params & params ) {
void whisper_print_usage ( int argc , char * * argv , const whisper_params & params ) {
fprintf ( stderr , " \n " ) ;
fprintf ( stderr , " \n " ) ;
fprintf ( stderr , " usage: %s [options] \n " , argv [ 0 ] ) ;
fprintf ( stderr , " usage: %s [options] file0.wav file1.wav ... \n " , argv [ 0 ] ) ;
fprintf ( stderr , " \n " ) ;
fprintf ( stderr , " \n " ) ;
fprintf ( stderr , " options: \n " ) ;
fprintf ( stderr , " options: \n " ) ;
fprintf ( stderr , " -h, --help show this help message and exit \n " ) ;
fprintf ( stderr , " -h, --help show this help message and exit \n " ) ;
fprintf ( stderr , " -s SEED, --seed SEED RNG seed (default: -1) \n " ) ;
fprintf ( stderr , " -s SEED, --seed SEED RNG seed (default: -1) \n " ) ;
fprintf ( stderr , " -t N, --threads N number of threads to use during computation (default: %d) \n " , params . n_threads ) ;
fprintf ( stderr , " -t N, --threads N number of threads to use during computation (default: %d) \n " , params . n_threads ) ;
fprintf ( stderr , " -o N, --offset N offset in milliseconds (default: %d) \n " , params . offset_ms ) ;
fprintf ( stderr , " -v, --verbose verbose output \n " ) ;
fprintf ( stderr , " -v, --verbose verbose output \n " ) ;
fprintf ( stderr , " --translate translate from source language to english \n " ) ;
fprintf ( stderr , " --translate translate from source language to english \n " ) ;
fprintf ( stderr , " -otxt, --output-txt output result in a text file \n " ) ;
fprintf ( stderr , " -ovtt, --output-vtt output result in a vtt file \n " ) ;
fprintf ( stderr , " -osrt, --output-srt output result in a srt file \n " ) ;
fprintf ( stderr , " -ps, --print_special print special tokens \n " ) ;
fprintf ( stderr , " -ps, --print_special print special tokens \n " ) ;
fprintf ( stderr , " -nt, --no_timestamps do not print timestamps \n " ) ;
fprintf ( stderr , " -nt, --no_timestamps do not print timestamps \n " ) ;
fprintf ( stderr , " -l LANG, --language LANG spoken language (default: %s) \n " , params . language . c_str ( ) ) ;
fprintf ( stderr , " -l LANG, --language LANG spoken language (default: %s) \n " , params . language . c_str ( ) ) ;
fprintf ( stderr , " -m FNAME, --model FNAME model path (default: %s) \n " , params . model . c_str ( ) ) ;
fprintf ( stderr , " -m FNAME, --model FNAME model path (default: %s) \n " , params . model . c_str ( ) ) ;
fprintf ( stderr , " -f FNAME, --file FNAME input WAV file path (default: %s) \n " , params . fname_inp . c_str ( ) ) ;
fprintf ( stderr , " -f FNAME, --file FNAME input WAV file path \n " ) ;
fprintf ( stderr , " \n " ) ;
fprintf ( stderr , " \n " ) ;
}
}
@ -110,33 +133,42 @@ int main(int argc, char ** argv) {
params . seed = time ( NULL ) ;
params . seed = time ( NULL ) ;
}
}
if ( params . fname_inp . empty ( ) ) {
fprintf ( stderr , " error: no input files specified \n " ) ;
whisper_print_usage ( argc , argv , params ) ;
return 2 ;
}
// whisper init
// whisper init
struct whisper_context * ctx = whisper_init ( params . model . c_str ( ) ) ;
struct whisper_context * ctx = whisper_init ( params . model . c_str ( ) ) ;
for ( int f = 0 ; f < ( int ) params . fname_inp . size ( ) ; + + f ) {
const auto fname_inp = params . fname_inp [ f ] ;
// WAV input
// WAV input
std : : vector < float > pcmf32 ;
std : : vector < float > pcmf32 ;
{
{
drwav wav ;
drwav wav ;
if ( ! drwav_init_file ( & wav , params . fname_inp . c_str ( ) , NULL ) ) {
if ( ! drwav_init_file ( & wav , fname_inp . c_str ( ) , NULL ) ) {
fprintf ( stderr , " %s: failed to open WAV file '%s' - check your input \n " , argv [ 0 ] , params . fname_inp . c_str ( ) ) ;
fprintf ( stderr , " %s: failed to open WAV file '%s' - check your input \n " , argv [ 0 ] , fname_inp . c_str ( ) ) ;
whisper_print_usage ( argc , argv , { } ) ;
whisper_print_usage ( argc , argv , { } ) ;
return 2 ;
return 3 ;
}
}
if ( wav . channels ! = 1 & & wav . channels ! = 2 ) {
if ( wav . channels ! = 1 & & wav . channels ! = 2 ) {
fprintf ( stderr , " %s: WAV file '%s' must be mono or stereo \n " , argv [ 0 ] , params . fname_inp . c_str ( ) ) ;
fprintf ( stderr , " %s: WAV file '%s' must be mono or stereo \n " , argv [ 0 ] , fname_inp . c_str ( ) ) ;
return 3 ;
return 4 ;
}
}
if ( wav . sampleRate ! = WHISPER_SAMPLE_RATE ) {
if ( wav . sampleRate ! = WHISPER_SAMPLE_RATE ) {
fprintf ( stderr , " %s: WAV file '%s' must be 16 kHz \n " , argv [ 0 ] , params . fname_inp . c_str ( ) ) ;
fprintf ( stderr , " %s: WAV file '%s' must be 16 kHz \n " , argv [ 0 ] , fname_inp . c_str ( ) ) ;
return 4 ;
return 5 ;
}
}
if ( wav . bitsPerSample ! = 16 ) {
if ( wav . bitsPerSample ! = 16 ) {
fprintf ( stderr , " %s: WAV file '%s' must be 16-bit \n " , argv [ 0 ] , params . fname_inp . c_str ( ) ) ;
fprintf ( stderr , " %s: WAV file '%s' must be 16-bit \n " , argv [ 0 ] , fname_inp . c_str ( ) ) ;
return 5 ;
return 6 ;
}
}
int n = wav . totalPCMFrameCount ;
int n = wav . totalPCMFrameCount ;
@ -161,22 +193,24 @@ int main(int argc, char ** argv) {
// print some info about the processing
// print some info about the processing
{
{
printf( " \n " ) ;
f printf( stderr , " \n " ) ;
if ( ! whisper_is_multilingual ( ctx ) ) {
if ( ! whisper_is_multilingual ( ctx ) ) {
if ( params . language ! = " en " | | params . translate ) {
if ( params . language ! = " en " | | params . translate ) {
params . language = " en " ;
params . language = " en " ;
params . translate = false ;
params . translate = false ;
printf( " %s: WARNING: model is not multilingual, ignoring language and translation options \n " , __func__ ) ;
f printf( stderr , " %s: WARNING: model is not multilingual, ignoring language and translation options \n " , __func__ ) ;
}
}
}
}
printf( " %s: processing %d samples ( %.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n " ,
f printf( stderr , " %s: processing '%s' ( %d samples, %.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n " ,
__func__ , int ( pcmf32 . size ( ) ) , float ( pcmf32 . size ( ) ) / WHISPER_SAMPLE_RATE , params . n_threads ,
__func__ , fname_inp . c_str ( ) , int ( pcmf32 . size ( ) ) , float ( pcmf32 . size ( ) ) / WHISPER_SAMPLE_RATE , params . n_threads ,
params . language . c_str ( ) ,
params . language . c_str ( ) ,
params . translate ? " translate " : " transcribe " ,
params . translate ? " translate " : " transcribe " ,
params . no_timestamps ? 0 : 1 ) ;
params . no_timestamps ? 0 : 1 ) ;
printf ( " \n " ) ;
fprintf ( stderr , " \n " ) ;
}
}
// run the inference
// run the inference
{
{
whisper_full_params wparams = whisper_full_default_params ( WHISPER_DECODE_GREEDY ) ;
whisper_full_params wparams = whisper_full_default_params ( WHISPER_DECODE_GREEDY ) ;
@ -188,13 +222,14 @@ int main(int argc, char ** argv) {
wparams . translate = params . translate ;
wparams . translate = params . translate ;
wparams . language = params . language . c_str ( ) ;
wparams . language = params . language . c_str ( ) ;
wparams . n_threads = params . n_threads ;
wparams . n_threads = params . n_threads ;
wparams . offset_ms = params . offset_ms ;
if ( whisper_full ( ctx , wparams , pcmf32 . data ( ) , pcmf32 . size ( ) ) ! = 0 ) {
if ( whisper_full ( ctx , wparams , pcmf32 . data ( ) , pcmf32 . size ( ) ) ! = 0 ) {
fprintf ( stderr , " %s: failed to process audio \n " , argv [ 0 ] ) ;
fprintf ( stderr , " %s: failed to process audio \n " , argv [ 0 ] ) ;
return 6 ;
return 7 ;
}
}
// print result ;
// print result
if ( ! wparams . print_realtime ) {
if ( ! wparams . print_realtime ) {
printf ( " \n " ) ;
printf ( " \n " ) ;
@ -203,13 +238,84 @@ int main(int argc, char ** argv) {
const char * text = whisper_full_get_segment_text ( ctx , i ) ;
const char * text = whisper_full_get_segment_text ( ctx , i ) ;
if ( params . no_timestamps ) {
if ( params . no_timestamps ) {
printf ( " %s " , text ) ;
printf ( " %s " , text ) ;
fflush ( stdout ) ;
fflush ( stdout ) ;
} else {
} else {
const int64_t t0 = whisper_full_get_segment_t0 ( ctx , i ) ;
const int64_t t0 = whisper_full_get_segment_t0 ( ctx , i ) ;
const int64_t t1 = whisper_full_get_segment_t1 ( ctx , i ) ;
const int64_t t1 = whisper_full_get_segment_t1 ( ctx , i ) ;
printf ( " [%s --> %s] %s \n " , to_timestamp ( t0 ) . c_str ( ) , to_timestamp ( t1 ) . c_str ( ) , text ) ;
printf ( " [%s --> %s] %s \n " , to_timestamp ( t0 ) . c_str ( ) , to_timestamp ( t1 ) . c_str ( ) , text ) ;
}
}
}
printf ( " \n " ) ;
// output to text file
if ( params . output_txt ) {
const auto fname_txt = fname_inp + " .txt " ;
std : : ofstream fout_txt ( fname_txt ) ;
if ( ! fout_txt . is_open ( ) ) {
fprintf ( stderr , " %s: failed to open '%s' for writing \n " , __func__ , fname_txt . c_str ( ) ) ;
return 8 ;
}
fprintf ( stderr , " %s: saving output to '%s.txt' \n " , __func__ , fname_inp . c_str ( ) ) ;
const int n_segments = whisper_full_n_segments ( ctx ) ;
for ( int i = 0 ; i < n_segments ; + + i ) {
const char * text = whisper_full_get_segment_text ( ctx , i ) ;
fout_txt < < text ;
}
}
// output to VTT file
if ( params . output_vtt ) {
const auto fname_vtt = fname_inp + " .vtt " ;
std : : ofstream fout_vtt ( fname_vtt ) ;
if ( ! fout_vtt . is_open ( ) ) {
fprintf ( stderr , " %s: failed to open '%s' for writing \n " , __func__ , fname_vtt . c_str ( ) ) ;
return 9 ;
}
fprintf ( stderr , " %s: saving output to '%s.vtt' \n " , __func__ , fname_inp . c_str ( ) ) ;
fout_vtt < < " WEBVTT \n \n " ;
const int n_segments = whisper_full_n_segments ( ctx ) ;
for ( int i = 0 ; i < n_segments ; + + i ) {
const char * text = whisper_full_get_segment_text ( ctx , i ) ;
const int64_t t0 = whisper_full_get_segment_t0 ( ctx , i ) ;
const int64_t t1 = whisper_full_get_segment_t1 ( ctx , i ) ;
fout_vtt < < to_timestamp ( t0 ) < < " --> " < < to_timestamp ( t1 ) < < " \n " ;
fout_vtt < < text < < " \n \n " ;
}
}
// output to SRT file
if ( params . output_srt ) {
const auto fname_srt = fname_inp + " .srt " ;
std : : ofstream fout_srt ( fname_srt ) ;
if ( ! fout_srt . is_open ( ) ) {
fprintf ( stderr , " %s: failed to open '%s' for writing \n " , __func__ , fname_srt . c_str ( ) ) ;
return 10 ;
}
fprintf ( stderr , " %s: saving output to '%s.srt' \n " , __func__ , fname_inp . c_str ( ) ) ;
const int n_segments = whisper_full_n_segments ( ctx ) ;
for ( int i = 0 ; i < n_segments ; + + i ) {
const char * text = whisper_full_get_segment_text ( ctx , i ) ;
const int64_t t0 = whisper_full_get_segment_t0 ( ctx , i ) ;
const int64_t t1 = whisper_full_get_segment_t1 ( ctx , i ) ;
fout_srt < < i + 1 < < " \n " ;
fout_srt < < to_timestamp ( t0 ) < < " --> " < < to_timestamp ( t1 ) < < " \n " ;
fout_srt < < text < < " \n \n " ;
}
}
}
}
}
}