diff --git a/bindings/ruby/ext/.gitignore b/bindings/ruby/ext/.gitignore index 1031bdd..7c9cb03 100644 --- a/bindings/ruby/ext/.gitignore +++ b/bindings/ruby/ext/.gitignore @@ -4,3 +4,4 @@ ggml.h whisper.bundle whisper.cpp whisper.h +dr_wav.h diff --git a/bindings/ruby/ext/extconf.rb b/bindings/ruby/ext/extconf.rb index 0421cb2..851c52d 100644 --- a/bindings/ruby/ext/extconf.rb +++ b/bindings/ruby/ext/extconf.rb @@ -3,6 +3,7 @@ system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.cpp')} .") system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.h')} .") system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.h')} .") system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.c')} .") +system("cp #{File.join(File.dirname(__FILE__),'..','..','..','examples','dr_wav.h')} .") # need to use c++ compiler flags diff --git a/bindings/ruby/ext/ruby_whisper.c b/bindings/ruby/ext/ruby_whisper.cpp similarity index 52% rename from bindings/ruby/ext/ruby_whisper.c rename to bindings/ruby/ext/ruby_whisper.cpp index 33324a0..e7416ba 100644 --- a/bindings/ruby/ext/ruby_whisper.c +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -1,20 +1,32 @@ #include #include "ruby_whisper.h" +#define DR_WAV_IMPLEMENTATION +#include "dr_wav.h" +#include +#include +#include +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif #define BOOL_PARAMS_SETTER(self, prop, value) \ ruby_whisper_params *rwp; \ Data_Get_Struct(self, ruby_whisper_params, rwp); \ if (value == Qfalse || value == Qnil) { \ - rwp->params->prop = false; \ + rwp->params.prop = false; \ } else { \ - rwp->params->prop = true; \ + rwp->params.prop = true; \ } \ return value; \ #define BOOL_PARAMS_GETTER(self, prop) \ ruby_whisper_params *rwp; \ Data_Get_Struct(self, ruby_whisper_params, rwp); \ - if (rwp->params->prop) { \ + if (rwp->params.prop) { \ return Qtrue; \ } else { \ return Qfalse; \ @@ -31,10 +43,6 @@ static void ruby_whisper_free(ruby_whisper *rw) { } } static void ruby_whisper_params_free(ruby_whisper_params *rwp) { - if (rwp->params) { - free(rwp->params); - rwp->params = NULL; - } } void rb_whisper_mark(ruby_whisper *rw) { @@ -64,7 +72,7 @@ static VALUE ruby_whisper_allocate(VALUE klass) { static VALUE ruby_whisper_params_allocate(VALUE klass) { ruby_whisper_params *rwp; rwp = ALLOC(ruby_whisper_params); - rwp->params = ALLOC(struct whisper_full_params); + rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp); } @@ -80,29 +88,161 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) { rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context"); } rw->context = whisper_init_from_file(StringValueCStr(whisper_model_file_path)); + if (rw->context == nullptr) { + rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context"); + } return self; } /* - * params.auto_detection = true|false + * transcribe a single file + * can emit to a block results + * + **/ +static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { + ruby_whisper *rw; + ruby_whisper_params *rwp; + VALUE wave_file_path, blk, params; + + rb_scan_args(argc, argv, "02&", &wave_file_path, ¶ms, &blk); + Data_Get_Struct(self, ruby_whisper, rw); + Data_Get_Struct(params, ruby_whisper_params, rwp); + + if (!rb_respond_to(wave_file_path, rb_intern("to_s"))) { + rb_raise(rb_eRuntimeError, "Expected file path to wave file"); + } + + std::string fname_inp = StringValueCStr(wave_file_path); + + std::vector pcmf32; // mono-channel F32 PCM + std::vector> pcmf32s; // stereo-channel F32 PCM + + // WAV input - this is directly from main.cpp example + { + drwav wav; + std::vector wav_data; // used for pipe input from stdin + + if (fname_inp == "-") { + { + uint8_t buf[1024]; + while (true) { + const size_t n = fread(buf, 1, sizeof(buf), stdin); + if (n == 0) { + break; + } + wav_data.insert(wav_data.end(), buf, buf + n); + } + } + + if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) { + fprintf(stderr, "error: failed to open WAV file from stdin\n"); + return self; + } + + fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size()); + } else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) { + fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str()); + return self; + } + + if (wav.channels != 1 && wav.channels != 2) { + fprintf(stderr, "WAV file '%s' must be mono or stereo\n", fname_inp.c_str()); + return self; + } + + if (rwp->diarize && wav.channels != 2 && rwp->params.print_timestamps == false) { + fprintf(stderr, "WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", fname_inp.c_str()); + return self; + } + + if (wav.sampleRate != WHISPER_SAMPLE_RATE) { + fprintf(stderr, "WAV file '%s' must be %i kHz\n", fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000); + return self; + } + + if (wav.bitsPerSample != 16) { + fprintf(stderr, "WAV file '%s' must be 16-bit\n", fname_inp.c_str()); + return self; + } + + const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8); + + std::vector pcm16; + pcm16.resize(n*wav.channels); + drwav_read_pcm_frames_s16(&wav, n, pcm16.data()); + drwav_uninit(&wav); + + // convert to mono, float + pcmf32.resize(n); + if (wav.channels == 1) { + for (uint64_t i = 0; i < n; i++) { + pcmf32[i] = float(pcm16[i])/32768.0f; + } + } else { + for (uint64_t i = 0; i < n; i++) { + pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f; + } + } + + if (rwp->diarize) { + // convert to stereo, float + pcmf32s.resize(2); + + pcmf32s[0].resize(n); + pcmf32s[1].resize(n); + for (uint64_t i = 0; i < n; i++) { + pcmf32s[0][i] = float(pcm16[2*i])/32768.0f; + pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f; + } + } + } + { + static bool is_aborted = false; // NOTE: this should be atomic to avoid data race + + rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) { + bool is_aborted = *(bool*)user_data; + return !is_aborted; + }; + rwp->params.encoder_begin_callback_user_data = &is_aborted; + } + + if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) { + fprintf(stderr, "failed to process audio\n"); + return self; + } + const int n_segments = whisper_full_n_segments(rw->context); + VALUE output = rb_str_new2(""); + for (int i = 0; i < n_segments; ++i) { + const char * text = whisper_full_get_segment_text(rw->context, i); + output = rb_str_concat(output, rb_str_new2(text)); + } + VALUE idCall = rb_intern("call"); + if (blk != Qnil) { + rb_funcall(blk, idCall, 1, output); + } + return self; +} + +/* + * params.language = "auto" | "en", etc... */ -static VALUE ruby_whisper_params_set_auto_detection(VALUE self, VALUE value) { +static VALUE ruby_whisper_params_set_language(VALUE self, VALUE value) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); if (value == Qfalse || value == Qnil) { - rwp->params->language = NULL; + rwp->params.language = "auto"; } else { - rwp->params->language = "auto"; + rwp->params.language = StringValueCStr(value); } return value; } -static VALUE ruby_whisper_params_get_auto_detection(VALUE self) { +static VALUE ruby_whisper_params_get_language(VALUE self) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); - if (rwp->params->language) { - return Qtrue; + if (rwp->params.language) { + return rb_str_new2(rwp->params.language); } else { - return Qfalse; + return rb_str_new2("auto"); } } static VALUE ruby_whisper_params_set_translate(VALUE self, VALUE value) { @@ -159,39 +299,76 @@ static VALUE ruby_whisper_params_set_suppress_non_speech_tokens(VALUE self, VALU static VALUE ruby_whisper_params_get_suppress_non_speech_tokens(VALUE self) { BOOL_PARAMS_GETTER(self, suppress_non_speech_tokens) } +static VALUE ruby_whisper_params_get_token_timestamps(VALUE self) { + BOOL_PARAMS_GETTER(self, token_timestamps) +} +static VALUE ruby_whisper_params_set_token_timestamps(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, token_timestamps, value) +} +static VALUE ruby_whisper_params_get_split_on_word(VALUE self) { + BOOL_PARAMS_GETTER(self, split_on_word) +} +static VALUE ruby_whisper_params_set_split_on_word(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, split_on_word, value) +} +static VALUE ruby_whisper_params_get_speed_up(VALUE self) { + BOOL_PARAMS_GETTER(self, speed_up) +} +static VALUE ruby_whisper_params_set_speed_up(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, speed_up, value) +} +static VALUE ruby_whisper_params_get_diarize(VALUE self) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + if (rwp->diarize) { + return Qtrue; + } else { + return Qfalse; + } +} +static VALUE ruby_whisper_params_set_diarize(VALUE self, VALUE value) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + if (value == Qfalse || value == Qnil) { + rwp->diarize = false; + } else { + rwp->diarize = true; + } \ + return value; +} static VALUE ruby_whisper_params_get_offset(VALUE self) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); - return INT2NUM(rwp->params->offset_ms); + return INT2NUM(rwp->params.offset_ms); } static VALUE ruby_whisper_params_set_offset(VALUE self, VALUE value) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); - rwp->params->offset_ms = NUM2INT(value); + rwp->params.offset_ms = NUM2INT(value); return value; } static VALUE ruby_whisper_params_get_duration(VALUE self) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); - return INT2NUM(rwp->params->duration_ms); + return INT2NUM(rwp->params.duration_ms); } static VALUE ruby_whisper_params_set_duration(VALUE self, VALUE value) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); - rwp->params->duration_ms = NUM2INT(value); + rwp->params.duration_ms = NUM2INT(value); return value; } static VALUE ruby_whisper_params_get_max_text_tokens(VALUE self) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); - return INT2NUM(rwp->params->n_max_text_ctx); + return INT2NUM(rwp->params.n_max_text_ctx); } static VALUE ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); - rwp->params->n_max_text_ctx = NUM2INT(value); + rwp->params.n_max_text_ctx = NUM2INT(value); return value; } @@ -203,10 +380,12 @@ void Init_whisper() { rb_define_alloc_func(cContext, ruby_whisper_allocate); rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1); + rb_define_method(cContext, "transcribe", ruby_whisper_transcribe, -1); + rb_define_alloc_func(cParams, ruby_whisper_params_allocate); - rb_define_method(cParams, "auto_detection=", ruby_whisper_params_set_auto_detection, 1); - rb_define_method(cParams, "auto_detection", ruby_whisper_params_get_auto_detection, 0); + rb_define_method(cParams, "language=", ruby_whisper_params_set_language, 1); + rb_define_method(cParams, "language", ruby_whisper_params_get_language, 0); rb_define_method(cParams, "translate=", ruby_whisper_params_set_translate, 1); rb_define_method(cParams, "translate", ruby_whisper_params_get_translate, 0); rb_define_method(cParams, "no_context=", ruby_whisper_params_set_no_context, 1); @@ -225,6 +404,14 @@ void Init_whisper() { rb_define_method(cParams, "suppress_blank=", ruby_whisper_params_set_suppress_blank, 1); rb_define_method(cParams, "suppress_non_speech_tokens", ruby_whisper_params_get_suppress_non_speech_tokens, 0); rb_define_method(cParams, "suppress_non_speech_tokens=", ruby_whisper_params_set_suppress_non_speech_tokens, 1); + rb_define_method(cParams, "token_timestamps", ruby_whisper_params_get_token_timestamps, 0); + rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1); + rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0); + rb_define_method(cParams, "split_on_word=", ruby_whisper_params_set_split_on_word, 1); + rb_define_method(cParams, "speed_up", ruby_whisper_params_get_speed_up, 0); + rb_define_method(cParams, "speed_up=", ruby_whisper_params_set_speed_up, 1); + rb_define_method(cParams, "diarize", ruby_whisper_params_get_diarize, 0); + rb_define_method(cParams, "diarize=", ruby_whisper_params_set_diarize, 1); rb_define_method(cParams, "offset", ruby_whisper_params_get_offset, 0); rb_define_method(cParams, "offset=", ruby_whisper_params_set_offset, 1); @@ -234,3 +421,6 @@ void Init_whisper() { rb_define_method(cParams, "max_text_tokens", ruby_whisper_params_get_max_text_tokens, 0); rb_define_method(cParams, "max_text_tokens=", ruby_whisper_params_set_max_text_tokens, 1); } +#ifdef __cplusplus +} +#endif diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index 246d133..8c35b7c 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -8,7 +8,8 @@ typedef struct { } ruby_whisper; typedef struct { - struct whisper_full_params *params; + struct whisper_full_params params; + bool diarize; } ruby_whisper_params; #endif diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index 617a64e..03d5a23 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -12,11 +12,11 @@ class TestWhisper < Test::Unit::TestCase @params = Whisper::Params.new end - def test_autodetect - @params.auto_detection = true - assert @params.auto_detection - @params.auto_detection = false - assert !@params.auto_detection + def test_language + @params.language = "en" + assert_equal @params.language, "en" + @params.language = "auto" + assert_equal @params.language, "auto" end def test_offset @@ -103,8 +103,36 @@ class TestWhisper < Test::Unit::TestCase assert !@params.suppress_non_speech_tokens end + def test_token_timestamps + @params.token_timestamps = true + assert @params.token_timestamps + @params.token_timestamps = false + assert !@params.token_timestamps + end + + def test_split_on_word + @params.split_on_word = true + assert @params.split_on_word + @params.split_on_word = false + assert !@params.split_on_word + end + + def test_speed_up + @params.speed_up = true + assert @params.speed_up + @params.speed_up = false + assert !@params.speed_up + end + def test_whisper @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'for-tests-ggml-base.en.bin')) + params = Whisper::Params.new + params.print_timestamps = false + + jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') + @whisper.transcribe(jfk, params) {|text| + puts text.inspect + } end end