add support to abort running transcription process

pull/534/head
Todd Fisher 2 years ago
parent 24d7ec3b23
commit e38e1ebe8b
No known key found for this signature in database
GPG Key ID: 15A1619895B1FB4A

@ -1,7 +1,20 @@
require 'erb' require 'erb'
require 'open3'
require 'rake/clean'
require 'rake/testtask' require 'rake/testtask'
require 'rubygems/package' require 'rubygems/package'
CLEAN.include '**/*.o'
CLEAN.include "**/*.#{(defined?(RbConfig) ? RbConfig : Config)::MAKEFILE_CONFIG['DLEXT']}"
CLOBBER.include 'doc'
CLOBBER.include '**/*.log'
CLOBBER.include '**/Makefile'
CLOBBER.include '**/extconf.h'
CLOBBER.include '**/extconf.h'
CLOBBER.include '**/whisper.*'
CLOBBER.include '**/ggml.*'
CLOBBER.include '**/dr_wav.h'
BUILD_VERSION=1 BUILD_VERSION=1
# Determine the current version of the software # Determine the current version of the software
if File.read('../../CMakeLists.txt') =~ /project.*\s*VERSION\s*(\d.+)\)/ if File.read('../../CMakeLists.txt') =~ /project.*\s*VERSION\s*(\d.+)\)/
@ -10,6 +23,58 @@ else
CURRENT_VERSION = "0.0.0.#{BUILD_VERSION}" CURRENT_VERSION = "0.0.0.#{BUILD_VERSION}"
end end
def shell(args, opts = {})
puts "> #{args.join(' ')}"
cmd, live_stream, cwd = args, opts[:live_stdout], opts[:cwd]
Dir.chdir(cwd) {
wait_thr = nil
Open3.popen3(*cmd) do |stdin, stdout, stderr, thr|
stdin.close
wait_thr = thr # Ruby 1.8 will not yield thr, this will be nil
while line = stdout.gets do
live_stream.puts(line) if live_stream
end
while line = stderr.gets do
puts line
end
end
# prefer process handle directly from popen3, but if not available
# fallback to global.
p_status = wait_thr ? wait_thr.value : $?
exit_code = p_status.exitstatus
error = (exit_code != 0)
}
end
make_program = (/mswin/ =~ RUBY_PLATFORM) ? 'nmake' : 'make'
MAKECMD = ENV['MAKE_CMD'] || make_program
MAKEOPTS = ENV['MAKE_OPTS'] || ''
WHISPER_SO = "ext/whisper.#{(defined?(RbConfig) ? RbConfig : Config)::MAKEFILE_CONFIG['DLEXT']}"
file 'ext/Makefile' => 'ext/extconf.rb' do
shell(['ruby', 'extconf.rb', ENV['EXTCONF_OPTS'].to_s],
{ live_stdout: STDOUT, cwd: "#{Dir.pwd}/ext" }
)
end
def make(target = '')
shell(["#{MAKECMD}", "#{MAKEOPTS}", "#{target}"].reject(&:empty?),
{ live_stdout: STDOUT, cwd: "#{Dir.pwd}/ext" }
)
end
# Let make handle dependencies between c/o/so - we'll just run it.
file WHISPER_SO => (['ext/Makefile'] + Dir['ext/*.cpp'] + Dir['ext/*.c'] + Dir['ext/*.h']) do
make
end
desc "Compile the shared object"
task :compile => [WHISPER_SO]
desc "Default Task (Test project)" desc "Default Task (Test project)"
task :default => :test task :default => :test

@ -4,6 +4,7 @@ We expose Whisper::Context and Whisper::Params. The Context object can be used
Parameters can be set on the Params object to customize how the transcription is generated. Parameters can be set on the Params object to customize how the transcription is generated.
``` ```
require 'whisper'
whisper = Whisper::Context.new('ggml-base.en.bin') whisper = Whisper::Context.new('ggml-base.en.bin')
params = Whisper::Params.new params = Whisper::Params.new
whisper.transcribe('jfk.wav', params) {|text| whisper.transcribe('jfk.wav', params) {|text|

@ -1,4 +1,5 @@
#include <ruby.h> #include <ruby.h>
#include <ruby/thread.h>
#include "ruby_whisper.h" #include "ruby_whisper.h"
#define DR_WAV_IMPLEMENTATION #define DR_WAV_IMPLEMENTATION
#include "dr_wav.h" #include "dr_wav.h"
@ -94,6 +95,32 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
return self; return self;
} }
struct WhisperFullParallelParams {
ruby_whisper *rw;
ruby_whisper_params *rwp;
std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
};
static void stop_whisper_unblock(void *args) {
struct WhisperFullParallelParams *object = (struct WhisperFullParallelParams *)args;
fprintf(stderr, "Set running to abort\n");
whisper_running_abort(object->rw->context);
}
static VALUE call_whisper_full_parallel(void *args) {
struct WhisperFullParallelParams *object = (struct WhisperFullParallelParams *)args;
whisper_running_restore(object->rw->context);
if (whisper_full_parallel(object->rw->context, object->rwp->params, object->pcmf32.data(), object->pcmf32.size(), 1) != 0) {
fprintf(stderr, "failed to process audio\n");
return INT2FIX(-1);
}
return INT2FIX(0);
}
/* /*
* transcribe a single file * transcribe a single file
* can emit to a block results * can emit to a block results
@ -114,8 +141,9 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
std::string fname_inp = StringValueCStr(wave_file_path); std::string fname_inp = StringValueCStr(wave_file_path);
std::vector<float> pcmf32; // mono-channel F32 PCM //std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM //std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
struct WhisperFullParallelParams object;
// WAV input - this is directly from main.cpp example // WAV input - this is directly from main.cpp example
{ {
@ -173,26 +201,26 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
drwav_uninit(&wav); drwav_uninit(&wav);
// convert to mono, float // convert to mono, float
pcmf32.resize(n); object.pcmf32.resize(n);
if (wav.channels == 1) { if (wav.channels == 1) {
for (uint64_t i = 0; i < n; i++) { for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[i])/32768.0f; object.pcmf32[i] = float(pcm16[i])/32768.0f;
} }
} else { } else {
for (uint64_t i = 0; i < n; i++) { for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f; object.pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
} }
} }
if (rwp->diarize) { if (rwp->diarize) {
// convert to stereo, float // convert to stereo, float
pcmf32s.resize(2); object.pcmf32s.resize(2);
pcmf32s[0].resize(n); object.pcmf32s[0].resize(n);
pcmf32s[1].resize(n); object.pcmf32s[1].resize(n);
for (uint64_t i = 0; i < n; i++) { for (uint64_t i = 0; i < n; i++) {
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f; object.pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f; object.pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
} }
} }
} }
@ -206,10 +234,16 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
rwp->params.encoder_begin_callback_user_data = &is_aborted; rwp->params.encoder_begin_callback_user_data = &is_aborted;
} }
if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) { object.rw = rw;
object.rwp = rwp;
int r = (int)(VALUE)rb_thread_call_without_gvl((void *(*)(void *))call_whisper_full_parallel, &object, stop_whisper_unblock, &object);
//if (whisper_full_parallel(rw->context, rwp->params, object.pcmf32.data(), pcmf32.size(), 1) != 0) {
if (r != 0) {
fprintf(stderr, "failed to process audio\n"); fprintf(stderr, "failed to process audio\n");
return self; return self;
} }
const int n_segments = whisper_full_n_segments(rw->context); const int n_segments = whisper_full_n_segments(rw->context);
VALUE output = rb_str_new2(""); VALUE output = rb_str_new2("");
for (int i = 0; i < n_segments; ++i) { for (int i = 0; i < n_segments; ++i) {

@ -603,6 +603,9 @@ struct whisper_context {
// [EXPERIMENTAL] speed-up techniques // [EXPERIMENTAL] speed-up techniques
int32_t exp_n_audio_ctx = 0; // 0 - use default int32_t exp_n_audio_ctx = 0; // 0 - use default
// [EXPERIMENTAL] abort handling
bool running = true;
void use_buf(struct ggml_context * ctx, int i) { void use_buf(struct ggml_context * ctx, int i) {
#if defined(WHISPER_USE_SCRATCH) #if defined(WHISPER_USE_SCRATCH)
size_t last_size = 0; size_t last_size = 0;
@ -3654,7 +3657,7 @@ int whisper_full(
std::vector<beam_candidate> beam_candidates; std::vector<beam_candidate> beam_candidates;
// main loop // main loop
while (true) { while (ctx->running) {
const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start); const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start);
while (progress_cur >= progress_prev + progress_step) { while (progress_cur >= progress_prev + progress_step) {
progress_prev += progress_step; progress_prev += progress_step;
@ -4204,12 +4207,27 @@ int whisper_full(
return 0; return 0;
} }
void whisper_running_abort(struct whisper_context * ctx) {
ctx->running = false;
}
void whisper_running_restore(struct whisper_context * ctx) {
ctx->running = true;
}
bool whisper_running_state(struct whisper_context * ctx) {
return ctx->running;
}
int whisper_full_parallel( int whisper_full_parallel(
struct whisper_context * ctx, struct whisper_context * ctx,
struct whisper_full_params params, struct whisper_full_params params,
const float * samples, const float * samples,
int n_samples, int n_samples,
int n_processors) { int n_processors) {
ctx->running = true;
if (n_processors == 1) { if (n_processors == 1) {
return whisper_full(ctx, params, samples, n_samples); return whisper_full(ctx, params, samples, n_samples);
} }

@ -225,6 +225,15 @@ extern "C" {
// Print system information // Print system information
WHISPER_API const char * whisper_print_system_info(void); WHISPER_API const char * whisper_print_system_info(void);
// Abort a running whisper_full_parallel or whisper_full
WHISPER_API void whisper_running_abort(struct whisper_context * ctx);
// Resume whisper context from an aborted state allowing it run again
WHISPER_API void whisper_running_restore(struct whisper_context * ctx);
// Check the whisper context state if true then it can run if false it can not
WHISPER_API bool whisper_running_state(struct whisper_context * ctx);
//////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////
// Available sampling strategies // Available sampling strategies

Loading…
Cancel
Save