From 126458b6cb565cf35d1e6e3910232f70df633579 Mon Sep 17 00:00:00 2001 From: Todd Fisher Date: Tue, 14 Feb 2023 09:43:33 -0500 Subject: [PATCH] add definitions for boolean params --- bindings/ruby/ext/ruby_whisper.c | 146 +++++++++++++++++++++++++++- bindings/ruby/tests/test_whisper.rb | 89 ++++++++++++++++- 2 files changed, 233 insertions(+), 2 deletions(-) diff --git a/bindings/ruby/ext/ruby_whisper.c b/bindings/ruby/ext/ruby_whisper.c index 4abfab5..33324a0 100644 --- a/bindings/ruby/ext/ruby_whisper.c +++ b/bindings/ruby/ext/ruby_whisper.c @@ -1,6 +1,25 @@ #include #include "ruby_whisper.h" +#define BOOL_PARAMS_SETTER(self, prop, value) \ + ruby_whisper_params *rwp; \ + Data_Get_Struct(self, ruby_whisper_params, rwp); \ + if (value == Qfalse || value == Qnil) { \ + rwp->params->prop = false; \ + } else { \ + rwp->params->prop = true; \ + } \ + return value; \ + +#define BOOL_PARAMS_GETTER(self, prop) \ + ruby_whisper_params *rwp; \ + Data_Get_Struct(self, ruby_whisper_params, rwp); \ + if (rwp->params->prop) { \ + return Qtrue; \ + } else { \ + return Qfalse; \ + } + VALUE mWhisper; VALUE cContext; VALUE cParams; @@ -71,10 +90,108 @@ static VALUE ruby_whisper_params_set_auto_detection(VALUE self, VALUE value) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); if (value == Qfalse || value == Qnil) { + rwp->params->language = NULL; + } else { rwp->params->language = "auto"; + } + return value; +} +static VALUE ruby_whisper_params_get_auto_detection(VALUE self) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + if (rwp->params->language) { + return Qtrue; } else { - rwp->params->language = NULL; + return Qfalse; } +} +static VALUE ruby_whisper_params_set_translate(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, translate, value) +} +static VALUE ruby_whisper_params_get_translate(VALUE self) { + BOOL_PARAMS_GETTER(self, translate) +} +static VALUE ruby_whisper_params_set_no_context(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, no_context, value) +} +static VALUE ruby_whisper_params_get_no_context(VALUE self) { + BOOL_PARAMS_GETTER(self, no_context) +} +static VALUE ruby_whisper_params_set_single_segment(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, single_segment, value) +} +static VALUE ruby_whisper_params_get_single_segment(VALUE self) { + BOOL_PARAMS_GETTER(self, single_segment) +} +static VALUE ruby_whisper_params_set_print_special(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, print_special, value) +} +static VALUE ruby_whisper_params_get_print_special(VALUE self) { + BOOL_PARAMS_GETTER(self, print_special) +} +static VALUE ruby_whisper_params_set_print_progress(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, print_progress, value) +} +static VALUE ruby_whisper_params_get_print_progress(VALUE self) { + BOOL_PARAMS_GETTER(self, print_progress) +} +static VALUE ruby_whisper_params_set_print_realtime(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, print_realtime, value) +} +static VALUE ruby_whisper_params_get_print_realtime(VALUE self) { + BOOL_PARAMS_GETTER(self, print_realtime) +} +static VALUE ruby_whisper_params_set_print_timestamps(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, print_timestamps, value) +} +static VALUE ruby_whisper_params_get_print_timestamps(VALUE self) { + BOOL_PARAMS_GETTER(self, print_timestamps) +} +static VALUE ruby_whisper_params_set_suppress_blank(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, suppress_blank, value) +} +static VALUE ruby_whisper_params_get_suppress_blank(VALUE self) { + BOOL_PARAMS_GETTER(self, suppress_blank) +} +static VALUE ruby_whisper_params_set_suppress_non_speech_tokens(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, suppress_non_speech_tokens, value) +} +static VALUE ruby_whisper_params_get_suppress_non_speech_tokens(VALUE self) { + BOOL_PARAMS_GETTER(self, suppress_non_speech_tokens) +} + +static VALUE ruby_whisper_params_get_offset(VALUE self) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return INT2NUM(rwp->params->offset_ms); +} +static VALUE ruby_whisper_params_set_offset(VALUE self, VALUE value) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->params->offset_ms = NUM2INT(value); + return value; +} +static VALUE ruby_whisper_params_get_duration(VALUE self) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return INT2NUM(rwp->params->duration_ms); +} +static VALUE ruby_whisper_params_set_duration(VALUE self, VALUE value) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->params->duration_ms = NUM2INT(value); + return value; +} + +static VALUE ruby_whisper_params_get_max_text_tokens(VALUE self) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return INT2NUM(rwp->params->n_max_text_ctx); +} +static VALUE ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->params->n_max_text_ctx = NUM2INT(value); return value; } @@ -89,4 +206,31 @@ void Init_whisper() { rb_define_alloc_func(cParams, ruby_whisper_params_allocate); rb_define_method(cParams, "auto_detection=", ruby_whisper_params_set_auto_detection, 1); + rb_define_method(cParams, "auto_detection", ruby_whisper_params_get_auto_detection, 0); + rb_define_method(cParams, "translate=", ruby_whisper_params_set_translate, 1); + rb_define_method(cParams, "translate", ruby_whisper_params_get_translate, 0); + rb_define_method(cParams, "no_context=", ruby_whisper_params_set_no_context, 1); + rb_define_method(cParams, "no_context", ruby_whisper_params_get_no_context, 0); + rb_define_method(cParams, "single_segment=", ruby_whisper_params_set_single_segment, 1); + rb_define_method(cParams, "single_segment", ruby_whisper_params_get_single_segment, 0); + rb_define_method(cParams, "print_special", ruby_whisper_params_get_print_special, 0); + rb_define_method(cParams, "print_special=", ruby_whisper_params_set_print_special, 1); + rb_define_method(cParams, "print_progress", ruby_whisper_params_get_print_progress, 0); + rb_define_method(cParams, "print_progress=", ruby_whisper_params_set_print_progress, 1); + rb_define_method(cParams, "print_realtime", ruby_whisper_params_get_print_realtime, 0); + rb_define_method(cParams, "print_realtime=", ruby_whisper_params_set_print_realtime, 1); + rb_define_method(cParams, "print_timestamps", ruby_whisper_params_get_print_timestamps, 0); + rb_define_method(cParams, "print_timestamps=", ruby_whisper_params_set_print_timestamps, 1); + rb_define_method(cParams, "suppress_blank", ruby_whisper_params_get_suppress_blank, 0); + rb_define_method(cParams, "suppress_blank=", ruby_whisper_params_set_suppress_blank, 1); + rb_define_method(cParams, "suppress_non_speech_tokens", ruby_whisper_params_get_suppress_non_speech_tokens, 0); + rb_define_method(cParams, "suppress_non_speech_tokens=", ruby_whisper_params_set_suppress_non_speech_tokens, 1); + + rb_define_method(cParams, "offset", ruby_whisper_params_get_offset, 0); + rb_define_method(cParams, "offset=", ruby_whisper_params_set_offset, 1); + rb_define_method(cParams, "duration", ruby_whisper_params_get_duration, 0); + rb_define_method(cParams, "duration=", ruby_whisper_params_set_duration, 1); + + rb_define_method(cParams, "max_text_tokens", ruby_whisper_params_get_max_text_tokens, 0); + rb_define_method(cParams, "max_text_tokens=", ruby_whisper_params_set_max_text_tokens, 1); } diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index 82c21fe..617a64e 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -10,14 +10,101 @@ require 'test/unit' class TestWhisper < Test::Unit::TestCase def setup @params = Whisper::Params.new - @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'for-tests-ggml-base.en.bin')) end def test_autodetect @params.auto_detection = true + assert @params.auto_detection + @params.auto_detection = false + assert !@params.auto_detection + end + + def test_offset + @params.offset = 10_000 + assert_equal @params.offset, 10_000 + @params.offset = 0 + assert_equal @params.offset, 0 + end + + def test_duration + @params.duration = 60_000 + assert_equal @params.duration, 60_000 + @params.duration = 0 + assert_equal @params.duration, 0 + end + + def test_max_text_tokens + @params.max_text_tokens = 300 + assert_equal @params.max_text_tokens, 300 + @params.max_text_tokens = 0 + assert_equal @params.max_text_tokens, 0 + end + + def test_translate + @params.translate = true + assert @params.translate + @params.translate = false + assert !@params.translate + end + + def test_no_context + @params.no_context = true + assert @params.no_context + @params.no_context = false + assert !@params.no_context + end + + def test_single_segment + @params.single_segment = true + assert @params.single_segment + @params.single_segment = false + assert !@params.single_segment + end + + def test_print_special + @params.print_special = true + assert @params.print_special + @params.print_special = false + assert !@params.print_special + end + + def test_print_progress + @params.print_progress = true + assert @params.print_progress + @params.print_progress = false + assert !@params.print_progress + end + + def test_print_realtime + @params.print_realtime = true + assert @params.print_realtime + @params.print_realtime = false + assert !@params.print_realtime + end + + def test_print_timestamps + @params.print_timestamps = true + assert @params.print_timestamps + @params.print_timestamps = false + assert !@params.print_timestamps + end + + def test_suppress_blank + @params.suppress_blank = true + assert @params.suppress_blank + @params.suppress_blank = false + assert !@params.suppress_blank + end + + def test_suppress_non_speech_tokens + @params.suppress_non_speech_tokens = true + assert @params.suppress_non_speech_tokens + @params.suppress_non_speech_tokens = false + assert !@params.suppress_non_speech_tokens end def test_whisper + @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'for-tests-ggml-base.en.bin')) end end