whisper.android : support benchmark for Android example. (#542)

* whisper.android: Support benchmark for Android example.

* whisper.android: update screenshot in README.

* update: Make text selectable for copy & paste.

* Update whisper.h to restore API name

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* whisper.android: Restore original API names.

---------

Co-authored-by: tinoue <tinoue@xevo.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
master
Takeshi Inoue 1 year ago committed by GitHub
parent fa9d43181f
commit 09e9068007
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -9,4 +9,4 @@ To use:
5. Select the "release" active build variant, and use Android Studio to run and deploy to your device. 5. Select the "release" active build variant, and use Android Studio to run and deploy to your device.
[^1]: I recommend the tiny or base models for running on an Android device. [^1]: I recommend the tiny or base models for running on an Android device.
<img width="300" alt="image" src="https://user-images.githubusercontent.com/1991296/208154256-82d972dc-221b-48c4-bfcb-36ce68602f93.png"> <img width="300" alt="image" src="https://user-images.githubusercontent.com/1670775/221613663-a17bf770-27ef-45ab-9a46-a5f99ba65d2a.jpg">

@ -2,6 +2,7 @@ package com.whispercppdemo.ui.main
import androidx.compose.foundation.layout.* import androidx.compose.foundation.layout.*
import androidx.compose.foundation.rememberScrollState import androidx.compose.foundation.rememberScrollState
import androidx.compose.foundation.text.selection.SelectionContainer
import androidx.compose.foundation.verticalScroll import androidx.compose.foundation.verticalScroll
import androidx.compose.material3.* import androidx.compose.material3.*
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
@ -19,6 +20,7 @@ fun MainScreen(viewModel: MainScreenViewModel) {
canTranscribe = viewModel.canTranscribe, canTranscribe = viewModel.canTranscribe,
isRecording = viewModel.isRecording, isRecording = viewModel.isRecording,
messageLog = viewModel.dataLog, messageLog = viewModel.dataLog,
onBenchmarkTapped = viewModel::benchmark,
onTranscribeSampleTapped = viewModel::transcribeSample, onTranscribeSampleTapped = viewModel::transcribeSample,
onRecordTapped = viewModel::toggleRecord onRecordTapped = viewModel::toggleRecord
) )
@ -30,6 +32,7 @@ private fun MainScreen(
canTranscribe: Boolean, canTranscribe: Boolean,
isRecording: Boolean, isRecording: Boolean,
messageLog: String, messageLog: String,
onBenchmarkTapped: () -> Unit,
onTranscribeSampleTapped: () -> Unit, onTranscribeSampleTapped: () -> Unit,
onRecordTapped: () -> Unit onRecordTapped: () -> Unit
) { ) {
@ -45,8 +48,11 @@ private fun MainScreen(
.padding(innerPadding) .padding(innerPadding)
.padding(16.dp) .padding(16.dp)
) { ) {
Row(horizontalArrangement = Arrangement.SpaceBetween) { Column(verticalArrangement = Arrangement.SpaceBetween) {
TranscribeSampleButton(enabled = canTranscribe, onClick = onTranscribeSampleTapped) Row(horizontalArrangement = Arrangement.SpaceBetween, modifier = Modifier.fillMaxWidth()) {
BenchmarkButton(enabled = canTranscribe, onClick = onBenchmarkTapped)
TranscribeSampleButton(enabled = canTranscribe, onClick = onTranscribeSampleTapped)
}
RecordButton( RecordButton(
enabled = canTranscribe, enabled = canTranscribe,
isRecording = isRecording, isRecording = isRecording,
@ -60,7 +66,16 @@ private fun MainScreen(
@Composable @Composable
private fun MessageLog(log: String) { private fun MessageLog(log: String) {
Text(modifier = Modifier.verticalScroll(rememberScrollState()), text = log) SelectionContainer() {
Text(modifier = Modifier.verticalScroll(rememberScrollState()), text = log)
}
}
@Composable
private fun BenchmarkButton(enabled: Boolean, onClick: () -> Unit) {
Button(onClick = onClick, enabled = enabled) {
Text("Benchmark")
}
} }
@Composable @Composable

@ -41,10 +41,15 @@ class MainScreenViewModel(private val application: Application) : ViewModel() {
init { init {
viewModelScope.launch { viewModelScope.launch {
printSystemInfo()
loadData() loadData()
} }
} }
private suspend fun printSystemInfo() {
printMessage(String.format("System Info: %s\n", WhisperContext.getSystemInfo()));
}
private suspend fun loadData() { private suspend fun loadData() {
printMessage("Loading data...\n") printMessage("Loading data...\n")
try { try {
@ -81,10 +86,29 @@ class MainScreenViewModel(private val application: Application) : ViewModel() {
//whisperContext = WhisperContext.createContextFromFile(firstModel.absolutePath) //whisperContext = WhisperContext.createContextFromFile(firstModel.absolutePath)
} }
fun benchmark() = viewModelScope.launch {
runBenchmark(6)
}
fun transcribeSample() = viewModelScope.launch { fun transcribeSample() = viewModelScope.launch {
transcribeAudio(getFirstSample()) transcribeAudio(getFirstSample())
} }
private suspend fun runBenchmark(nthreads: Int) {
if (!canTranscribe) {
return
}
canTranscribe = false
printMessage("Running benchmark. This will take minutes...\n")
whisperContext?.benchMemory(nthreads)?.let{ printMessage(it) }
printMessage("\n")
whisperContext?.benchGgmlMulMat(nthreads)?.let{ printMessage(it) }
canTranscribe = true
}
private suspend fun getFirstSample(): File = withContext(Dispatchers.IO) { private suspend fun getFirstSample(): File = withContext(Dispatchers.IO) {
samplesPath.listFiles()!!.first() samplesPath.listFiles()!!.first()
} }
@ -114,11 +138,14 @@ class MainScreenViewModel(private val application: Application) : ViewModel() {
canTranscribe = false canTranscribe = false
try { try {
printMessage("Reading wave samples...\n") printMessage("Reading wave samples... ")
val data = readAudioSamples(file) val data = readAudioSamples(file)
printMessage("${data.size / (16000 / 1000)} ms\n")
printMessage("Transcribing data...\n") printMessage("Transcribing data...\n")
val start = System.currentTimeMillis()
val text = whisperContext?.transcribeData(data) val text = whisperContext?.transcribeData(data)
printMessage("Done: $text\n") val elapsed = System.currentTimeMillis() - start
printMessage("Done ($elapsed ms): $text\n")
} catch (e: Exception) { } catch (e: Exception) {
Log.w(LOG_TAG, e) Log.w(LOG_TAG, e)
printMessage("${e.localizedMessage}\n") printMessage("${e.localizedMessage}\n")

@ -27,6 +27,14 @@ class WhisperContext private constructor(private var ptr: Long) {
} }
} }
suspend fun benchMemory(nthreads: Int): String = withContext(scope.coroutineContext) {
return@withContext WhisperLib.benchMemcpy(nthreads)
}
suspend fun benchGgmlMulMat(nthreads: Int): String = withContext(scope.coroutineContext) {
return@withContext WhisperLib.benchGgmlMulMat(nthreads)
}
suspend fun release() = withContext(scope.coroutineContext) { suspend fun release() = withContext(scope.coroutineContext) {
if (ptr != 0L) { if (ptr != 0L) {
WhisperLib.freeContext(ptr) WhisperLib.freeContext(ptr)
@ -66,6 +74,10 @@ class WhisperContext private constructor(private var ptr: Long) {
} }
return WhisperContext(ptr) return WhisperContext(ptr)
} }
fun getSystemInfo(): String {
return WhisperLib.getSystemInfo()
}
} }
} }
@ -117,6 +129,9 @@ private class WhisperLib {
external fun fullTranscribe(contextPtr: Long, audioData: FloatArray) external fun fullTranscribe(contextPtr: Long, audioData: FloatArray)
external fun getTextSegmentCount(contextPtr: Long): Int external fun getTextSegmentCount(contextPtr: Long): Int
external fun getTextSegment(contextPtr: Long, index: Int): String external fun getTextSegment(contextPtr: Long, index: Int): String
external fun getSystemInfo(): String
external fun benchMemcpy(nthread: Int): String
external fun benchGgmlMulMat(nthread: Int): String
} }
} }

@ -6,6 +6,7 @@
#include <sys/sysinfo.h> #include <sys/sysinfo.h>
#include <string.h> #include <string.h>
#include "whisper.h" #include "whisper.h"
#include "ggml.h"
#define UNUSED(x) (void)(x) #define UNUSED(x) (void)(x)
#define TAG "JNI" #define TAG "JNI"
@ -213,4 +214,30 @@ Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_getTextSegment(
const char *text = whisper_full_get_segment_text(context, index); const char *text = whisper_full_get_segment_text(context, index);
jstring string = (*env)->NewStringUTF(env, text); jstring string = (*env)->NewStringUTF(env, text);
return string; return string;
} }
JNIEXPORT jstring JNICALL
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_getSystemInfo(
JNIEnv *env, jobject thiz
) {
UNUSED(thiz);
const char *sysinfo = whisper_print_system_info();
jstring string = (*env)->NewStringUTF(env, sysinfo);
return string;
}
JNIEXPORT jstring JNICALL
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_benchMemcpy(JNIEnv *env, jobject thiz,
jint n_threads) {
UNUSED(thiz);
const char *bench_ggml_memcpy = whisper_bench_memcpy_str(n_threads);
jstring string = (*env)->NewStringUTF(env, bench_ggml_memcpy);
}
JNIEXPORT jstring JNICALL
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_benchGgmlMulMat(JNIEnv *env, jobject thiz,
jint n_threads) {
UNUSED(thiz);
const char *bench_ggml_mul_mat = whisper_bench_ggml_mul_mat_str(n_threads);
jstring string = (*env)->NewStringUTF(env, bench_ggml_mul_mat);
}

@ -4551,6 +4551,15 @@ float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int
// //
WHISPER_API int whisper_bench_memcpy(int n_threads) { WHISPER_API int whisper_bench_memcpy(int n_threads) {
fputs(whisper_bench_memcpy_str(n_threads), stderr);
return 0;
}
WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
static std::string s;
s = "";
char strbuf[256];
ggml_time_init(); ggml_time_init();
size_t n = 50; size_t n = 50;
@ -4580,7 +4589,8 @@ WHISPER_API int whisper_bench_memcpy(int n_threads) {
src[0] = rand(); src[0] = rand();
} }
fprintf(stderr, "memcpy: %.2f GB/s\n", (double) (n*size)/(tsum*1024llu*1024llu*1024llu)); snprintf(strbuf, sizeof(strbuf), "memcpy: %.2f GB/s\n", (double) (n*size)/(tsum*1024llu*1024llu*1024llu));
s += strbuf;
// needed to prevent the compile from optimizing the memcpy away // needed to prevent the compile from optimizing the memcpy away
{ {
@ -4588,16 +4598,26 @@ WHISPER_API int whisper_bench_memcpy(int n_threads) {
for (size_t i = 0; i < size; i++) sum += dst[i]; for (size_t i = 0; i < size; i++) sum += dst[i];
fprintf(stderr, "sum: %s %f\n", sum == -536870910.00 ? "ok" : "error", sum); snprintf(strbuf, sizeof(strbuf), "sum: %s %f\n", sum == -536870910.00 ? "ok" : "error", sum);
s += strbuf;
} }
free(src); free(src);
free(dst); free(dst);
return 0; return s.c_str();
} }
WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) { WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) {
fputs(whisper_bench_ggml_mul_mat_str(n_threads), stderr);
return 0;
}
WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
static std::string s;
s = "";
char strbuf[256];
ggml_time_init(); ggml_time_init();
const int n_max = 128; const int n_max = 128;
@ -4673,11 +4693,12 @@ WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) {
s = ((2.0*N*N*N*n)/tsum)*1e-9; s = ((2.0*N*N*N*n)/tsum)*1e-9;
} }
fprintf(stderr, "ggml_mul_mat: %5zu x %5zu: F16 %8.1f GFLOPS (%3d runs) / F32 %8.1f GFLOPS (%3d runs)\n", snprintf(strbuf, sizeof(strbuf), "ggml_mul_mat: %5zu x %5zu: F16 %8.1f GFLOPS (%3d runs) / F32 %8.1f GFLOPS (%3d runs)\n",
N, N, s_fp16, n_fp16, s_fp32, n_fp32); N, N, s_fp16, n_fp16, s_fp32, n_fp32);
s += strbuf;
} }
return 0; return s.c_str();
} }
// ================================================================================================= // =================================================================================================

@ -462,7 +462,9 @@ extern "C" {
// Temporary helpers needed for exposing ggml interface // Temporary helpers needed for exposing ggml interface
WHISPER_API int whisper_bench_memcpy(int n_threads); WHISPER_API int whisper_bench_memcpy(int n_threads);
WHISPER_API const char * whisper_bench_memcpy_str(int n_threads);
WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads); WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads);
WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads);
#ifdef __cplusplus #ifdef __cplusplus
} }

Loading…
Cancel
Save