You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
163 lines
4.8 KiB
163 lines
4.8 KiB
import Foundation
|
|
import SwiftUI
|
|
import AVFoundation
|
|
|
|
@MainActor
|
|
class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
|
|
@Published var isModelLoaded = false
|
|
@Published var messageLog = ""
|
|
@Published var canTranscribe = false
|
|
@Published var isRecording = false
|
|
|
|
private var whisperContext: WhisperContext?
|
|
private let recorder = Recorder()
|
|
private var recordedFile: URL? = nil
|
|
private var audioPlayer: AVAudioPlayer?
|
|
|
|
private var modelUrl: URL? {
|
|
Bundle.main.url(forResource: "ggml-tiny.en", withExtension: "bin", subdirectory: "models")
|
|
}
|
|
|
|
private var sampleUrl: URL? {
|
|
Bundle.main.url(forResource: "jfk", withExtension: "wav", subdirectory: "samples")
|
|
}
|
|
|
|
private enum LoadError: Error {
|
|
case couldNotLocateModel
|
|
}
|
|
|
|
override init() {
|
|
super.init()
|
|
do {
|
|
try loadModel()
|
|
canTranscribe = true
|
|
} catch {
|
|
print(error.localizedDescription)
|
|
messageLog += "\(error.localizedDescription)\n"
|
|
}
|
|
}
|
|
|
|
private func loadModel() throws {
|
|
messageLog += "Loading model...\n"
|
|
if let modelUrl {
|
|
whisperContext = try WhisperContext.createContext(path: modelUrl.path())
|
|
messageLog += "Loaded model \(modelUrl.lastPathComponent)\n"
|
|
} else {
|
|
messageLog += "Could not locate model\n"
|
|
}
|
|
}
|
|
|
|
func transcribeSample() async {
|
|
if let sampleUrl {
|
|
await transcribeAudio(sampleUrl)
|
|
} else {
|
|
messageLog += "Could not locate sample\n"
|
|
}
|
|
}
|
|
|
|
private func transcribeAudio(_ url: URL) async {
|
|
if (!canTranscribe) {
|
|
return
|
|
}
|
|
guard let whisperContext else {
|
|
return
|
|
}
|
|
|
|
do {
|
|
canTranscribe = false
|
|
messageLog += "Reading wave samples...\n"
|
|
let data = try readAudioSamples(url)
|
|
messageLog += "Transcribing data...\n"
|
|
await whisperContext.fullTranscribe(samples: data)
|
|
let text = await whisperContext.getTranscription()
|
|
messageLog += "Done: \(text)\n"
|
|
} catch {
|
|
print(error.localizedDescription)
|
|
messageLog += "\(error.localizedDescription)\n"
|
|
}
|
|
|
|
canTranscribe = true
|
|
}
|
|
|
|
private func readAudioSamples(_ url: URL) throws -> [Float] {
|
|
stopPlayback()
|
|
try startPlayback(url)
|
|
return try decodeWaveFile(url)
|
|
}
|
|
|
|
func toggleRecord() async {
|
|
if isRecording {
|
|
await recorder.stopRecording()
|
|
isRecording = false
|
|
if let recordedFile {
|
|
await transcribeAudio(recordedFile)
|
|
}
|
|
} else {
|
|
requestRecordPermission { granted in
|
|
if granted {
|
|
Task {
|
|
do {
|
|
self.stopPlayback()
|
|
let file = try FileManager.default.url(for: .documentDirectory, in: .userDomainMask, appropriateFor: nil, create: true)
|
|
.appending(path: "output.wav")
|
|
try await self.recorder.startRecording(toOutputFile: file, delegate: self)
|
|
self.isRecording = true
|
|
self.recordedFile = file
|
|
} catch {
|
|
print(error.localizedDescription)
|
|
self.messageLog += "\(error.localizedDescription)\n"
|
|
self.isRecording = false
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
private func requestRecordPermission(response: @escaping (Bool) -> Void) {
|
|
#if os(macOS)
|
|
response(true)
|
|
#else
|
|
AVAudioSession.sharedInstance().requestRecordPermission { granted in
|
|
response(granted)
|
|
}
|
|
#endif
|
|
}
|
|
|
|
private func startPlayback(_ url: URL) throws {
|
|
audioPlayer = try AVAudioPlayer(contentsOf: url)
|
|
audioPlayer?.play()
|
|
}
|
|
|
|
private func stopPlayback() {
|
|
audioPlayer?.stop()
|
|
audioPlayer = nil
|
|
}
|
|
|
|
// MARK: AVAudioRecorderDelegate
|
|
|
|
nonisolated func audioRecorderEncodeErrorDidOccur(_ recorder: AVAudioRecorder, error: Error?) {
|
|
if let error {
|
|
Task {
|
|
await handleRecError(error)
|
|
}
|
|
}
|
|
}
|
|
|
|
private func handleRecError(_ error: Error) {
|
|
print(error.localizedDescription)
|
|
messageLog += "\(error.localizedDescription)\n"
|
|
isRecording = false
|
|
}
|
|
|
|
nonisolated func audioRecorderDidFinishRecording(_ recorder: AVAudioRecorder, successfully flag: Bool) {
|
|
Task {
|
|
await onDidFinishRecording()
|
|
}
|
|
}
|
|
|
|
private func onDidFinishRecording() {
|
|
isRecording = false
|
|
}
|
|
}
|