From 48389940beb92453b50d520d4d05dd3f622357fb Mon Sep 17 00:00:00 2001 From: Syahmi Azhar Date: Mon, 2 Jan 2023 01:43:48 +0800 Subject: [PATCH] android : load models directly from assets --- .../ui/main/MainScreenViewModel.kt | 14 +++- .../com/whispercppdemo/whisper/LibWhisper.kt | 13 +++- .../app/src/main/jni/whisper/jni.c | 74 +++++++++++++++++++ 3 files changed, 96 insertions(+), 5 deletions(-) diff --git a/examples/whisper.android/app/src/main/java/com/whispercppdemo/ui/main/MainScreenViewModel.kt b/examples/whisper.android/app/src/main/java/com/whispercppdemo/ui/main/MainScreenViewModel.kt index 8664440..d741748 100644 --- a/examples/whisper.android/app/src/main/java/com/whispercppdemo/ui/main/MainScreenViewModel.kt +++ b/examples/whisper.android/app/src/main/java/com/whispercppdemo/ui/main/MainScreenViewModel.kt @@ -64,16 +64,22 @@ class MainScreenViewModel(private val application: Application) : ViewModel() { private suspend fun copyAssets() = withContext(Dispatchers.IO) { modelsPath.mkdirs() samplesPath.mkdirs() - application.copyData("models", modelsPath, ::printMessage) + //application.copyData("models", modelsPath, ::printMessage) application.copyData("samples", samplesPath, ::printMessage) printMessage("All data copied to working directory.\n") } private suspend fun loadBaseModel() = withContext(Dispatchers.IO) { printMessage("Loading model...\n") - val firstModel = modelsPath.listFiles()!!.first() - whisperContext = WhisperContext.createContext(firstModel.absolutePath) - printMessage("Loaded model ${firstModel.name}.\n") + val models = application.assets.list("models/") + if (models != null) { + val inputstream = application.assets.open("models/" + models[0]) + whisperContext = WhisperContext.createContextFromInputStream(inputstream) + printMessage("Loaded model ${models[0]}.\n") + } + + //val firstModel = modelsPath.listFiles()!!.first() + //whisperContext = WhisperContext.createContextFromFile(firstModel.absolutePath) } fun transcribeSample() = viewModelScope.launch { diff --git a/examples/whisper.android/app/src/main/java/com/whispercppdemo/whisper/LibWhisper.kt b/examples/whisper.android/app/src/main/java/com/whispercppdemo/whisper/LibWhisper.kt index a6dfdcc..edd041a 100644 --- a/examples/whisper.android/app/src/main/java/com/whispercppdemo/whisper/LibWhisper.kt +++ b/examples/whisper.android/app/src/main/java/com/whispercppdemo/whisper/LibWhisper.kt @@ -4,6 +4,7 @@ import android.os.Build import android.util.Log import kotlinx.coroutines.* import java.io.File +import java.io.InputStream import java.util.concurrent.Executors private const val LOG_TAG = "LibWhisper" @@ -39,13 +40,22 @@ class WhisperContext private constructor(private var ptr: Long) { } companion object { - fun createContext(filePath: String): WhisperContext { + fun createContextFromFile(filePath: String): WhisperContext { val ptr = WhisperLib.initContext(filePath) if (ptr == 0L) { throw java.lang.RuntimeException("Couldn't create context with path $filePath") } return WhisperContext(ptr) } + + fun createContextFromInputStream(stream: InputStream): WhisperContext { + val ptr = WhisperLib.initContextFromInputStream(stream) + + if (ptr == 0L) { + throw java.lang.RuntimeException("Couldn't create context from input stream") + } + return WhisperContext(ptr) + } } } @@ -76,6 +86,7 @@ private class WhisperLib { } // JNI methods + external fun initContextFromInputStream(inputStream: InputStream): Long external fun initContext(modelPath: String): Long external fun freeContext(contextPtr: Long) external fun fullTranscribe(contextPtr: Long, audioData: FloatArray) diff --git a/examples/whisper.android/app/src/main/jni/whisper/jni.c b/examples/whisper.android/app/src/main/jni/whisper/jni.c index 0992943..0fd2897 100644 --- a/examples/whisper.android/app/src/main/jni/whisper/jni.c +++ b/examples/whisper.android/app/src/main/jni/whisper/jni.c @@ -2,6 +2,7 @@ #include #include #include +#include #include "whisper.h" #define UNUSED(x) (void)(x) @@ -17,6 +18,79 @@ static inline int max(int a, int b) { return (a > b) ? a : b; } +struct input_stream_context { + size_t offset; + JNIEnv * env; + jobject thiz; + jobject input_stream; + + jmethodID mid_available; + jmethodID mid_read; +}; + +size_t inputStreamRead(void * ctx, void * output, size_t read_size) { + struct input_stream_context* is = (struct input_stream_context*)ctx; + + jint avail_size = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_available); + jint size_to_copy = read_size < avail_size ? (jint)read_size : avail_size; + + jbyteArray byte_array = (*is->env)->NewByteArray(is->env, size_to_copy); + + jint n_read = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_read, byte_array, 0, size_to_copy); + + if (size_to_copy != read_size || size_to_copy != n_read) { + LOGI("Insufficient Read: Req=%zu, ToCopy=%d, Available=%d", read_size, size_to_copy, n_read); + } + + jbyte* byte_array_elements = (*is->env)->GetByteArrayElements(is->env, byte_array, NULL); + memcpy(output, byte_array_elements, size_to_copy); + (*is->env)->ReleaseByteArrayElements(is->env, byte_array, byte_array_elements, JNI_ABORT); + + (*is->env)->DeleteLocalRef(is->env, byte_array); + + is->offset += size_to_copy; + + return size_to_copy; +} +bool inputStreamEof(void * ctx) { + struct input_stream_context* is = (struct input_stream_context*)ctx; + + jint result = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_available); + return result <= 0; +} +void inputStreamClose(void * ctx) { + +} + +JNIEXPORT jlong JNICALL +Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_initContextFromInputStream( + JNIEnv *env, jobject thiz, jobject input_stream) { + UNUSED(thiz); + + struct whisper_context *context = NULL; + struct whisper_model_loader loader = {}; + struct input_stream_context inp_ctx = {}; + + inp_ctx.offset = 0; + inp_ctx.env = env; + inp_ctx.thiz = thiz; + inp_ctx.input_stream = input_stream; + + jclass cls = (*env)->GetObjectClass(env, input_stream); + inp_ctx.mid_available = (*env)->GetMethodID(env, cls, "available", "()I"); + inp_ctx.mid_read = (*env)->GetMethodID(env, cls, "read", "([BII)I"); + + loader.context = &inp_ctx; + loader.read = inputStreamRead; + loader.eof = inputStreamEof; + loader.close = inputStreamClose; + + loader.eof(loader.context); + + context = whisper_init(&loader); + return (jlong) context; +} + JNIEXPORT jlong JNICALL Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_initContext( JNIEnv *env, jobject thiz, jstring model_path_str) {