whisper.android: Support benchmark for Android example.

pull/542/head
tinoue 2 years ago committed by tinoue
parent ca21f7ab16
commit 32d1a7b652

@ -104,8 +104,8 @@ int main(int argc, char ** argv) {
switch (params.what) { switch (params.what) {
case 0: ret = whisper_bench_encoder(params); break; case 0: ret = whisper_bench_encoder(params); break;
case 1: ret = whisper_bench_memcpy(params.n_threads); break; case 1: ret = whisper_print_bench_memcpy(params.n_threads); break;
case 2: ret = whisper_bench_ggml_mul_mat(params.n_threads); break; case 2: ret = whisper_print_bench_ggml_mul_mat(params.n_threads); break;
default: fprintf(stderr, "error: unknown benchmark: %d\n", params.what); break; default: fprintf(stderr, "error: unknown benchmark: %d\n", params.what); break;
} }

@ -19,6 +19,7 @@ fun MainScreen(viewModel: MainScreenViewModel) {
canTranscribe = viewModel.canTranscribe, canTranscribe = viewModel.canTranscribe,
isRecording = viewModel.isRecording, isRecording = viewModel.isRecording,
messageLog = viewModel.dataLog, messageLog = viewModel.dataLog,
onBenchmarkTapped = viewModel::benchmark,
onTranscribeSampleTapped = viewModel::transcribeSample, onTranscribeSampleTapped = viewModel::transcribeSample,
onRecordTapped = viewModel::toggleRecord onRecordTapped = viewModel::toggleRecord
) )
@ -30,6 +31,7 @@ private fun MainScreen(
canTranscribe: Boolean, canTranscribe: Boolean,
isRecording: Boolean, isRecording: Boolean,
messageLog: String, messageLog: String,
onBenchmarkTapped: () -> Unit,
onTranscribeSampleTapped: () -> Unit, onTranscribeSampleTapped: () -> Unit,
onRecordTapped: () -> Unit onRecordTapped: () -> Unit
) { ) {
@ -45,8 +47,11 @@ private fun MainScreen(
.padding(innerPadding) .padding(innerPadding)
.padding(16.dp) .padding(16.dp)
) { ) {
Row(horizontalArrangement = Arrangement.SpaceBetween) { Column(verticalArrangement = Arrangement.SpaceBetween) {
TranscribeSampleButton(enabled = canTranscribe, onClick = onTranscribeSampleTapped) Row(horizontalArrangement = Arrangement.SpaceBetween) {
BenchmarkButton(enabled = canTranscribe, onClick = onBenchmarkTapped)
TranscribeSampleButton(enabled = canTranscribe, onClick = onTranscribeSampleTapped)
}
RecordButton( RecordButton(
enabled = canTranscribe, enabled = canTranscribe,
isRecording = isRecording, isRecording = isRecording,
@ -63,6 +68,13 @@ private fun MessageLog(log: String) {
Text(modifier = Modifier.verticalScroll(rememberScrollState()), text = log) Text(modifier = Modifier.verticalScroll(rememberScrollState()), text = log)
} }
@Composable
private fun BenchmarkButton(enabled: Boolean, onClick: () -> Unit) {
Button(onClick = onClick, enabled = enabled) {
Text("Benchmark")
}
}
@Composable @Composable
private fun TranscribeSampleButton(enabled: Boolean, onClick: () -> Unit) { private fun TranscribeSampleButton(enabled: Boolean, onClick: () -> Unit) {
Button(onClick = onClick, enabled = enabled) { Button(onClick = onClick, enabled = enabled) {

@ -41,10 +41,15 @@ class MainScreenViewModel(private val application: Application) : ViewModel() {
init { init {
viewModelScope.launch { viewModelScope.launch {
printSystemInfo()
loadData() loadData()
} }
} }
private suspend fun printSystemInfo() {
printMessage(String.format("System Info: %s\n", WhisperContext.getSystemInfo()));
}
private suspend fun loadData() { private suspend fun loadData() {
printMessage("Loading data...\n") printMessage("Loading data...\n")
try { try {
@ -81,10 +86,29 @@ class MainScreenViewModel(private val application: Application) : ViewModel() {
//whisperContext = WhisperContext.createContextFromFile(firstModel.absolutePath) //whisperContext = WhisperContext.createContextFromFile(firstModel.absolutePath)
} }
fun benchmark() = viewModelScope.launch {
runBenchmark(6)
}
fun transcribeSample() = viewModelScope.launch { fun transcribeSample() = viewModelScope.launch {
transcribeAudio(getFirstSample()) transcribeAudio(getFirstSample())
} }
private suspend fun runBenchmark(nthreads: Int) {
if (!canTranscribe) {
return
}
canTranscribe = false
printMessage("Running benchmark. This will take minutes...\n")
whisperContext?.benchMemory(nthreads)?.let{ printMessage(it) }
printMessage("\n")
whisperContext?.benchGgmlMulMat(nthreads)?.let{ printMessage(it) }
canTranscribe = true
}
private suspend fun getFirstSample(): File = withContext(Dispatchers.IO) { private suspend fun getFirstSample(): File = withContext(Dispatchers.IO) {
samplesPath.listFiles()!!.first() samplesPath.listFiles()!!.first()
} }
@ -114,11 +138,14 @@ class MainScreenViewModel(private val application: Application) : ViewModel() {
canTranscribe = false canTranscribe = false
try { try {
printMessage("Reading wave samples...\n") printMessage("Reading wave samples... ")
val data = readAudioSamples(file) val data = readAudioSamples(file)
printMessage("${data.size / (16000 / 1000)} ms\n")
printMessage("Transcribing data...\n") printMessage("Transcribing data...\n")
val start = System.currentTimeMillis()
val text = whisperContext?.transcribeData(data) val text = whisperContext?.transcribeData(data)
printMessage("Done: $text\n") val elapsed = System.currentTimeMillis() - start
printMessage("Done ($elapsed ms): $text\n")
} catch (e: Exception) { } catch (e: Exception) {
Log.w(LOG_TAG, e) Log.w(LOG_TAG, e)
printMessage("${e.localizedMessage}\n") printMessage("${e.localizedMessage}\n")

@ -27,6 +27,14 @@ class WhisperContext private constructor(private var ptr: Long) {
} }
} }
suspend fun benchMemory(nthreads: Int): String = withContext(scope.coroutineContext) {
return@withContext WhisperLib.benchMemcpy(nthreads)
}
suspend fun benchGgmlMulMat(nthreads: Int): String = withContext(scope.coroutineContext) {
return@withContext WhisperLib.benchGgmlMulMat(nthreads)
}
suspend fun release() = withContext(scope.coroutineContext) { suspend fun release() = withContext(scope.coroutineContext) {
if (ptr != 0L) { if (ptr != 0L) {
WhisperLib.freeContext(ptr) WhisperLib.freeContext(ptr)
@ -66,6 +74,10 @@ class WhisperContext private constructor(private var ptr: Long) {
} }
return WhisperContext(ptr) return WhisperContext(ptr)
} }
fun getSystemInfo(): String {
return WhisperLib.getSystemInfo()
}
} }
} }
@ -103,6 +115,9 @@ private class WhisperLib {
external fun fullTranscribe(contextPtr: Long, audioData: FloatArray) external fun fullTranscribe(contextPtr: Long, audioData: FloatArray)
external fun getTextSegmentCount(contextPtr: Long): Int external fun getTextSegmentCount(contextPtr: Long): Int
external fun getTextSegment(contextPtr: Long, index: Int): String external fun getTextSegment(contextPtr: Long, index: Int): String
external fun getSystemInfo(): String
external fun benchMemcpy(nthread: Int): String
external fun benchGgmlMulMat(nthread: Int): String
} }
} }

@ -6,6 +6,7 @@
#include <sys/sysinfo.h> #include <sys/sysinfo.h>
#include <string.h> #include <string.h>
#include "whisper.h" #include "whisper.h"
#include "ggml.h"
#define UNUSED(x) (void)(x) #define UNUSED(x) (void)(x)
#define TAG "JNI" #define TAG "JNI"
@ -214,3 +215,29 @@ Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_getTextSegment(
jstring string = (*env)->NewStringUTF(env, text); jstring string = (*env)->NewStringUTF(env, text);
return string; return string;
} }
JNIEXPORT jstring JNICALL
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_getSystemInfo(
JNIEnv *env, jobject thiz
) {
UNUSED(thiz);
const char *sysinfo = whisper_print_system_info();
jstring string = (*env)->NewStringUTF(env, sysinfo);
return string;
}
JNIEXPORT jstring JNICALL
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_benchMemcpy(JNIEnv *env, jobject thiz,
jint n_threads) {
UNUSED(thiz);
const char *bench_ggml_memcpy = whisper_bench_memcpy(n_threads);
jstring string = (*env)->NewStringUTF(env, bench_ggml_memcpy);
}
JNIEXPORT jstring JNICALL
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_benchGgmlMulMat(JNIEnv *env, jobject thiz,
jint n_threads) {
UNUSED(thiz);
const char *bench_ggml_mul_mat = whisper_bench_ggml_mul_mat(n_threads);
jstring string = (*env)->NewStringUTF(env, bench_ggml_mul_mat);
}

@ -4381,7 +4381,16 @@ float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int
// Will be removed in the future when ggml becomes a separate library // Will be removed in the future when ggml becomes a separate library
// //
WHISPER_API int whisper_bench_memcpy(int n_threads) { WHISPER_API int whisper_print_bench_memcpy(int n_threads) {
fputs(whisper_bench_memcpy(n_threads), stderr);
return 0;
}
WHISPER_API const char * whisper_bench_memcpy(int n_threads) {
static std::string s;
s = "";
char strbuf[256];
ggml_time_init(); ggml_time_init();
size_t n = 50; size_t n = 50;
@ -4411,7 +4420,8 @@ WHISPER_API int whisper_bench_memcpy(int n_threads) {
src[0] = rand(); src[0] = rand();
} }
fprintf(stderr, "memcpy: %.2f GB/s\n", (double) (n*size)/(tsum*1024llu*1024llu*1024llu)); snprintf(strbuf, sizeof(strbuf), "memcpy: %.2f GB/s\n", (double) (n*size)/(tsum*1024llu*1024llu*1024llu));
s += strbuf;
// needed to prevent the compile from optimizing the memcpy away // needed to prevent the compile from optimizing the memcpy away
{ {
@ -4419,16 +4429,26 @@ WHISPER_API int whisper_bench_memcpy(int n_threads) {
for (size_t i = 0; i < size; i++) sum += dst[i]; for (size_t i = 0; i < size; i++) sum += dst[i];
fprintf(stderr, "sum: %s %f\n", sum == -536870910.00 ? "ok" : "error", sum); snprintf(strbuf, sizeof(strbuf), "sum: %s %f\n", sum == -536870910.00 ? "ok" : "error", sum);
s += strbuf;
} }
free(src); free(src);
free(dst); free(dst);
return s.c_str();
}
WHISPER_API int whisper_print_bench_ggml_mul_mat(int n_threads) {
fputs(whisper_bench_ggml_mul_mat(n_threads), stderr);
return 0; return 0;
} }
WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) { WHISPER_API const char * whisper_bench_ggml_mul_mat(int n_threads) {
static std::string s;
s = "";
char strbuf[256];
ggml_time_init(); ggml_time_init();
const int n_max = 128; const int n_max = 128;
@ -4504,11 +4524,12 @@ WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) {
s = ((2.0*N*N*N*n)/tsum)*1e-9; s = ((2.0*N*N*N*n)/tsum)*1e-9;
} }
fprintf(stderr, "ggml_mul_mat: %5zu x %5zu: F16 %8.1f GFLOPS (%3d runs) / F32 %8.1f GFLOPS (%3d runs)\n", snprintf(strbuf, sizeof(strbuf), "ggml_mul_mat: %5zu x %5zu: F16 %8.1f GFLOPS (%3d runs) / F32 %8.1f GFLOPS (%3d runs)\n",
N, N, s_fp16, n_fp16, s_fp32, n_fp32); N, N, s_fp16, n_fp16, s_fp32, n_fp32);
s += strbuf;
} }
return 0; return s.c_str();
} }
// ================================================================================================= // =================================================================================================

@ -383,8 +383,10 @@ extern "C" {
// Temporary helpers needed for exposing ggml interface // Temporary helpers needed for exposing ggml interface
WHISPER_API int whisper_bench_memcpy(int n_threads); WHISPER_API int whisper_print_bench_memcpy(int n_threads);
WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads); WHISPER_API const char * whisper_bench_memcpy(int n_threads);
WHISPER_API int whisper_print_bench_ggml_mul_mat(int n_threads);
WHISPER_API const char * whisper_bench_ggml_mul_mat(int n_threads);
#ifdef __cplusplus #ifdef __cplusplus
} }

Loading…
Cancel
Save