add definitions for boolean params

pull/500/head
Todd Fisher 1 year ago
parent ec8e2e14e1
commit 126458b6cb
No known key found for this signature in database
GPG Key ID: 15A1619895B1FB4A

@ -1,6 +1,25 @@
#include <ruby.h>
#include "ruby_whisper.h"
#define BOOL_PARAMS_SETTER(self, prop, value) \
ruby_whisper_params *rwp; \
Data_Get_Struct(self, ruby_whisper_params, rwp); \
if (value == Qfalse || value == Qnil) { \
rwp->params->prop = false; \
} else { \
rwp->params->prop = true; \
} \
return value; \
#define BOOL_PARAMS_GETTER(self, prop) \
ruby_whisper_params *rwp; \
Data_Get_Struct(self, ruby_whisper_params, rwp); \
if (rwp->params->prop) { \
return Qtrue; \
} else { \
return Qfalse; \
}
VALUE mWhisper;
VALUE cContext;
VALUE cParams;
@ -71,10 +90,108 @@ static VALUE ruby_whisper_params_set_auto_detection(VALUE self, VALUE value) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
if (value == Qfalse || value == Qnil) {
rwp->params->language = NULL;
} else {
rwp->params->language = "auto";
}
return value;
}
static VALUE ruby_whisper_params_get_auto_detection(VALUE self) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
if (rwp->params->language) {
return Qtrue;
} else {
rwp->params->language = NULL;
return Qfalse;
}
}
static VALUE ruby_whisper_params_set_translate(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, translate, value)
}
static VALUE ruby_whisper_params_get_translate(VALUE self) {
BOOL_PARAMS_GETTER(self, translate)
}
static VALUE ruby_whisper_params_set_no_context(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, no_context, value)
}
static VALUE ruby_whisper_params_get_no_context(VALUE self) {
BOOL_PARAMS_GETTER(self, no_context)
}
static VALUE ruby_whisper_params_set_single_segment(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, single_segment, value)
}
static VALUE ruby_whisper_params_get_single_segment(VALUE self) {
BOOL_PARAMS_GETTER(self, single_segment)
}
static VALUE ruby_whisper_params_set_print_special(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, print_special, value)
}
static VALUE ruby_whisper_params_get_print_special(VALUE self) {
BOOL_PARAMS_GETTER(self, print_special)
}
static VALUE ruby_whisper_params_set_print_progress(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, print_progress, value)
}
static VALUE ruby_whisper_params_get_print_progress(VALUE self) {
BOOL_PARAMS_GETTER(self, print_progress)
}
static VALUE ruby_whisper_params_set_print_realtime(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, print_realtime, value)
}
static VALUE ruby_whisper_params_get_print_realtime(VALUE self) {
BOOL_PARAMS_GETTER(self, print_realtime)
}
static VALUE ruby_whisper_params_set_print_timestamps(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, print_timestamps, value)
}
static VALUE ruby_whisper_params_get_print_timestamps(VALUE self) {
BOOL_PARAMS_GETTER(self, print_timestamps)
}
static VALUE ruby_whisper_params_set_suppress_blank(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, suppress_blank, value)
}
static VALUE ruby_whisper_params_get_suppress_blank(VALUE self) {
BOOL_PARAMS_GETTER(self, suppress_blank)
}
static VALUE ruby_whisper_params_set_suppress_non_speech_tokens(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, suppress_non_speech_tokens, value)
}
static VALUE ruby_whisper_params_get_suppress_non_speech_tokens(VALUE self) {
BOOL_PARAMS_GETTER(self, suppress_non_speech_tokens)
}
static VALUE ruby_whisper_params_get_offset(VALUE self) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return INT2NUM(rwp->params->offset_ms);
}
static VALUE ruby_whisper_params_set_offset(VALUE self, VALUE value) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params->offset_ms = NUM2INT(value);
return value;
}
static VALUE ruby_whisper_params_get_duration(VALUE self) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return INT2NUM(rwp->params->duration_ms);
}
static VALUE ruby_whisper_params_set_duration(VALUE self, VALUE value) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params->duration_ms = NUM2INT(value);
return value;
}
static VALUE ruby_whisper_params_get_max_text_tokens(VALUE self) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return INT2NUM(rwp->params->n_max_text_ctx);
}
static VALUE ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params->n_max_text_ctx = NUM2INT(value);
return value;
}
@ -89,4 +206,31 @@ void Init_whisper() {
rb_define_alloc_func(cParams, ruby_whisper_params_allocate);
rb_define_method(cParams, "auto_detection=", ruby_whisper_params_set_auto_detection, 1);
rb_define_method(cParams, "auto_detection", ruby_whisper_params_get_auto_detection, 0);
rb_define_method(cParams, "translate=", ruby_whisper_params_set_translate, 1);
rb_define_method(cParams, "translate", ruby_whisper_params_get_translate, 0);
rb_define_method(cParams, "no_context=", ruby_whisper_params_set_no_context, 1);
rb_define_method(cParams, "no_context", ruby_whisper_params_get_no_context, 0);
rb_define_method(cParams, "single_segment=", ruby_whisper_params_set_single_segment, 1);
rb_define_method(cParams, "single_segment", ruby_whisper_params_get_single_segment, 0);
rb_define_method(cParams, "print_special", ruby_whisper_params_get_print_special, 0);
rb_define_method(cParams, "print_special=", ruby_whisper_params_set_print_special, 1);
rb_define_method(cParams, "print_progress", ruby_whisper_params_get_print_progress, 0);
rb_define_method(cParams, "print_progress=", ruby_whisper_params_set_print_progress, 1);
rb_define_method(cParams, "print_realtime", ruby_whisper_params_get_print_realtime, 0);
rb_define_method(cParams, "print_realtime=", ruby_whisper_params_set_print_realtime, 1);
rb_define_method(cParams, "print_timestamps", ruby_whisper_params_get_print_timestamps, 0);
rb_define_method(cParams, "print_timestamps=", ruby_whisper_params_set_print_timestamps, 1);
rb_define_method(cParams, "suppress_blank", ruby_whisper_params_get_suppress_blank, 0);
rb_define_method(cParams, "suppress_blank=", ruby_whisper_params_set_suppress_blank, 1);
rb_define_method(cParams, "suppress_non_speech_tokens", ruby_whisper_params_get_suppress_non_speech_tokens, 0);
rb_define_method(cParams, "suppress_non_speech_tokens=", ruby_whisper_params_set_suppress_non_speech_tokens, 1);
rb_define_method(cParams, "offset", ruby_whisper_params_get_offset, 0);
rb_define_method(cParams, "offset=", ruby_whisper_params_set_offset, 1);
rb_define_method(cParams, "duration", ruby_whisper_params_get_duration, 0);
rb_define_method(cParams, "duration=", ruby_whisper_params_set_duration, 1);
rb_define_method(cParams, "max_text_tokens", ruby_whisper_params_get_max_text_tokens, 0);
rb_define_method(cParams, "max_text_tokens=", ruby_whisper_params_set_max_text_tokens, 1);
}

@ -10,14 +10,101 @@ require 'test/unit'
class TestWhisper < Test::Unit::TestCase
def setup
@params = Whisper::Params.new
@whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'for-tests-ggml-base.en.bin'))
end
def test_autodetect
@params.auto_detection = true
assert @params.auto_detection
@params.auto_detection = false
assert !@params.auto_detection
end
def test_offset
@params.offset = 10_000
assert_equal @params.offset, 10_000
@params.offset = 0
assert_equal @params.offset, 0
end
def test_duration
@params.duration = 60_000
assert_equal @params.duration, 60_000
@params.duration = 0
assert_equal @params.duration, 0
end
def test_max_text_tokens
@params.max_text_tokens = 300
assert_equal @params.max_text_tokens, 300
@params.max_text_tokens = 0
assert_equal @params.max_text_tokens, 0
end
def test_translate
@params.translate = true
assert @params.translate
@params.translate = false
assert !@params.translate
end
def test_no_context
@params.no_context = true
assert @params.no_context
@params.no_context = false
assert !@params.no_context
end
def test_single_segment
@params.single_segment = true
assert @params.single_segment
@params.single_segment = false
assert !@params.single_segment
end
def test_print_special
@params.print_special = true
assert @params.print_special
@params.print_special = false
assert !@params.print_special
end
def test_print_progress
@params.print_progress = true
assert @params.print_progress
@params.print_progress = false
assert !@params.print_progress
end
def test_print_realtime
@params.print_realtime = true
assert @params.print_realtime
@params.print_realtime = false
assert !@params.print_realtime
end
def test_print_timestamps
@params.print_timestamps = true
assert @params.print_timestamps
@params.print_timestamps = false
assert !@params.print_timestamps
end
def test_suppress_blank
@params.suppress_blank = true
assert @params.suppress_blank
@params.suppress_blank = false
assert !@params.suppress_blank
end
def test_suppress_non_speech_tokens
@params.suppress_non_speech_tokens = true
assert @params.suppress_non_speech_tokens
@params.suppress_non_speech_tokens = false
assert !@params.suppress_non_speech_tokens
end
def test_whisper
@whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'for-tests-ggml-base.en.bin'))
end
end

Loading…
Cancel
Save