bindings : initial import of golang bindings (#287)
* Initial import of golang bindings * Updated makefile rules * Updated bindings * Makefile update to add in more testspull/298/head
parent
90564f85f9
commit
231bebca7d
@ -0,0 +1,3 @@
|
|||||||
|
build
|
||||||
|
models
|
||||||
|
go.sum
|
@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2022 David Thorpe
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
@ -0,0 +1,38 @@
|
|||||||
|
CMAKE := $(shell which cmake)
|
||||||
|
BUILD_DIR := "build"
|
||||||
|
MODELS_DIR := "models"
|
||||||
|
EXAMPLES_DIR := $(wildcard examples/*)
|
||||||
|
C_INCLUDE_PATH := "../.."
|
||||||
|
|
||||||
|
all: clean whisper examples
|
||||||
|
|
||||||
|
whisper: mkdir
|
||||||
|
@echo Build whisper
|
||||||
|
@${CMAKE} -S ../.. -B ${BUILD_DIR} -D BUILD_SHARED_LIBS=off -D WHISPER_NO_AVX2=on
|
||||||
|
@${CMAKE} --build ${BUILD_DIR} --target whisper
|
||||||
|
|
||||||
|
test: model-small whisper
|
||||||
|
@go mod tidy
|
||||||
|
@go test -v .
|
||||||
|
@go test -v ./pkg/whisper/...
|
||||||
|
|
||||||
|
examples: $(EXAMPLES_DIR)
|
||||||
|
|
||||||
|
model-small: mkdir examples/go-model-download
|
||||||
|
@${BUILD_DIR}/go-model-download -out models small.en
|
||||||
|
|
||||||
|
$(EXAMPLES_DIR): mkdir whisper
|
||||||
|
@echo Build example $(notdir $@)
|
||||||
|
@go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@
|
||||||
|
|
||||||
|
mkdir:
|
||||||
|
@echo Mkdir ${BUILD_DIR}
|
||||||
|
@install -d ${BUILD_DIR}
|
||||||
|
@echo Mkdir ${MODELS_DIR}
|
||||||
|
@install -d ${MODELS_DIR}
|
||||||
|
|
||||||
|
clean:
|
||||||
|
@echo Clean
|
||||||
|
@rm -fr $(BUILD_DIR)
|
||||||
|
@go mod tidy
|
||||||
|
@go clean
|
@ -0,0 +1,77 @@
|
|||||||
|
# Go bindings for Whisper
|
||||||
|
|
||||||
|
This package provides Go bindings for whisper.cpp. They have been tested on:
|
||||||
|
|
||||||
|
* Darwin (OS X) 12.6 on x64_64
|
||||||
|
* Debian Linux on arm64
|
||||||
|
* Fedora Linux on x86_64
|
||||||
|
|
||||||
|
The "low level" bindings are in the `bindings/go` directory and there is a more
|
||||||
|
Go-style package in the `bindings/go/pkg/whisper` directory. The most simple usage
|
||||||
|
is as follows:
|
||||||
|
|
||||||
|
```go
|
||||||
|
import (
|
||||||
|
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
var modelpath string // Path to the model
|
||||||
|
var samples []float32 // Samples to process
|
||||||
|
|
||||||
|
// Load the model
|
||||||
|
model, err := whisper.New(modelpath)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer model.Close()
|
||||||
|
|
||||||
|
// Process samples
|
||||||
|
context, err := model.NewContext()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if err := context.Process(samples, nil); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print out the results
|
||||||
|
for {
|
||||||
|
segment, err := context.NextSegment()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
fmt.Printf("[%6s->%6s] %s\n", segment.Start, segment.End, segment.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Building & Testing
|
||||||
|
|
||||||
|
In order to build, you need to have the Go compiler installed. You can get it from [here](https://golang.org/dl/). Run the tests with:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/ggerganov/whisper.cpp.git
|
||||||
|
cd whisper.cpp/bindings/go
|
||||||
|
make test
|
||||||
|
```
|
||||||
|
|
||||||
|
This will compile a static `libwhisper.a` in a `build` folder, download a model file, then run the tests. To build the examples:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make examples
|
||||||
|
```
|
||||||
|
|
||||||
|
The examples are placed in the `build` directory. Once built, you can download all the models with the following command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./build/go-model-download -out models
|
||||||
|
```
|
||||||
|
|
||||||
|
And you can then test a model against samples with the following command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./build/go-whisper -model models/ggml-tiny.en.bin samples/jfk.wav
|
||||||
|
```
|
||||||
|
|
||||||
|
|
@ -0,0 +1,5 @@
|
|||||||
|
/*
|
||||||
|
github.com/ggerganov/whisper.cpp/bindings/go
|
||||||
|
provides a speech-to-text service bindings for the Go programming language.
|
||||||
|
*/
|
||||||
|
package whisper
|
@ -0,0 +1,30 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ContextForSignal returns a context object which is cancelled when a signal
|
||||||
|
// is received. It returns nil if no signal parameter is provided
|
||||||
|
func ContextForSignal(signals ...os.Signal) context.Context {
|
||||||
|
if len(signals) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := make(chan os.Signal)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
// Send message on channel when signal received
|
||||||
|
signal.Notify(ch, signals...)
|
||||||
|
|
||||||
|
// When any signal received, call cancel
|
||||||
|
go func() {
|
||||||
|
<-ch
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Return success
|
||||||
|
return ctx
|
||||||
|
}
|
@ -0,0 +1,206 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// CONSTANTS
|
||||||
|
|
||||||
|
const (
|
||||||
|
srcUrl = "https://huggingface.co/" // The location of the models
|
||||||
|
srcPathPrefix = "/datasets/ggerganov/whisper.cpp/resolve/main/ggml" // Filename prefix
|
||||||
|
srcExt = ".bin" // Filename extension
|
||||||
|
bufSize = 1024 * 64 // Size of the buffer used for downloading the model
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// The models which will be downloaded, if no model is specified as an argument
|
||||||
|
modelNames = []string{"tiny.en", "tiny", "base.en", "base", "small.en", "small", "medium.en", "medium", "large-v1", "large"}
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// The output folder. When not set, use current working directory.
|
||||||
|
flagOut = flag.String("out", "", "Output folder")
|
||||||
|
|
||||||
|
// HTTP timeout parameter - will timeout if takes longer than this to download a model
|
||||||
|
flagTimeout = flag.Duration("timeout", 30*time.Minute, "HTTP timeout")
|
||||||
|
|
||||||
|
// Quiet parameter - will not print progress if set
|
||||||
|
flagQuiet = flag.Bool("quiet", false, "Quiet mode")
|
||||||
|
)
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// MAIN
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
flag.Usage = func() {
|
||||||
|
name := filepath.Base(flag.CommandLine.Name())
|
||||||
|
fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s [options] <model>\n\n", name)
|
||||||
|
flag.PrintDefaults()
|
||||||
|
}
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
// Get output path
|
||||||
|
out, err := GetOut()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "Error:", err)
|
||||||
|
os.Exit(-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create context which quits on SIGINT or SIGQUIT
|
||||||
|
ctx := ContextForSignal(os.Interrupt, syscall.SIGQUIT)
|
||||||
|
|
||||||
|
// Progress filehandle
|
||||||
|
progress := os.Stdout
|
||||||
|
if *flagQuiet {
|
||||||
|
progress, err = os.Open(os.DevNull)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "Error:", err)
|
||||||
|
os.Exit(-1)
|
||||||
|
}
|
||||||
|
defer progress.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Download models - exit on error or interrupt
|
||||||
|
for _, model := range GetModels() {
|
||||||
|
url, err := URLForModel(model)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "Error:", err)
|
||||||
|
continue
|
||||||
|
} else if path, err := Download(ctx, progress, url, out); err == nil || err == io.EOF {
|
||||||
|
continue
|
||||||
|
} else if err == context.Canceled {
|
||||||
|
os.Remove(path)
|
||||||
|
fmt.Fprintln(progress, "\nInterrupted")
|
||||||
|
break
|
||||||
|
} else if err == context.DeadlineExceeded {
|
||||||
|
os.Remove(path)
|
||||||
|
fmt.Fprintln(progress, "Timeout downloading model")
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
os.Remove(path)
|
||||||
|
fmt.Fprintln(os.Stderr, "Error:", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// PUBLIC METHODS
|
||||||
|
|
||||||
|
// GetOut returns the path to the output directory
|
||||||
|
func GetOut() (string, error) {
|
||||||
|
if *flagOut == "" {
|
||||||
|
return os.Getwd()
|
||||||
|
}
|
||||||
|
if info, err := os.Stat(*flagOut); err != nil {
|
||||||
|
return "", err
|
||||||
|
} else if !info.IsDir() {
|
||||||
|
return "", fmt.Errorf("not a directory: %s", info.Name())
|
||||||
|
} else {
|
||||||
|
return *flagOut, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModels returns the list of models to download
|
||||||
|
func GetModels() []string {
|
||||||
|
if flag.NArg() == 0 {
|
||||||
|
return modelNames
|
||||||
|
} else {
|
||||||
|
return flag.Args()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// URLForModel returns the URL for the given model on huggingface.co
|
||||||
|
func URLForModel(model string) (string, error) {
|
||||||
|
url, err := url.Parse(srcUrl)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
} else {
|
||||||
|
url.Path = srcPathPrefix + "-" + model + srcExt
|
||||||
|
}
|
||||||
|
return url.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Download downloads the model from the given URL to the given output directory
|
||||||
|
func Download(ctx context.Context, p io.Writer, model, out string) (string, error) {
|
||||||
|
// Create HTTP client
|
||||||
|
client := http.Client{
|
||||||
|
Timeout: *flagTimeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initiate the download
|
||||||
|
req, err := http.NewRequest("GET", model, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return "", fmt.Errorf("%s: %s", model, resp.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If output file exists and is the same size as the model, skip
|
||||||
|
path := filepath.Join(out, filepath.Base(model))
|
||||||
|
if info, err := os.Stat(path); err == nil && info.Size() == resp.ContentLength {
|
||||||
|
fmt.Fprintln(p, "Skipping", model, "as it already exists")
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create file
|
||||||
|
w, err := os.Create(path)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer w.Close()
|
||||||
|
|
||||||
|
// Report
|
||||||
|
fmt.Fprintln(p, "Downloading", model, "to", out)
|
||||||
|
|
||||||
|
// Progressively download the model
|
||||||
|
data := make([]byte, bufSize)
|
||||||
|
count, pct := int64(0), int64(0)
|
||||||
|
ticker := time.NewTicker(5 * time.Second)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
// Cancelled, return error
|
||||||
|
return path, ctx.Err()
|
||||||
|
case <-ticker.C:
|
||||||
|
pct = DownloadReport(p, pct, count, resp.ContentLength)
|
||||||
|
default:
|
||||||
|
// Read body
|
||||||
|
n, err := resp.Body.Read(data)
|
||||||
|
if err != nil {
|
||||||
|
DownloadReport(p, pct, count, resp.ContentLength)
|
||||||
|
return path, err
|
||||||
|
} else if m, err := w.Write(data[:n]); err != nil {
|
||||||
|
return path, err
|
||||||
|
} else {
|
||||||
|
count += int64(m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Report periodically reports the download progress when percentage changes
|
||||||
|
func DownloadReport(w io.Writer, pct, count, total int64) int64 {
|
||||||
|
pct_ := count * 100 / total
|
||||||
|
if pct_ > pct {
|
||||||
|
fmt.Fprintf(w, " ...%d MB written (%d%%)\n", count/1e6, pct_)
|
||||||
|
}
|
||||||
|
return pct_
|
||||||
|
}
|
@ -0,0 +1,61 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
)
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// TYPES
|
||||||
|
|
||||||
|
type Flags struct {
|
||||||
|
*flag.FlagSet
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// LIFECYCLE
|
||||||
|
|
||||||
|
func NewFlags(name string, args []string) (*Flags, error) {
|
||||||
|
flags := &Flags{
|
||||||
|
FlagSet: flag.NewFlagSet(name, flag.ContinueOnError),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register the command line arguments
|
||||||
|
registerFlags(flags)
|
||||||
|
|
||||||
|
// Parse command line
|
||||||
|
if err := flags.Parse(args); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return success
|
||||||
|
return flags, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// PUBLIC METHODS
|
||||||
|
|
||||||
|
func (flags *Flags) GetModel() string {
|
||||||
|
return flags.Lookup("model").Value.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (flags *Flags) GetLanguage() string {
|
||||||
|
return flags.Lookup("language").Value.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (flags *Flags) IsSpeedup() bool {
|
||||||
|
return flags.Lookup("speedup").Value.String() == "true"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (flags *Flags) IsTokens() bool {
|
||||||
|
return flags.Lookup("tokens").Value.String() == "true"
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// PRIVATE METHODS
|
||||||
|
|
||||||
|
func registerFlags(flag *Flags) {
|
||||||
|
flag.String("model", "", "Path to the model file")
|
||||||
|
flag.String("language", "", "Language")
|
||||||
|
flag.Bool("speedup", false, "Enable speedup")
|
||||||
|
flag.Bool("tokens", false, "Display tokens")
|
||||||
|
}
|
@ -0,0 +1,44 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
// Packages
|
||||||
|
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
flags, err := NewFlags(filepath.Base(os.Args[0]), os.Args[1:])
|
||||||
|
if err == flag.ErrHelp {
|
||||||
|
os.Exit(0)
|
||||||
|
} else if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, err)
|
||||||
|
os.Exit(1)
|
||||||
|
} else if flags.GetModel() == "" {
|
||||||
|
fmt.Fprintln(os.Stderr, "Use -model flag to specify which model file to use")
|
||||||
|
os.Exit(1)
|
||||||
|
} else if flags.NArg() == 0 {
|
||||||
|
fmt.Fprintln(os.Stderr, "No input files specified")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load model
|
||||||
|
model, err := whisper.New(flags.GetModel())
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
defer model.Close()
|
||||||
|
|
||||||
|
// Process files
|
||||||
|
for _, filename := range flags.Args() {
|
||||||
|
fmt.Println("Processing", filename)
|
||||||
|
if err := Process(model, filename, flags.GetLanguage(), flags.IsSpeedup(), flags.IsTokens()); err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,80 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
// Package imports
|
||||||
|
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||||
|
wav "github.com/go-audio/wav"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Process(model whisper.Model, path string, lang string, speedup, tokens bool) error {
|
||||||
|
var data []float32
|
||||||
|
|
||||||
|
// Create processing context
|
||||||
|
context, err := model.NewContext()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open the file
|
||||||
|
fh, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer fh.Close()
|
||||||
|
|
||||||
|
// Decode the WAV file
|
||||||
|
dec := wav.NewDecoder(fh)
|
||||||
|
if buf, err := dec.FullPCMBuffer(); err != nil {
|
||||||
|
return err
|
||||||
|
} else if dec.SampleRate != whisper.SampleRate {
|
||||||
|
return fmt.Errorf("unsupported sample rate: %d", dec.SampleRate)
|
||||||
|
} else if dec.NumChans != 1 {
|
||||||
|
return fmt.Errorf("unsupported number of channels: %d", dec.NumChans)
|
||||||
|
} else {
|
||||||
|
data = buf.AsFloat32Buffer().Data
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the parameters
|
||||||
|
var cb whisper.SegmentCallback
|
||||||
|
if lang != "" {
|
||||||
|
if err := context.SetLanguage(lang); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if speedup {
|
||||||
|
context.SetSpeedup(true)
|
||||||
|
}
|
||||||
|
if tokens {
|
||||||
|
cb = func(segment whisper.Segment) {
|
||||||
|
fmt.Printf("%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
|
||||||
|
for _, token := range segment.Tokens {
|
||||||
|
fmt.Printf("%q ", token.Text)
|
||||||
|
}
|
||||||
|
fmt.Println("")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the data
|
||||||
|
if err := context.Process(data, cb); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print out the results
|
||||||
|
for {
|
||||||
|
segment, err := context.NextSegment()
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
} else if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fmt.Printf("[%6s->%6s] %s\n", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond), segment.Text)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return success
|
||||||
|
return nil
|
||||||
|
}
|
@ -0,0 +1,16 @@
|
|||||||
|
module github.com/ggerganov/whisper.cpp/bindings/go
|
||||||
|
|
||||||
|
go 1.19
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/go-audio/wav v1.1.0
|
||||||
|
github.com/stretchr/testify v1.8.1
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
|
github.com/go-audio/audio v1.0.0 // indirect
|
||||||
|
github.com/go-audio/riff v1.0.0 // indirect
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
|
)
|
@ -0,0 +1,134 @@
|
|||||||
|
package whisper
|
||||||
|
|
||||||
|
// This file defines the whisper_token, whisper_token_data and whisper_full_params
|
||||||
|
// structures, which are used by the whisper_full() function.
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// CGO
|
||||||
|
|
||||||
|
/*
|
||||||
|
#include <whisper.h>
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// PUBLIC METHODS
|
||||||
|
|
||||||
|
func (p *Params) SetTranslate(v bool) {
|
||||||
|
p.translate = toBool(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Params) SetNoContext(v bool) {
|
||||||
|
p.no_context = toBool(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Params) SetSingleSegment(v bool) {
|
||||||
|
p.single_segment = toBool(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Params) SetPrintSpecial(v bool) {
|
||||||
|
p.print_special = toBool(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Params) SetPrintProgress(v bool) {
|
||||||
|
p.print_progress = toBool(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Params) SetPrintRealtime(v bool) {
|
||||||
|
p.print_realtime = toBool(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Params) SetPrintTimestamps(v bool) {
|
||||||
|
p.print_timestamps = toBool(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Params) SetSpeedup(v bool) {
|
||||||
|
p.speed_up = toBool(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Params) SetLanguage(lang int) error {
|
||||||
|
str := C.whisper_lang_str(C.int(lang))
|
||||||
|
if str == nil {
|
||||||
|
return ErrInvalidLanguage
|
||||||
|
} else {
|
||||||
|
p.language = str
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Params) Language() int {
|
||||||
|
if p.language == nil {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return int(C.whisper_lang_id(p.language))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Params) SetThreads(threads int) {
|
||||||
|
p.n_threads = C.int(threads)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Params) SetOffset(offset_ms int) {
|
||||||
|
p.offset_ms = C.int(offset_ms)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Params) SetDuration(duration_ms int) {
|
||||||
|
p.duration_ms = C.int(duration_ms)
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// PRIVATE METHODS
|
||||||
|
|
||||||
|
func toBool(v bool) C.bool {
|
||||||
|
if v {
|
||||||
|
return C.bool(true)
|
||||||
|
}
|
||||||
|
return C.bool(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// STRINGIFY
|
||||||
|
|
||||||
|
func (p *Params) String() string {
|
||||||
|
str := "<whisper.params"
|
||||||
|
str += fmt.Sprintf(" strategy=%v", p.strategy)
|
||||||
|
str += fmt.Sprintf(" n_threads=%d", p.n_threads)
|
||||||
|
if p.language != nil {
|
||||||
|
str += fmt.Sprintf(" language=%s", C.GoString(p.language))
|
||||||
|
}
|
||||||
|
str += fmt.Sprintf(" n_max_text_ctx=%d", p.n_max_text_ctx)
|
||||||
|
str += fmt.Sprintf(" offset_ms=%d", p.offset_ms)
|
||||||
|
str += fmt.Sprintf(" duration_ms=%d", p.duration_ms)
|
||||||
|
if p.translate {
|
||||||
|
str += " translate"
|
||||||
|
}
|
||||||
|
if p.no_context {
|
||||||
|
str += " no_context"
|
||||||
|
}
|
||||||
|
if p.single_segment {
|
||||||
|
str += " single_segment"
|
||||||
|
}
|
||||||
|
if p.print_special {
|
||||||
|
str += " print_special"
|
||||||
|
}
|
||||||
|
if p.print_progress {
|
||||||
|
str += " print_progress"
|
||||||
|
}
|
||||||
|
if p.print_realtime {
|
||||||
|
str += " print_realtime"
|
||||||
|
}
|
||||||
|
if p.print_timestamps {
|
||||||
|
str += " print_timestamps"
|
||||||
|
}
|
||||||
|
if p.token_timestamps {
|
||||||
|
str += " token_timestamps"
|
||||||
|
}
|
||||||
|
if p.speed_up {
|
||||||
|
str += " speed_up"
|
||||||
|
}
|
||||||
|
|
||||||
|
return str + ">"
|
||||||
|
}
|
@ -0,0 +1,27 @@
|
|||||||
|
package whisper
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
// Bindings
|
||||||
|
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||||
|
)
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// ERRORS
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrUnableToLoadModel = errors.New("unable to load model")
|
||||||
|
ErrInternalAppError = errors.New("internal application error")
|
||||||
|
ErrProcessingFailed = errors.New("processing failed")
|
||||||
|
ErrUnsupportedLanguage = errors.New("unsupported language")
|
||||||
|
)
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// CONSTANTS
|
||||||
|
|
||||||
|
// SampleRate is the sample rate of the audio data.
|
||||||
|
const SampleRate = whisper.SampleRate
|
||||||
|
|
||||||
|
// SampleBits is the number of bytes per sample.
|
||||||
|
const SampleBits = whisper.SampleBits
|
@ -0,0 +1,145 @@
|
|||||||
|
package whisper
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
// Bindings
|
||||||
|
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||||
|
)
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// TYPES
|
||||||
|
|
||||||
|
type context struct {
|
||||||
|
n int
|
||||||
|
model *model
|
||||||
|
params whisper.Params
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure context adheres to the interface
|
||||||
|
var _ Context = (*context)(nil)
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// LIFECYCLE
|
||||||
|
|
||||||
|
func NewContext(model *model, params whisper.Params) (Context, error) {
|
||||||
|
context := new(context)
|
||||||
|
context.model = model
|
||||||
|
context.params = params
|
||||||
|
|
||||||
|
// Return success
|
||||||
|
return context, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// PUBLIC METHODS
|
||||||
|
|
||||||
|
// Set the language to use for speech recognition.
|
||||||
|
func (context *context) SetLanguage(lang string) error {
|
||||||
|
if context.model.ctx == nil {
|
||||||
|
return ErrInternalAppError
|
||||||
|
}
|
||||||
|
if id := context.model.ctx.Whisper_lang_id(lang); id < 0 {
|
||||||
|
return ErrUnsupportedLanguage
|
||||||
|
} else if err := context.params.SetLanguage(id); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Return success
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get language
|
||||||
|
func (context *context) Language() string {
|
||||||
|
return whisper.Whisper_lang_str(context.params.Language())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set speedup flag
|
||||||
|
func (context *context) SetSpeedup(v bool) {
|
||||||
|
context.params.SetSpeedup(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process new sample data and return any errors
|
||||||
|
func (context *context) Process(data []float32, cb SegmentCallback) error {
|
||||||
|
if context.model.ctx == nil {
|
||||||
|
return ErrInternalAppError
|
||||||
|
}
|
||||||
|
// If the callback is defined then we force on single_segment mode
|
||||||
|
if cb != nil {
|
||||||
|
context.params.SetSingleSegment(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// We don't do parallel processing at the moment
|
||||||
|
processors := 0
|
||||||
|
if processors > 1 {
|
||||||
|
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
|
||||||
|
if cb != nil {
|
||||||
|
num_segments := context.model.ctx.Whisper_full_n_segments()
|
||||||
|
s0 := num_segments - new
|
||||||
|
for i := s0; i < num_segments; i++ {
|
||||||
|
cb(toSegment(context.model.ctx, i))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
|
||||||
|
if cb != nil {
|
||||||
|
num_segments := context.model.ctx.Whisper_full_n_segments()
|
||||||
|
s0 := num_segments - new
|
||||||
|
for i := s0; i < num_segments; i++ {
|
||||||
|
cb(toSegment(context.model.ctx, i))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return success
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the next segment of tokens
|
||||||
|
func (context *context) NextSegment() (Segment, error) {
|
||||||
|
if context.model.ctx == nil {
|
||||||
|
return Segment{}, ErrInternalAppError
|
||||||
|
}
|
||||||
|
if context.n >= context.model.ctx.Whisper_full_n_segments() {
|
||||||
|
return Segment{}, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
// Populate result
|
||||||
|
result := toSegment(context.model.ctx, context.n)
|
||||||
|
|
||||||
|
// Increment the cursor
|
||||||
|
context.n++
|
||||||
|
|
||||||
|
// Return success
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// PRIVATE METHODS
|
||||||
|
|
||||||
|
func toSegment(ctx *whisper.Context, n int) Segment {
|
||||||
|
return Segment{
|
||||||
|
Num: n,
|
||||||
|
Text: strings.TrimSpace(ctx.Whisper_full_get_segment_text(n)),
|
||||||
|
Start: time.Duration(ctx.Whisper_full_get_segment_t0(n)) * time.Millisecond * 10,
|
||||||
|
End: time.Duration(ctx.Whisper_full_get_segment_t1(n)) * time.Millisecond * 10,
|
||||||
|
Tokens: toTokens(ctx, n),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toTokens(ctx *whisper.Context, n int) []Token {
|
||||||
|
result := make([]Token, ctx.Whisper_full_n_tokens(n))
|
||||||
|
for i := 0; i < len(result); i++ {
|
||||||
|
result[i] = Token{
|
||||||
|
Id: int(ctx.Whisper_full_get_token_id(n, i)),
|
||||||
|
Text: strings.TrimSpace(ctx.Whisper_full_get_token_text(n, i)),
|
||||||
|
P: ctx.Whisper_full_get_token_p(n, i),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
@ -0,0 +1,55 @@
|
|||||||
|
package whisper_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
// Packages
|
||||||
|
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||||
|
assert "github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ModelPath = "../../models/ggml-tiny.bin"
|
||||||
|
SamplePath = "../../samples/jfk.wav"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_Whisper_000(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||||
|
t.Skip("Skipping test, model not found:", ModelPath)
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||||
|
t.Skip("Skipping test, sample not found:", SamplePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load model
|
||||||
|
model, err := whisper.New(ModelPath)
|
||||||
|
assert.NoError(err)
|
||||||
|
assert.NotNil(model)
|
||||||
|
assert.NoError(model.Close())
|
||||||
|
|
||||||
|
t.Log("languages=", model.Languages())
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_Whisper_001(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||||
|
t.Skip("Skipping test, model not found:", ModelPath)
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||||
|
t.Skip("Skipping test, sample not found:", SamplePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load model
|
||||||
|
model, err := whisper.New(ModelPath)
|
||||||
|
assert.NoError(err)
|
||||||
|
assert.NotNil(model)
|
||||||
|
defer model.Close()
|
||||||
|
|
||||||
|
// Get context for decoding
|
||||||
|
ctx, err := model.NewContext()
|
||||||
|
assert.NoError(err)
|
||||||
|
assert.NotNil(ctx)
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,4 @@
|
|||||||
|
/*
|
||||||
|
This is the higher-level speech-to-text whisper.cpp API for go
|
||||||
|
*/
|
||||||
|
package whisper
|
@ -0,0 +1,63 @@
|
|||||||
|
package whisper
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// TYPES
|
||||||
|
|
||||||
|
// SegmentCallback is the callback function for processing segments in real
|
||||||
|
// time. It is called during the Process function
|
||||||
|
type SegmentCallback func(Segment)
|
||||||
|
|
||||||
|
// Model is the interface to a whisper model. Create a new model with the
|
||||||
|
// function whisper.New(string)
|
||||||
|
type Model interface {
|
||||||
|
io.Closer
|
||||||
|
|
||||||
|
// Return a new speech-to-text context.
|
||||||
|
NewContext() (Context, error)
|
||||||
|
|
||||||
|
// Return all languages supported.
|
||||||
|
Languages() []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Context is the speach recognition context.
|
||||||
|
type Context interface {
|
||||||
|
SetLanguage(string) error // Set the language to use for speech recognition.
|
||||||
|
Language() string // Get language
|
||||||
|
SetSpeedup(bool) // Set speedup flag
|
||||||
|
|
||||||
|
// Process mono audio data and return any errors.
|
||||||
|
// If defined, newly generated segments are passed to the
|
||||||
|
// callback function during processing.
|
||||||
|
Process([]float32, SegmentCallback) error
|
||||||
|
|
||||||
|
// After process is called, return segments until the end of the stream
|
||||||
|
// is reached, when io.EOF is returned.
|
||||||
|
NextSegment() (Segment, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Segment is the text result of a speech recognition.
|
||||||
|
type Segment struct {
|
||||||
|
// Segment Number
|
||||||
|
Num int
|
||||||
|
|
||||||
|
// Time beginning and end timestamps for the segment.
|
||||||
|
Start, End time.Duration
|
||||||
|
|
||||||
|
// The text of the segment.
|
||||||
|
Text string
|
||||||
|
|
||||||
|
// The tokens of the segment.
|
||||||
|
Tokens []Token
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token is a text or special token
|
||||||
|
type Token struct {
|
||||||
|
Id int
|
||||||
|
Text string
|
||||||
|
P float32
|
||||||
|
}
|
@ -0,0 +1,95 @@
|
|||||||
|
package whisper
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
|
// Bindings
|
||||||
|
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||||
|
)
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// TYPES
|
||||||
|
|
||||||
|
type model struct {
|
||||||
|
path string
|
||||||
|
ctx *whisper.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure model adheres to the interface
|
||||||
|
var _ Model = (*model)(nil)
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// LIFECYCLE
|
||||||
|
|
||||||
|
func New(path string) (*model, error) {
|
||||||
|
model := new(model)
|
||||||
|
if _, err := os.Stat(path); err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else if ctx := whisper.Whisper_init(path); ctx == nil {
|
||||||
|
return nil, ErrUnableToLoadModel
|
||||||
|
} else {
|
||||||
|
model.ctx = ctx
|
||||||
|
model.path = path
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return success
|
||||||
|
return model, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (model *model) Close() error {
|
||||||
|
if model.ctx != nil {
|
||||||
|
model.ctx.Whisper_free()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release resources
|
||||||
|
model.ctx = nil
|
||||||
|
|
||||||
|
// Return success
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// STRINGIFY
|
||||||
|
|
||||||
|
func (model *model) String() string {
|
||||||
|
str := "<whisper.model"
|
||||||
|
if model.ctx != nil {
|
||||||
|
str += fmt.Sprintf(" model=%q", model.path)
|
||||||
|
}
|
||||||
|
return str + ">"
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// PUBLIC METHODS
|
||||||
|
|
||||||
|
// Return all recognized languages. Initially it is set to auto-detect
|
||||||
|
func (model *model) Languages() []string {
|
||||||
|
result := make([]string, 0, whisper.Whisper_lang_max_id())
|
||||||
|
for i := 0; i < whisper.Whisper_lang_max_id(); i++ {
|
||||||
|
str := whisper.Whisper_lang_str(i)
|
||||||
|
if model.ctx.Whisper_lang_id(str) >= 0 {
|
||||||
|
result = append(result, str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (model *model) NewContext() (Context, error) {
|
||||||
|
if model.ctx == nil {
|
||||||
|
return nil, ErrInternalAppError
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new context
|
||||||
|
params := model.ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
|
||||||
|
params.SetTranslate(false)
|
||||||
|
params.SetPrintSpecial(false)
|
||||||
|
params.SetPrintProgress(false)
|
||||||
|
params.SetPrintRealtime(false)
|
||||||
|
params.SetPrintTimestamps(false)
|
||||||
|
params.SetThreads(runtime.NumCPU())
|
||||||
|
|
||||||
|
// Return new context
|
||||||
|
return NewContext(model, params)
|
||||||
|
}
|
Binary file not shown.
@ -0,0 +1,412 @@
|
|||||||
|
package whisper
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// CGO
|
||||||
|
|
||||||
|
/*
|
||||||
|
#cgo CFLAGS: -I${SRCDIR}/../..
|
||||||
|
#cgo LDFLAGS: -L${SRCDIR}/build -lwhisper -lm -lstdc++
|
||||||
|
#cgo darwin LDFLAGS: -framework Accelerate
|
||||||
|
#include <whisper.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
|
||||||
|
extern void callNewSegment(void* user_data, int new);
|
||||||
|
extern bool callEncoderBegin(void* user_data);
|
||||||
|
|
||||||
|
// Text segment callback
|
||||||
|
// Called on every newly generated text segment
|
||||||
|
// Use the whisper_full_...() functions to obtain the text segments
|
||||||
|
static void whisper_new_segment_cb(struct whisper_context* ctx, int n_new, void* user_data) {
|
||||||
|
if(user_data != NULL && ctx != NULL) {
|
||||||
|
callNewSegment(user_data, n_new);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encoder begin callback
|
||||||
|
// If not NULL, called before the encoder starts
|
||||||
|
// If it returns false, the computation is aborted
|
||||||
|
static bool whisper_encoder_begin_cb(struct whisper_context* ctx, void* user_data) {
|
||||||
|
if(user_data != NULL && ctx != NULL) {
|
||||||
|
return callEncoderBegin(user_data);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get default parameters and set callbacks
|
||||||
|
static struct whisper_full_params whisper_full_default_params_cb(struct whisper_context* ctx, enum whisper_sampling_strategy strategy) {
|
||||||
|
struct whisper_full_params params = whisper_full_default_params(strategy);
|
||||||
|
params.new_segment_callback = whisper_new_segment_cb;
|
||||||
|
params.new_segment_callback_user_data = (void*)(ctx);
|
||||||
|
params.encoder_begin_callback = whisper_encoder_begin_cb;
|
||||||
|
params.encoder_begin_callback_user_data = (void*)(ctx);
|
||||||
|
return params;
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// TYPES
|
||||||
|
|
||||||
|
type (
|
||||||
|
Context C.struct_whisper_context
|
||||||
|
Token C.whisper_token
|
||||||
|
TokenData C.struct_whisper_token_data
|
||||||
|
SamplingStrategy C.enum_whisper_sampling_strategy
|
||||||
|
Params C.struct_whisper_full_params
|
||||||
|
)
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// GLOBALS
|
||||||
|
|
||||||
|
const (
|
||||||
|
SAMPLING_GREEDY SamplingStrategy = C.WHISPER_SAMPLING_GREEDY
|
||||||
|
SAMPLING_BEAM_SEARCH SamplingStrategy = C.WHISPER_SAMPLING_BEAM_SEARCH
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
SampleRate = C.WHISPER_SAMPLE_RATE // Expected sample rate, samples per second
|
||||||
|
SampleBits = uint16(unsafe.Sizeof(C.float(0))) * 8 // Sample size in bits
|
||||||
|
NumFFT = C.WHISPER_N_FFT
|
||||||
|
NumMEL = C.WHISPER_N_MEL
|
||||||
|
HopLength = C.WHISPER_HOP_LENGTH
|
||||||
|
ChunkSize = C.WHISPER_CHUNK_SIZE
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrTokenizerFailed = errors.New("whisper_tokenize failed")
|
||||||
|
ErrAutoDetectFailed = errors.New("whisper_lang_auto_detect failed")
|
||||||
|
ErrConversionFailed = errors.New("whisper_convert failed")
|
||||||
|
ErrInvalidLanguage = errors.New("invalid language")
|
||||||
|
)
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// PUBLIC METHODS
|
||||||
|
|
||||||
|
// Allocates all memory needed for the model and loads the model from the given file.
|
||||||
|
// Returns NULL on failure.
|
||||||
|
func Whisper_init(path string) *Context {
|
||||||
|
cPath := C.CString(path)
|
||||||
|
defer C.free(unsafe.Pointer(cPath))
|
||||||
|
if ctx := C.whisper_init(cPath); ctx != nil {
|
||||||
|
return (*Context)(ctx)
|
||||||
|
} else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frees all memory allocated by the model.
|
||||||
|
func (ctx *Context) Whisper_free() {
|
||||||
|
C.whisper_free((*C.struct_whisper_context)(ctx))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert RAW PCM audio to log mel spectrogram.
|
||||||
|
// The resulting spectrogram is stored inside the provided whisper context.
|
||||||
|
func (ctx *Context) Whisper_pcm_to_mel(data []float32, threads int) error {
|
||||||
|
if C.whisper_pcm_to_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(threads)) == 0 {
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
return ErrConversionFailed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// This can be used to set a custom log mel spectrogram inside the provided whisper context.
|
||||||
|
// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
|
||||||
|
// n_mel must be 80
|
||||||
|
func (ctx *Context) Whisper_set_mel(data []float32, n_mel int) error {
|
||||||
|
if C.whisper_set_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(n_mel)) == 0 {
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
return ErrConversionFailed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context.
|
||||||
|
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
|
||||||
|
// offset can be used to specify the offset of the first frame in the spectrogram.
|
||||||
|
func (ctx *Context) Whisper_encode(offset, threads int) error {
|
||||||
|
if C.whisper_encode((*C.struct_whisper_context)(ctx), C.int(offset), C.int(threads)) == 0 {
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
return ErrConversionFailed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run the Whisper decoder to obtain the logits and probabilities for the next token.
|
||||||
|
// Make sure to call whisper_encode() first.
|
||||||
|
// tokens + n_tokens is the provided context for the decoder.
|
||||||
|
// n_past is the number of tokens to use from previous decoder calls.
|
||||||
|
func (ctx *Context) Whisper_decode(tokens []Token, past, threads int) error {
|
||||||
|
if C.whisper_decode((*C.struct_whisper_context)(ctx), (*C.whisper_token)(&tokens[0]), C.int(len(tokens)), C.int(past), C.int(threads)) == 0 {
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
return ErrConversionFailed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// whisper_sample_best() returns the token with the highest probability
|
||||||
|
func (ctx *Context) Whisper_sample_best() TokenData {
|
||||||
|
return TokenData(C.whisper_sample_best((*C.struct_whisper_context)(ctx)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// whisper_sample_timestamp() returns the most probable timestamp token
|
||||||
|
func (ctx *Context) Whisper_sample_timestamp(is_initial bool) TokenData {
|
||||||
|
return TokenData(C.whisper_sample_timestamp((*C.struct_whisper_context)(ctx), C.bool(is_initial)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert the provided text into tokens. The tokens pointer must be large enough to hold the resulting tokens.
|
||||||
|
// Returns the number of tokens on success
|
||||||
|
func (ctx *Context) Whisper_tokenize(text string, tokens []Token) (int, error) {
|
||||||
|
cText := C.CString(text)
|
||||||
|
defer C.free(unsafe.Pointer(cText))
|
||||||
|
if n := C.whisper_tokenize((*C.struct_whisper_context)(ctx), cText, (*C.whisper_token)(&tokens[0]), C.int(len(tokens))); n >= 0 {
|
||||||
|
return int(n), nil
|
||||||
|
} else {
|
||||||
|
return 0, ErrTokenizerFailed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the id of the specified language, returns -1 if not found
|
||||||
|
func (ctx *Context) Whisper_lang_id(lang string) int {
|
||||||
|
return int(C.whisper_lang_id(C.CString(lang)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Largest language id (i.e. number of available languages - 1)
|
||||||
|
func Whisper_lang_max_id() int {
|
||||||
|
return int(C.whisper_lang_max_id())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the short string of the specified language id (e.g. 2 -> "de"),
|
||||||
|
// returns empty string if not found
|
||||||
|
func Whisper_lang_str(id int) string {
|
||||||
|
return C.GoString(C.whisper_lang_str(C.int(id)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use mel data at offset_ms to try and auto-detect the spoken language
|
||||||
|
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
|
||||||
|
// Returns the probabilities of all languages.
|
||||||
|
// ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
|
||||||
|
func (ctx *Context) Whisper_lang_auto_detect(offset_ms, n_threads int) ([]float32, error) {
|
||||||
|
probs := make([]float32, Whisper_lang_max_id()+1)
|
||||||
|
if n := int(C.whisper_lang_auto_detect((*C.struct_whisper_context)(ctx), C.int(offset_ms), C.int(n_threads), (*C.float)(&probs[0]))); n < 0 {
|
||||||
|
return nil, ErrAutoDetectFailed
|
||||||
|
} else {
|
||||||
|
return probs, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ctx *Context) Whisper_n_len() int {
|
||||||
|
return int(C.whisper_n_len((*C.struct_whisper_context)(ctx)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ctx *Context) Whisper_n_vocab() int {
|
||||||
|
return int(C.whisper_n_vocab((*C.struct_whisper_context)(ctx)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ctx *Context) Whisper_n_text_ctx() int {
|
||||||
|
return int(C.whisper_n_text_ctx((*C.struct_whisper_context)(ctx)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ctx *Context) Whisper_is_multilingual() int {
|
||||||
|
return int(C.whisper_is_multilingual((*C.struct_whisper_context)(ctx)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// The probabilities for the next token
|
||||||
|
//func (ctx *Whisper_context) Whisper_get_probs() []float32 {
|
||||||
|
// return (*[1 << 30]float32)(unsafe.Pointer(C.whisper_get_probs((*C.struct_whisper_context)(ctx))))[:ctx.Whisper_n_vocab()]
|
||||||
|
//}
|
||||||
|
|
||||||
|
// Token Id -> String. Uses the vocabulary in the provided context
|
||||||
|
func (ctx *Context) Whisper_token_to_str(token Token) string {
|
||||||
|
return C.GoString(C.whisper_token_to_str((*C.struct_whisper_context)(ctx), C.whisper_token(token)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Special tokens
|
||||||
|
func (ctx *Context) Whisper_token_eot() Token {
|
||||||
|
return Token(C.whisper_token_eot((*C.struct_whisper_context)(ctx)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Special tokens
|
||||||
|
func (ctx *Context) Whisper_token_sot() Token {
|
||||||
|
return Token(C.whisper_token_sot((*C.struct_whisper_context)(ctx)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Special tokens
|
||||||
|
func (ctx *Context) Whisper_token_prev() Token {
|
||||||
|
return Token(C.whisper_token_prev((*C.struct_whisper_context)(ctx)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Special tokens
|
||||||
|
func (ctx *Context) Whisper_token_solm() Token {
|
||||||
|
return Token(C.whisper_token_solm((*C.struct_whisper_context)(ctx)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Special tokens
|
||||||
|
func (ctx *Context) Whisper_token_not() Token {
|
||||||
|
return Token(C.whisper_token_not((*C.struct_whisper_context)(ctx)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Special tokens
|
||||||
|
func (ctx *Context) Whisper_token_beg() Token {
|
||||||
|
return Token(C.whisper_token_beg((*C.struct_whisper_context)(ctx)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Special tokens
|
||||||
|
func (ctx *Context) Whisper_token_lang(lang_id int) Token {
|
||||||
|
return Token(C.whisper_token_lang((*C.struct_whisper_context)(ctx), C.int(lang_id)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Task tokens
|
||||||
|
func Whisper_token_translate() Token {
|
||||||
|
return Token(C.whisper_token_translate())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Task tokens
|
||||||
|
func Whisper_token_transcribe() Token {
|
||||||
|
return Token(C.whisper_token_transcribe())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Performance information
|
||||||
|
func (ctx *Context) Whisper_print_timings() {
|
||||||
|
C.whisper_print_timings((*C.struct_whisper_context)(ctx))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Performance information
|
||||||
|
func (ctx *Context) Whisper_reset_timings() {
|
||||||
|
C.whisper_reset_timings((*C.struct_whisper_context)(ctx))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print system information
|
||||||
|
func Whisper_print_system_info() string {
|
||||||
|
return C.GoString(C.whisper_print_system_info())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return default parameters for a strategy
|
||||||
|
func (ctx *Context) Whisper_full_default_params(strategy SamplingStrategy) Params {
|
||||||
|
// Get default parameters
|
||||||
|
return Params(C.whisper_full_default_params_cb((*C.struct_whisper_context)(ctx), C.enum_whisper_sampling_strategy(strategy)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
||||||
|
// Uses the specified decoding strategy to obtain the text.
|
||||||
|
func (ctx *Context) Whisper_full(params Params, samples []float32, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
|
||||||
|
registerEncoderBeginCallback(ctx, encoderBeginCallback)
|
||||||
|
registerNewSegmentCallback(ctx, newSegmentCallback)
|
||||||
|
defer registerEncoderBeginCallback(ctx, nil)
|
||||||
|
defer registerNewSegmentCallback(ctx, nil)
|
||||||
|
if C.whisper_full((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples))) == 0 {
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
return ErrConversionFailed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split the input audio in chunks and process each chunk separately using whisper_full()
|
||||||
|
// It seems this approach can offer some speedup in some cases.
|
||||||
|
// However, the transcription accuracy can be worse at the beginning and end of each chunk.
|
||||||
|
func (ctx *Context) Whisper_full_parallel(params Params, samples []float32, processors int, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
|
||||||
|
registerEncoderBeginCallback(ctx, encoderBeginCallback)
|
||||||
|
registerNewSegmentCallback(ctx, newSegmentCallback)
|
||||||
|
defer registerEncoderBeginCallback(ctx, nil)
|
||||||
|
defer registerNewSegmentCallback(ctx, nil)
|
||||||
|
|
||||||
|
if C.whisper_full_parallel((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples)), C.int(processors)) == 0 {
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
return ErrConversionFailed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Number of generated text segments.
|
||||||
|
// A segment can be a few words, a sentence, or even a paragraph.
|
||||||
|
func (ctx *Context) Whisper_full_n_segments() int {
|
||||||
|
return int(C.whisper_full_n_segments((*C.struct_whisper_context)(ctx)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the start and end time of the specified segment.
|
||||||
|
func (ctx *Context) Whisper_full_get_segment_t0(segment int) int64 {
|
||||||
|
return int64(C.whisper_full_get_segment_t0((*C.struct_whisper_context)(ctx), C.int(segment)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the start and end time of the specified segment.
|
||||||
|
func (ctx *Context) Whisper_full_get_segment_t1(segment int) int64 {
|
||||||
|
return int64(C.whisper_full_get_segment_t1((*C.struct_whisper_context)(ctx), C.int(segment)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the text of the specified segment.
|
||||||
|
func (ctx *Context) Whisper_full_get_segment_text(segment int) string {
|
||||||
|
return C.GoString(C.whisper_full_get_segment_text((*C.struct_whisper_context)(ctx), C.int(segment)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get number of tokens in the specified segment.
|
||||||
|
func (ctx *Context) Whisper_full_n_tokens(segment int) int {
|
||||||
|
return int(C.whisper_full_n_tokens((*C.struct_whisper_context)(ctx), C.int(segment)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the token text of the specified token index in the specified segment.
|
||||||
|
func (ctx *Context) Whisper_full_get_token_text(segment int, token int) string {
|
||||||
|
return C.GoString(C.whisper_full_get_token_text((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the token of the specified token index in the specified segment.
|
||||||
|
func (ctx *Context) Whisper_full_get_token_id(segment int, token int) Token {
|
||||||
|
return Token(C.whisper_full_get_token_id((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get token data for the specified token in the specified segment.
|
||||||
|
// This contains probabilities, timestamps, etc.
|
||||||
|
func (ctx *Context) whisper_full_get_token_data(segment int, token int) TokenData {
|
||||||
|
return TokenData(C.whisper_full_get_token_data((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the probability of the specified token in the specified segment.
|
||||||
|
func (ctx *Context) Whisper_full_get_token_p(segment int, token int) float32 {
|
||||||
|
return float32(C.whisper_full_get_token_p((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// CALLBACKS
|
||||||
|
|
||||||
|
var (
|
||||||
|
cbNewSegment = make(map[unsafe.Pointer]func(int))
|
||||||
|
cbEncoderBegin = make(map[unsafe.Pointer]func() bool)
|
||||||
|
)
|
||||||
|
|
||||||
|
func registerNewSegmentCallback(ctx *Context, fn func(int)) {
|
||||||
|
if fn == nil {
|
||||||
|
delete(cbNewSegment, unsafe.Pointer(ctx))
|
||||||
|
} else {
|
||||||
|
cbNewSegment[unsafe.Pointer(ctx)] = fn
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func registerEncoderBeginCallback(ctx *Context, fn func() bool) {
|
||||||
|
if fn == nil {
|
||||||
|
delete(cbEncoderBegin, unsafe.Pointer(ctx))
|
||||||
|
} else {
|
||||||
|
cbEncoderBegin[unsafe.Pointer(ctx)] = fn
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//export callNewSegment
|
||||||
|
func callNewSegment(user_data unsafe.Pointer, new C.int) {
|
||||||
|
if fn, ok := cbNewSegment[user_data]; ok {
|
||||||
|
fn(int(new))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//export callEncoderBegin
|
||||||
|
func callEncoderBegin(user_data unsafe.Pointer) C.bool {
|
||||||
|
if fn, ok := cbEncoderBegin[user_data]; ok {
|
||||||
|
if fn() {
|
||||||
|
return C.bool(true)
|
||||||
|
} else {
|
||||||
|
return C.bool(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
@ -0,0 +1,110 @@
|
|||||||
|
package whisper_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
// Packages
|
||||||
|
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||||
|
wav "github.com/go-audio/wav"
|
||||||
|
assert "github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ModelPath = "models/ggml-small.en.bin"
|
||||||
|
SamplePath = "samples/jfk.wav"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_Whisper_000(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||||
|
t.Skip("Skipping test, model not found:", ModelPath)
|
||||||
|
}
|
||||||
|
ctx := whisper.Whisper_init(ModelPath)
|
||||||
|
assert.NotNil(ctx)
|
||||||
|
ctx.Whisper_free()
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_Whisper_001(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||||
|
t.Skip("Skipping test, model not found:", ModelPath)
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||||
|
t.Skip("Skipping test, sample not found:", SamplePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open samples
|
||||||
|
fh, err := os.Open(SamplePath)
|
||||||
|
assert.NoError(err)
|
||||||
|
defer fh.Close()
|
||||||
|
|
||||||
|
// Read samples
|
||||||
|
d := wav.NewDecoder(fh)
|
||||||
|
buf, err := d.FullPCMBuffer()
|
||||||
|
assert.NoError(err)
|
||||||
|
|
||||||
|
// Run whisper
|
||||||
|
ctx := whisper.Whisper_init(ModelPath)
|
||||||
|
assert.NotNil(ctx)
|
||||||
|
defer ctx.Whisper_free()
|
||||||
|
assert.NoError(ctx.Whisper_full(ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY), buf.AsFloat32Buffer().Data, nil, nil))
|
||||||
|
|
||||||
|
// Print out tokens
|
||||||
|
num_segments := ctx.Whisper_full_n_segments()
|
||||||
|
assert.GreaterOrEqual(num_segments, 1)
|
||||||
|
for i := 0; i < num_segments; i++ {
|
||||||
|
str := ctx.Whisper_full_get_segment_text(i)
|
||||||
|
assert.NotEmpty(str)
|
||||||
|
t0 := time.Duration(ctx.Whisper_full_get_segment_t0(i)) * time.Millisecond
|
||||||
|
t1 := time.Duration(ctx.Whisper_full_get_segment_t1(i)) * time.Millisecond
|
||||||
|
t.Logf("[%6s->%-6s] %q", t0, t1, str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_Whisper_002(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
for i := 0; i < whisper.Whisper_lang_max_id(); i++ {
|
||||||
|
str := whisper.Whisper_lang_str(i)
|
||||||
|
assert.NotEmpty(str)
|
||||||
|
t.Log(str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_Whisper_003(t *testing.T) {
|
||||||
|
threads := runtime.NumCPU()
|
||||||
|
assert := assert.New(t)
|
||||||
|
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||||
|
t.Skip("Skipping test, model not found:", ModelPath)
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||||
|
t.Skip("Skipping test, sample not found:", SamplePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open samples
|
||||||
|
fh, err := os.Open(SamplePath)
|
||||||
|
assert.NoError(err)
|
||||||
|
defer fh.Close()
|
||||||
|
|
||||||
|
// Read samples
|
||||||
|
d := wav.NewDecoder(fh)
|
||||||
|
buf, err := d.FullPCMBuffer()
|
||||||
|
assert.NoError(err)
|
||||||
|
|
||||||
|
// Make the model
|
||||||
|
ctx := whisper.Whisper_init(ModelPath)
|
||||||
|
assert.NotNil(ctx)
|
||||||
|
defer ctx.Whisper_free()
|
||||||
|
|
||||||
|
// Get MEL
|
||||||
|
assert.NoError(ctx.Whisper_pcm_to_mel(buf.AsFloat32Buffer().Data, threads))
|
||||||
|
|
||||||
|
// Get Languages
|
||||||
|
languages, err := ctx.Whisper_lang_auto_detect(0, threads)
|
||||||
|
assert.NoError(err)
|
||||||
|
for i, p := range languages {
|
||||||
|
t.Logf("%s: %f", whisper.Whisper_lang_str(i), p)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in new issue