You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
225 lines
7.2 KiB
225 lines
7.2 KiB
package com.whispercppdemo.ui.main
|
|
|
|
import android.app.Application
|
|
import android.content.Context
|
|
import android.media.MediaPlayer
|
|
import android.util.Log
|
|
import androidx.compose.runtime.getValue
|
|
import androidx.compose.runtime.mutableStateOf
|
|
import androidx.compose.runtime.setValue
|
|
import androidx.core.net.toUri
|
|
import androidx.lifecycle.ViewModel
|
|
import androidx.lifecycle.ViewModelProvider
|
|
import androidx.lifecycle.viewModelScope
|
|
import androidx.lifecycle.viewmodel.initializer
|
|
import androidx.lifecycle.viewmodel.viewModelFactory
|
|
import com.whispercppdemo.media.decodeWaveFile
|
|
import com.whispercppdemo.recorder.Recorder
|
|
import com.whispercppdemo.whisper.WhisperContext
|
|
import kotlinx.coroutines.Dispatchers
|
|
import kotlinx.coroutines.launch
|
|
import kotlinx.coroutines.runBlocking
|
|
import kotlinx.coroutines.withContext
|
|
import java.io.File
|
|
|
|
private const val LOG_TAG = "MainScreenViewModel"
|
|
|
|
class MainScreenViewModel(private val application: Application) : ViewModel() {
|
|
var canTranscribe by mutableStateOf(false)
|
|
private set
|
|
var dataLog by mutableStateOf("")
|
|
private set
|
|
var isRecording by mutableStateOf(false)
|
|
private set
|
|
|
|
private val modelsPath = File(application.filesDir, "models")
|
|
private val samplesPath = File(application.filesDir, "samples")
|
|
private var recorder: Recorder = Recorder()
|
|
private var whisperContext: WhisperContext? = null
|
|
private var mediaPlayer: MediaPlayer? = null
|
|
private var recordedFile: File? = null
|
|
|
|
init {
|
|
viewModelScope.launch {
|
|
printSystemInfo()
|
|
loadData()
|
|
}
|
|
}
|
|
|
|
private suspend fun printSystemInfo() {
|
|
printMessage(String.format("System Info: %s\n", WhisperContext.getSystemInfo()));
|
|
}
|
|
|
|
private suspend fun loadData() {
|
|
printMessage("Loading data...\n")
|
|
try {
|
|
copyAssets()
|
|
loadBaseModel()
|
|
canTranscribe = true
|
|
} catch (e: Exception) {
|
|
Log.w(LOG_TAG, e)
|
|
printMessage("${e.localizedMessage}\n")
|
|
}
|
|
}
|
|
|
|
private suspend fun printMessage(msg: String) = withContext(Dispatchers.Main) {
|
|
dataLog += msg
|
|
}
|
|
|
|
private suspend fun copyAssets() = withContext(Dispatchers.IO) {
|
|
modelsPath.mkdirs()
|
|
samplesPath.mkdirs()
|
|
//application.copyData("models", modelsPath, ::printMessage)
|
|
application.copyData("samples", samplesPath, ::printMessage)
|
|
printMessage("All data copied to working directory.\n")
|
|
}
|
|
|
|
private suspend fun loadBaseModel() = withContext(Dispatchers.IO) {
|
|
printMessage("Loading model...\n")
|
|
val models = application.assets.list("models/")
|
|
if (models != null) {
|
|
whisperContext = WhisperContext.createContextFromAsset(application.assets, "models/" + models[0])
|
|
printMessage("Loaded model ${models[0]}.\n")
|
|
}
|
|
|
|
//val firstModel = modelsPath.listFiles()!!.first()
|
|
//whisperContext = WhisperContext.createContextFromFile(firstModel.absolutePath)
|
|
}
|
|
|
|
fun benchmark() = viewModelScope.launch {
|
|
runBenchmark(6)
|
|
}
|
|
|
|
fun transcribeSample() = viewModelScope.launch {
|
|
transcribeAudio(getFirstSample())
|
|
}
|
|
|
|
private suspend fun runBenchmark(nthreads: Int) {
|
|
if (!canTranscribe) {
|
|
return
|
|
}
|
|
|
|
canTranscribe = false
|
|
|
|
printMessage("Running benchmark. This will take minutes...\n")
|
|
whisperContext?.benchMemory(nthreads)?.let{ printMessage(it) }
|
|
printMessage("\n")
|
|
whisperContext?.benchGgmlMulMat(nthreads)?.let{ printMessage(it) }
|
|
|
|
canTranscribe = true
|
|
}
|
|
|
|
private suspend fun getFirstSample(): File = withContext(Dispatchers.IO) {
|
|
samplesPath.listFiles()!!.first()
|
|
}
|
|
|
|
private suspend fun readAudioSamples(file: File): FloatArray = withContext(Dispatchers.IO) {
|
|
stopPlayback()
|
|
startPlayback(file)
|
|
return@withContext decodeWaveFile(file)
|
|
}
|
|
|
|
private suspend fun stopPlayback() = withContext(Dispatchers.Main) {
|
|
mediaPlayer?.stop()
|
|
mediaPlayer?.release()
|
|
mediaPlayer = null
|
|
}
|
|
|
|
private suspend fun startPlayback(file: File) = withContext(Dispatchers.Main) {
|
|
mediaPlayer = MediaPlayer.create(application, file.absolutePath.toUri())
|
|
mediaPlayer?.start()
|
|
}
|
|
|
|
private suspend fun transcribeAudio(file: File) {
|
|
if (!canTranscribe) {
|
|
return
|
|
}
|
|
|
|
canTranscribe = false
|
|
|
|
try {
|
|
printMessage("Reading wave samples... ")
|
|
val data = readAudioSamples(file)
|
|
printMessage("${data.size / (16000 / 1000)} ms\n")
|
|
printMessage("Transcribing data...\n")
|
|
val start = System.currentTimeMillis()
|
|
val text = whisperContext?.transcribeData(data)
|
|
val elapsed = System.currentTimeMillis() - start
|
|
printMessage("Done ($elapsed ms): $text\n")
|
|
} catch (e: Exception) {
|
|
Log.w(LOG_TAG, e)
|
|
printMessage("${e.localizedMessage}\n")
|
|
}
|
|
|
|
canTranscribe = true
|
|
}
|
|
|
|
fun toggleRecord() = viewModelScope.launch {
|
|
try {
|
|
if (isRecording) {
|
|
recorder.stopRecording()
|
|
isRecording = false
|
|
recordedFile?.let { transcribeAudio(it) }
|
|
} else {
|
|
stopPlayback()
|
|
val file = getTempFileForRecording()
|
|
recorder.startRecording(file) { e ->
|
|
viewModelScope.launch {
|
|
withContext(Dispatchers.Main) {
|
|
printMessage("${e.localizedMessage}\n")
|
|
isRecording = false
|
|
}
|
|
}
|
|
}
|
|
isRecording = true
|
|
recordedFile = file
|
|
}
|
|
} catch (e: Exception) {
|
|
Log.w(LOG_TAG, e)
|
|
printMessage("${e.localizedMessage}\n")
|
|
isRecording = false
|
|
}
|
|
}
|
|
|
|
private suspend fun getTempFileForRecording() = withContext(Dispatchers.IO) {
|
|
File.createTempFile("recording", "wav")
|
|
}
|
|
|
|
override fun onCleared() {
|
|
runBlocking {
|
|
whisperContext?.release()
|
|
whisperContext = null
|
|
stopPlayback()
|
|
}
|
|
}
|
|
|
|
companion object {
|
|
fun factory() = viewModelFactory {
|
|
initializer {
|
|
val application =
|
|
this[ViewModelProvider.AndroidViewModelFactory.APPLICATION_KEY] as Application
|
|
MainScreenViewModel(application)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
private suspend fun Context.copyData(
|
|
assetDirName: String,
|
|
destDir: File,
|
|
printMessage: suspend (String) -> Unit
|
|
) = withContext(Dispatchers.IO) {
|
|
assets.list(assetDirName)?.forEach { name ->
|
|
val assetPath = "$assetDirName/$name"
|
|
Log.v(LOG_TAG, "Processing $assetPath...")
|
|
val destination = File(destDir, name)
|
|
Log.v(LOG_TAG, "Copying $assetPath to $destination...")
|
|
printMessage("Copying $name...\n")
|
|
assets.open(assetPath).use { input ->
|
|
destination.outputStream().use { output ->
|
|
input.copyTo(output)
|
|
}
|
|
}
|
|
Log.v(LOG_TAG, "Copied $assetPath to $destination")
|
|
}
|
|
} |