Initial import of golang bindings

pull/288/head
David Thorpe 3 years ago
parent 1d716d6e34
commit 7007828bed

@ -0,0 +1,3 @@
build
models
go.sum

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2022 David Thorpe
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

@ -0,0 +1,37 @@
CMAKE := $(shell which cmake)
BUILD_DIR := "build"
MODELS_DIR := "models"
EXAMPLES_DIR := $(wildcard examples/*)
C_INCLUDE_PATH := "../.."
all: clean whisper examples
whisper: mkdir
@echo Build whisper
@${CMAKE} -S ../.. -B ${BUILD_DIR} -D BUILD_SHARED_LIBS=off -D WHISPER_NO_AVX2=on
@${CMAKE} --build ${BUILD_DIR} --target whisper
test: model-small whisper
@go mod tidy
@go test -v .
examples: $(EXAMPLES_DIR)
model-small: examples/go-model-download
@${BUILD_DIR}/go-model-download -out models small.en
$(EXAMPLES_DIR): clean mkdir
@echo Build example $(notdir $@)
@go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@
mkdir:
@echo Mkdir ${BUILD_DIR}
@install -d ${BUILD_DIR}
@echo Mkdir ${MODELS_DIR}
@install -d ${MODELS_DIR}
clean:
@echo Clean
@rm -fr $(BUILD_DIR)
@go mod tidy
@go clean

@ -0,0 +1,74 @@
# Go bindings for Whisper
This package provides Go bindings for whisper.cpp. They have been tested on:
* Darwin (OS X) 12.6 on x64_64
* Debian Linux on arm64
The "low level" bindings are in the `bindings/go` directory and there is a more
Go-style package in the `bindings/go/pkg/whisper` directory. The most simple usage
is as follows:
```go
import (
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
)
func main() {
var modelpath string // Path to the model
var samples []float32 // Samples to process
// Load the model
model, err := whisper.New(modelpath)
if err != nil {
panic(err)
}
defer model.Close()
// Process samples
context, err := model.NewContext()
if err != nil {
panic(err)
}
if err := context.Process(samples); err != nil {
return err
}
// Print out the results
for {
segment, err := context.NextSegment()
if err != nil {
break
}
fmt.Printf("[%6s->%6s] %s\n", segment.Start, segment.End, segment.Text)
}
}
```
## Building & Testing
In order to build, you need to have the Go compiler installed. You can get it from [here](https://golang.org/dl/). Run the tests with:
```bash
make test
```
Then build the examples with:
```bash
make examples
```
The examples are placed in the `build` directory. Once built, you can download all the models with the following command:
```bash
./build/go-model-download -out models
```
And you can then test a model against samples with the following command:
```bash
./build/go-whisper -model models/ggml-tiny.en.bin samples/jfk.wav
```

@ -0,0 +1,5 @@
/*
github.com/ggerganov/whisper.cpp/bindings/go
provides a speech-to-text service bindings for the Go programming language.
*/
package whisper

@ -0,0 +1,30 @@
package main
import (
"context"
"os"
"os/signal"
)
// ContextForSignal returns a context object which is cancelled when a signal
// is received. It returns nil if no signal parameter is provided
func ContextForSignal(signals ...os.Signal) context.Context {
if len(signals) == 0 {
return nil
}
ch := make(chan os.Signal)
ctx, cancel := context.WithCancel(context.Background())
// Send message on channel when signal received
signal.Notify(ch, signals...)
// When any signal received, call cancel
go func() {
<-ch
cancel()
}()
// Return success
return ctx
}

@ -0,0 +1,206 @@
package main
import (
"context"
"flag"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"syscall"
"time"
)
///////////////////////////////////////////////////////////////////////////////
// CONSTANTS
const (
srcUrl = "https://huggingface.co/" // The location of the models
srcPathPrefix = "/datasets/ggerganov/whisper.cpp/resolve/main/ggml" // Filename prefix
srcExt = ".bin" // Filename extension
bufSize = 1024 * 64 // Size of the buffer used for downloading the model
)
var (
// The models which will be downloaded, if no model is specified as an argument
modelNames = []string{"tiny.en", "tiny", "base.en", "base", "small.en", "small", "medium.en", "medium", "large-v1", "large"}
)
var (
// The output folder. When not set, use current working directory.
flagOut = flag.String("out", "", "Output folder")
// HTTP timeout parameter - will timeout if takes longer than this to download a model
flagTimeout = flag.Duration("timeout", 30*time.Minute, "HTTP timeout")
// Quiet parameter - will not print progress if set
flagQuiet = flag.Bool("quiet", false, "Quiet mode")
)
///////////////////////////////////////////////////////////////////////////////
// MAIN
func main() {
flag.Usage = func() {
name := filepath.Base(flag.CommandLine.Name())
fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s [options] <model>\n\n", name)
flag.PrintDefaults()
}
flag.Parse()
// Get output path
out, err := GetOut()
if err != nil {
fmt.Fprintln(os.Stderr, "Error:", err)
os.Exit(-1)
}
// Create context which quits on SIGINT or SIGQUIT
ctx := ContextForSignal(os.Interrupt, syscall.SIGQUIT)
// Progress filehandle
progress := os.Stdout
if *flagQuiet {
progress, err = os.Open(os.DevNull)
if err != nil {
fmt.Fprintln(os.Stderr, "Error:", err)
os.Exit(-1)
}
defer progress.Close()
}
// Download models - exit on error or interrupt
for _, model := range GetModels() {
url, err := URLForModel(model)
if err != nil {
fmt.Fprintln(os.Stderr, "Error:", err)
continue
} else if path, err := Download(ctx, progress, url, out); err == nil || err == io.EOF {
continue
} else if err == context.Canceled {
os.Remove(path)
fmt.Fprintln(progress, "\nInterrupted")
break
} else if err == context.DeadlineExceeded {
os.Remove(path)
fmt.Fprintln(progress, "Timeout downloading model")
continue
} else {
os.Remove(path)
fmt.Fprintln(os.Stderr, "Error:", err)
break
}
}
}
///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS
// GetOut returns the path to the output directory
func GetOut() (string, error) {
if *flagOut == "" {
return os.Getwd()
}
if info, err := os.Stat(*flagOut); err != nil {
return "", err
} else if !info.IsDir() {
return "", fmt.Errorf("not a directory: %s", info.Name())
} else {
return *flagOut, nil
}
}
// GetModels returns the list of models to download
func GetModels() []string {
if flag.NArg() == 0 {
return modelNames
} else {
return flag.Args()
}
}
// URLForModel returns the URL for the given model on huggingface.co
func URLForModel(model string) (string, error) {
url, err := url.Parse(srcUrl)
if err != nil {
return "", err
} else {
url.Path = srcPathPrefix + "-" + model + srcExt
}
return url.String(), nil
}
// Download downloads the model from the given URL to the given output directory
func Download(ctx context.Context, p io.Writer, model, out string) (string, error) {
// Create HTTP client
client := http.Client{
Timeout: *flagTimeout,
}
// Initiate the download
req, err := http.NewRequest("GET", model, nil)
if err != nil {
return "", err
}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("%s: %s", model, resp.Status)
}
// If output file exists and is the same size as the model, skip
path := filepath.Join(out, filepath.Base(model))
if info, err := os.Stat(path); err == nil && info.Size() == resp.ContentLength {
fmt.Fprintln(p, "Skipping", model, "as it already exists")
return "", nil
}
// Create file
w, err := os.Create(path)
if err != nil {
return "", err
}
defer w.Close()
// Report
fmt.Fprintln(p, "Downloading", model, "to", out)
// Progressively download the model
data := make([]byte, bufSize)
count, pct := int64(0), int64(0)
ticker := time.NewTicker(5 * time.Second)
for {
select {
case <-ctx.Done():
// Cancelled, return error
return path, ctx.Err()
case <-ticker.C:
pct = DownloadReport(p, pct, count, resp.ContentLength)
default:
// Read body
n, err := resp.Body.Read(data)
if err != nil {
DownloadReport(p, pct, count, resp.ContentLength)
return path, err
} else if m, err := w.Write(data[:n]); err != nil {
return path, err
} else {
count += int64(m)
}
}
}
}
// Report periodically reports the download progress when percentage changes
func DownloadReport(w io.Writer, pct, count, total int64) int64 {
pct_ := count * 100 / total
if pct_ > pct {
fmt.Fprintf(w, " ...%d MB written (%d%%)\n", count/1e6, pct_)
}
return pct_
}

@ -0,0 +1,51 @@
package main
import (
"flag"
)
///////////////////////////////////////////////////////////////////////////////
// TYPES
type Flags struct {
*flag.FlagSet
}
///////////////////////////////////////////////////////////////////////////////
// LIFECYCLE
func NewFlags(name string, args []string) (*Flags, error) {
flags := &Flags{
FlagSet: flag.NewFlagSet(name, flag.ContinueOnError),
}
// Register the command line arguments
registerFlags(flags)
// Parse command line
if err := flags.Parse(args); err != nil {
return nil, err
}
// Return success
return flags, nil
}
///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS
func (flags *Flags) GetModel() string {
return flags.Lookup("model").Value.String()
}
func (flags *Flags) GetLanguage() string {
return flags.Lookup("language").Value.String()
}
///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS
func registerFlags(flag *Flags) {
flag.String("model", "", "Path to the model file")
flag.String("language", "", "Language")
}

@ -0,0 +1,44 @@
package main
import (
"flag"
"fmt"
"os"
"path/filepath"
// Packages
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
)
func main() {
flags, err := NewFlags(filepath.Base(os.Args[0]), os.Args[1:])
if err == flag.ErrHelp {
os.Exit(0)
} else if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
} else if flags.GetModel() == "" {
fmt.Fprintln(os.Stderr, "Use -model flag to specify which model file to use")
os.Exit(1)
} else if flags.NArg() == 0 {
fmt.Fprintln(os.Stderr, "No input files specified")
os.Exit(1)
}
// Load model
model, err := whisper.New(flags.GetModel())
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
defer model.Close()
// Process files
for _, filename := range flags.Args() {
fmt.Println("Processing", filename)
if err := Process(model, filename); err != nil {
fmt.Fprintln(os.Stderr, err)
continue
}
}
}

@ -0,0 +1,60 @@
package main
import (
"fmt"
"io"
"os"
"time"
// Package imports
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
wav "github.com/go-audio/wav"
)
func Process(model whisper.Model, path string) error {
var data []float32
// Create processing context
context, err := model.NewContext()
if err != nil {
return err
}
// Open the file
fh, err := os.Open(path)
if err != nil {
return err
}
defer fh.Close()
// Decode the WAV file
dec := wav.NewDecoder(fh)
if buf, err := dec.FullPCMBuffer(); err != nil {
return err
} else if dec.SampleRate != whisper.SampleRate {
return fmt.Errorf("unsupported sample rate: %d", dec.SampleRate)
} else if dec.NumChans != 1 {
return fmt.Errorf("unsupported number of channels: %d", dec.NumChans)
} else {
data = buf.AsFloat32Buffer().Data
}
// Process the data
if err := context.Process(data); err != nil {
return err
}
// Print out the results
for {
segment, err := context.NextSegment()
if err == io.EOF {
break
} else if err != nil {
return err
}
fmt.Printf("[%6s->%6s] %s\n", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond), segment.Text)
}
// Return success
return nil
}

@ -0,0 +1,16 @@
module github.com/ggerganov/whisper.cpp/bindings/go
go 1.19
require (
github.com/go-audio/wav v1.1.0
github.com/stretchr/testify v1.8.1
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-audio/audio v1.0.0 // indirect
github.com/go-audio/riff v1.0.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

@ -0,0 +1,105 @@
package whisper
// This file defines the whisper_token, whisper_token_data and whisper_full_params
// structures, which are used by the whisper_full() function.
import (
"fmt"
)
///////////////////////////////////////////////////////////////////////////////
// CGO
/*
#include <stdbool.h>
*/
import "C"
///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS
func (p *Params) SetTranslate(v bool) {
p.translate = toBool(v)
}
func (p *Params) SetNoContext(v bool) {
p.no_context = toBool(v)
}
func (p *Params) SetSingleSegment(v bool) {
p.single_segment = toBool(v)
}
func (p *Params) SetPrintSpecial(v bool) {
p.print_special = toBool(v)
}
func (p *Params) SetPrintProgress(v bool) {
p.print_progress = toBool(v)
}
func (p *Params) SetPrintRealtime(v bool) {
p.print_realtime = toBool(v)
}
func (p *Params) SetPrintTimestamps(v bool) {
p.print_timestamps = toBool(v)
}
func (p *Params) SetSpeedup(v bool) {
p.speed_up = toBool(v)
}
///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS
func toBool(v bool) C.bool {
if v {
return C.bool(true)
}
return C.bool(false)
}
///////////////////////////////////////////////////////////////////////////////
// STRINGIFY
func (p *Params) String() string {
str := "<whisper.params"
str += fmt.Sprintf(" strategy=%v", p.strategy)
str += fmt.Sprintf(" n_threads=%d", p.n_threads)
if p.language != nil {
str += fmt.Sprintf(" language=%s", C.GoString(p.language))
}
str += fmt.Sprintf(" n_max_text_ctx=%d", p.n_max_text_ctx)
str += fmt.Sprintf(" offset_ms=%d", p.offset_ms)
str += fmt.Sprintf(" duration_ms=%d", p.duration_ms)
if p.translate {
str += " translate"
}
if p.no_context {
str += " no_context"
}
if p.single_segment {
str += " single_segment"
}
if p.print_special {
str += " print_special"
}
if p.print_progress {
str += " print_progress"
}
if p.print_realtime {
str += " print_realtime"
}
if p.print_timestamps {
str += " print_timestamps"
}
if p.token_timestamps {
str += " token_timestamps"
}
if p.speed_up {
str += " speed_up"
}
return str + ">"
}

@ -0,0 +1,26 @@
package whisper
import (
"errors"
// Bindings
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
)
///////////////////////////////////////////////////////////////////////////////
// ERRORS
var (
ErrUnableToLoadModel = errors.New("unable to load model")
ErrInternalAppError = errors.New("internal application error")
ErrProcessingFailed = errors.New("processing failed")
)
///////////////////////////////////////////////////////////////////////////////
// CONSTANTS
// SampleRate is the sample rate of the audio data.
const SampleRate = whisper.SampleRate
// SampleBits is the number of bytes per sample.
const SampleBits = whisper.SampleBits

@ -0,0 +1,80 @@
package whisper
import (
// Bindings
"io"
"strings"
"time"
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
)
///////////////////////////////////////////////////////////////////////////////
// TYPES
type context struct {
n int
model *model
params whisper.Params
}
// Make sure context adheres to the interface
var _ Context = (*context)(nil)
///////////////////////////////////////////////////////////////////////////////
// LIFECYCLE
func NewContext(model *model, params whisper.Params) (Context, error) {
context := new(context)
context.model = model
context.params = params
// Return success
return context, nil
}
///////////////////////////////////////////////////////////////////////////////
// STRINGIFY
func (context *context) String() string {
str := "<whisper.context"
return str + ">"
}
///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS
// Process new sample data and return any errors
func (context *context) Process(data []float32) error {
// Process data
if context.model.ctx == nil {
return ErrInternalAppError
} else if ret := context.model.ctx.Whisper_full(context.params, data, nil, nil); ret != 0 {
return ErrProcessingFailed
}
// Return success
return nil
}
// Return the next segment of tokens
func (context *context) NextSegment() (Segment, error) {
result := Segment{}
if context.model.ctx == nil {
return result, ErrInternalAppError
}
if context.n >= context.model.ctx.Whisper_full_n_segments() {
return result, io.EOF
}
// Populate result
result.Text = strings.TrimSpace(context.model.ctx.Whisper_full_get_segment_text(context.n))
result.Start = time.Duration(context.model.ctx.Whisper_full_get_segment_t0(context.n)) * time.Millisecond
result.End = time.Duration(context.model.ctx.Whisper_full_get_segment_t1(context.n)) * time.Millisecond
// Increment the cursor
context.n++
// Return success
return result, nil
}

@ -0,0 +1,4 @@
/*
This is the higher-level speech-to-text whisper.cpp API for go
*/
package whisper

@ -0,0 +1,35 @@
package whisper
import (
"io"
"time"
)
///////////////////////////////////////////////////////////////////////////////
// TYPES
// Model is the interface to a whisper model. Create a new model with the
// function whisper.New(string)
type Model interface {
io.Closer
NewContext() (Context, error)
}
// Context is the speach recognition context.
type Context interface {
// Process mono audio data and return any errors.
Process([]float32) error
// Return segments until the end of the stream is reached, when io.EOF is
// returned.
NextSegment() (Segment, error)
}
type Segment struct {
// Time beginning and end timestamps for the segment.
Start, End time.Duration
// The text of the segment.
Text string
}

@ -0,0 +1,82 @@
package whisper
import (
"fmt"
"os"
// Bindings
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
)
///////////////////////////////////////////////////////////////////////////////
// TYPES
type model struct {
path string
ctx *whisper.Context
}
// Make sure model adheres to the interface
var _ Model = (*model)(nil)
///////////////////////////////////////////////////////////////////////////////
// LIFECYCLE
func New(path string) (*model, error) {
model := new(model)
if _, err := os.Stat(path); err != nil {
return nil, err
} else if ctx := whisper.Whisper_init(path); ctx == nil {
return nil, ErrUnableToLoadModel
} else {
model.ctx = ctx
model.path = path
}
// Return success
return model, nil
}
func (model *model) Close() error {
if model.ctx != nil {
model.ctx.Whisper_free()
}
// Release resources
model.ctx = nil
// Return success
return nil
}
///////////////////////////////////////////////////////////////////////////////
// STRINGIFY
func (model *model) String() string {
str := "<whisper.model"
if model.ctx != nil {
str += fmt.Sprintf(" model=%q", model.path)
}
return str + ">"
}
///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS
func (model *model) NewContext() (Context, error) {
if model.ctx == nil {
return nil, ErrInternalAppError
}
// Create new context
params := model.ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
params.SetTranslate(false)
params.SetPrintSpecial(false)
params.SetPrintProgress(false)
params.SetPrintRealtime(false)
params.SetPrintTimestamps(false)
params.SetSpeedup(false)
// Return new context
return NewContext(model, params)
}

Binary file not shown.

@ -0,0 +1,342 @@
package whisper
import "unsafe"
///////////////////////////////////////////////////////////////////////////////
// CGO
/*
#cgo CFLAGS: -I${SRCDIR}/../..
#cgo LDFLAGS: -L${SRCDIR}/build -lwhisper -lm -lstdc++
#cgo darwin LDFLAGS: -framework Accelerate
#include <whisper.h>
#include <stdlib.h>
extern void callNewSegment(void* user_data, int new);
extern bool callEncoderBegin(void* user_data);
// Text segment callback
// Called on every newly generated text segment
// Use the whisper_full_...() functions to obtain the text segments
static void whisper_new_segment_cb(struct whisper_context* ctx, int n_new, void* user_data) {
if(user_data != NULL && ctx != NULL) {
callNewSegment(user_data, n_new);
}
}
// Encoder begin callback
// If not NULL, called before the encoder starts
// If it returns false, the computation is aborted
static bool whisper_encoder_begin_cb(struct whisper_context* ctx, void* user_data) {
if(user_data != NULL && ctx != NULL) {
return callEncoderBegin(user_data);
}
return false;
}
// Get default parameters and set callbacks
static struct whisper_full_params whisper_full_default_params_cb(struct whisper_context* ctx, enum whisper_sampling_strategy strategy) {
struct whisper_full_params params = whisper_full_default_params(strategy);
params.new_segment_callback = whisper_new_segment_cb;
params.new_segment_callback_user_data = (void*)(ctx);
params.encoder_begin_callback = whisper_encoder_begin_cb;
params.encoder_begin_callback_user_data = (void*)(ctx);
return params;
}
*/
import "C"
///////////////////////////////////////////////////////////////////////////////
// TYPES
type (
Context C.struct_whisper_context
Token C.whisper_token
TokenData C.struct_whisper_token_data
SamplingStrategy C.enum_whisper_sampling_strategy
Params C.struct_whisper_full_params
)
///////////////////////////////////////////////////////////////////////////////
// GLOBALS
const (
SAMPLING_GREEDY SamplingStrategy = C.WHISPER_SAMPLING_GREEDY
SAMPLING_BEAM_SEARCH SamplingStrategy = C.WHISPER_SAMPLING_BEAM_SEARCH
)
const (
SampleRate = C.WHISPER_SAMPLE_RATE // Expected sample rate, samples per second
SampleBits = uint16(unsafe.Sizeof(C.float(0))) * 8 // Sample size in bits
NumFFT = C.WHISPER_N_FFT
NumMEL = C.WHISPER_N_MEL
HopLength = C.WHISPER_HOP_LENGTH
ChunkSize = C.WHISPER_CHUNK_SIZE
)
///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS
// Allocates all memory needed for the model and loads the model from the given file.
// Returns NULL on failure.
func Whisper_init(path string) *Context {
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
if ctx := C.whisper_init(cPath); ctx != nil {
return (*Context)(ctx)
} else {
return nil
}
}
// Frees all memory allocated by the model.
func (ctx *Context) Whisper_free() {
C.whisper_free((*C.struct_whisper_context)(ctx))
}
// Convert RAW PCM audio to log mel spectrogram.
// The resulting spectrogram is stored inside the provided whisper context.
// Returns 0 on success
func (ctx *Context) Whisper_pcm_to_mel(data []float32, threads int) int {
return int(C.whisper_pcm_to_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(threads)))
}
// This can be used to set a custom log mel spectrogram inside the provided whisper context.
// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
// n_mel must be 80
// Returns 0 on success
func (ctx *Context) Whisper_set_mel(data []float32, n_mel int) int {
return int(C.whisper_set_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(n_mel)))
}
// Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context.
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
// offset can be used to specify the offset of the first frame in the spectrogram.
// Returns 0 on success
func (ctx *Context) Whisper_encode(offset, threads int) int {
return int(C.whisper_encode((*C.struct_whisper_context)(ctx), C.int(offset), C.int(threads)))
}
// Run the Whisper decoder to obtain the logits and probabilities for the next token.
// Make sure to call whisper_encode() first.
// tokens + n_tokens is the provided context for the decoder.
// n_past is the number of tokens to use from previous decoder calls.
// Returns 0 on success
func (ctx *Context) Whisper_decode(tokens []Token, past, threads int) int {
return int(C.whisper_decode((*C.struct_whisper_context)(ctx), (*C.whisper_token)(&tokens[0]), C.int(len(tokens)), C.int(past), C.int(threads)))
}
// whisper_sample_best() returns the token with the highest probability
func (ctx *Context) Whisper_sample_best() TokenData {
return TokenData(C.whisper_sample_best((*C.struct_whisper_context)(ctx)))
}
// whisper_sample_timestamp() returns the most probable timestamp token
func (ctx *Context) Whisper_sample_timestamp(is_initial bool) TokenData {
return TokenData(C.whisper_sample_timestamp((*C.struct_whisper_context)(ctx), C.bool(is_initial)))
}
// Return the id of the specified language, returns -1 if not found
func (ctx *Context) Whisper_lang_id(lang string) int {
return int(C.whisper_lang_id(C.CString(lang)))
}
func (ctx *Context) Whisper_n_len() int {
return int(C.whisper_n_len((*C.struct_whisper_context)(ctx)))
}
func (ctx *Context) Whisper_n_vocab() int {
return int(C.whisper_n_vocab((*C.struct_whisper_context)(ctx)))
}
func (ctx *Context) Whisper_n_text_ctx() int {
return int(C.whisper_n_text_ctx((*C.struct_whisper_context)(ctx)))
}
func (ctx *Context) Whisper_is_multilingual() int {
return int(C.whisper_is_multilingual((*C.struct_whisper_context)(ctx)))
}
// The probabilities for the next token
//func (ctx *Whisper_context) Whisper_get_probs() []float32 {
// return (*[1 << 30]float32)(unsafe.Pointer(C.whisper_get_probs((*C.struct_whisper_context)(ctx))))[:ctx.Whisper_n_vocab()]
//}
// Token Id -> String. Uses the vocabulary in the provided context
func (ctx *Context) Whisper_token_to_str(token Token) string {
return C.GoString(C.whisper_token_to_str((*C.struct_whisper_context)(ctx), C.whisper_token(token)))
}
// Special tokens
func (ctx *Context) Whisper_token_eot() Token {
return Token(C.whisper_token_eot((*C.struct_whisper_context)(ctx)))
}
// Special tokens
func (ctx *Context) Whisper_token_sot() Token {
return Token(C.whisper_token_sot((*C.struct_whisper_context)(ctx)))
}
// Special tokens
func (ctx *Context) Whisper_token_prev() Token {
return Token(C.whisper_token_prev((*C.struct_whisper_context)(ctx)))
}
// Special tokens
func (ctx *Context) Whisper_token_solm() Token {
return Token(C.whisper_token_solm((*C.struct_whisper_context)(ctx)))
}
// Special tokens
func (ctx *Context) Whisper_token_not() Token {
return Token(C.whisper_token_not((*C.struct_whisper_context)(ctx)))
}
// Special tokens
func (ctx *Context) Whisper_token_beg() Token {
return Token(C.whisper_token_beg((*C.struct_whisper_context)(ctx)))
}
// Task tokens
func Whisper_token_translate() Token {
return Token(C.whisper_token_translate())
}
// Task tokens
func Whisper_token_transcribe() Token {
return Token(C.whisper_token_transcribe())
}
// Performance information
func (ctx *Context) Whisper_print_timings() {
C.whisper_print_timings((*C.struct_whisper_context)(ctx))
}
// Performance information
func (ctx *Context) Whisper_reset_timings() {
C.whisper_reset_timings((*C.struct_whisper_context)(ctx))
}
// Print system information
func Whisper_print_system_info() string {
return C.GoString(C.whisper_print_system_info())
}
// Return default parameters for a strategy
func (ctx *Context) Whisper_full_default_params(strategy SamplingStrategy) Params {
// Get default parameters
return Params(C.whisper_full_default_params_cb((*C.struct_whisper_context)(ctx), C.enum_whisper_sampling_strategy(strategy)))
}
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
// Uses the specified decoding strategy to obtain the text.
func (ctx *Context) Whisper_full(params Params, samples []float32, encoderBeginCallback func() bool, newSegmentCallback func(int)) int {
registerEncoderBeginCallback(ctx, encoderBeginCallback)
registerNewSegmentCallback(ctx, newSegmentCallback)
result := int(C.whisper_full((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples))))
registerEncoderBeginCallback(ctx, nil)
registerNewSegmentCallback(ctx, nil)
return result
}
// Split the input audio in chunks and process each chunk separately using whisper_full()
// It seems this approach can offer some speedup in some cases.
// However, the transcription accuracy can be worse at the beginning and end of each chunk.
func (ctx *Context) Whisper_full_parallel(params Params, samples []float32, processors int, encoderBeginCallback func() bool, newSegmentCallback func(int)) int {
registerEncoderBeginCallback(ctx, encoderBeginCallback)
registerNewSegmentCallback(ctx, newSegmentCallback)
result := int(C.whisper_full_parallel((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples)), C.int(processors)))
registerEncoderBeginCallback(ctx, nil)
registerNewSegmentCallback(ctx, nil)
return result
}
// Number of generated text segments.
// A segment can be a few words, a sentence, or even a paragraph.
func (ctx *Context) Whisper_full_n_segments() int {
return int(C.whisper_full_n_segments((*C.struct_whisper_context)(ctx)))
}
// Get the start and end time of the specified segment.
func (ctx *Context) Whisper_full_get_segment_t0(segment int) int64 {
return int64(C.whisper_full_get_segment_t0((*C.struct_whisper_context)(ctx), C.int(segment)))
}
// Get the start and end time of the specified segment.
func (ctx *Context) Whisper_full_get_segment_t1(segment int) int64 {
return int64(C.whisper_full_get_segment_t1((*C.struct_whisper_context)(ctx), C.int(segment)))
}
// Get the text of the specified segment.
func (ctx *Context) Whisper_full_get_segment_text(segment int) string {
return C.GoString(C.whisper_full_get_segment_text((*C.struct_whisper_context)(ctx), C.int(segment)))
}
// Get number of tokens in the specified segment.
func (ctx *Context) Whisper_full_n_tokens(segment int) int {
return int(C.whisper_full_n_tokens((*C.struct_whisper_context)(ctx), C.int(segment)))
}
// Get the token text of the specified token index in the specified segment.
func (ctx *Context) Whisper_full_get_token_text(segment int, token int) string {
return C.GoString(C.whisper_full_get_token_text((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
}
// Get the token of the specified token index in the specified segment.
func (ctx *Context) Whisper_full_get_token_id(segment int, token int) Token {
return Token(C.whisper_full_get_token_id((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
}
// Get token data for the specified token in the specified segment.
// This contains probabilities, timestamps, etc.
func (ctx *Context) whisper_full_get_token_data(segment int, token int) TokenData {
return TokenData(C.whisper_full_get_token_data((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
}
// Get the probability of the specified token in the specified segment.
func (ctx *Context) Whisper_full_get_token_p(segment int, token int) float32 {
return float32(C.whisper_full_get_token_p((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
}
///////////////////////////////////////////////////////////////////////////////
// CALLBACKS
var (
cbNewSegment = make(map[unsafe.Pointer]func(int))
cbEncoderBegin = make(map[unsafe.Pointer]func() bool)
)
func registerNewSegmentCallback(ctx *Context, fn func(int)) {
if fn == nil {
delete(cbNewSegment, unsafe.Pointer(ctx))
} else {
cbNewSegment[unsafe.Pointer(ctx)] = fn
}
}
func registerEncoderBeginCallback(ctx *Context, fn func() bool) {
if fn == nil {
delete(cbEncoderBegin, unsafe.Pointer(ctx))
} else {
cbEncoderBegin[unsafe.Pointer(ctx)] = fn
}
}
//export callNewSegment
func callNewSegment(user_data unsafe.Pointer, new C.int) {
if fn, ok := cbNewSegment[user_data]; ok {
fn(int(new))
}
}
//export callEncoderBegin
func callEncoderBegin(user_data unsafe.Pointer) C.bool {
if fn, ok := cbEncoderBegin[user_data]; ok {
if fn() {
return C.bool(true)
} else {
return C.bool(false)
}
}
return true
}

@ -0,0 +1,66 @@
package whisper_test
import (
"os"
"testing"
"time"
// Packages
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
"github.com/go-audio/wav"
assert "github.com/stretchr/testify/assert"
)
const (
ModelPath = "models/ggml-small.en.bin"
SamplePath = "samples/jfk.wav"
)
func Test_Whisper_000(t *testing.T) {
assert := assert.New(t)
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", ModelPath)
}
ctx := whisper.Whisper_init(ModelPath)
assert.NotNil(ctx)
ctx.Whisper_free()
}
func Test_Whisper_001(t *testing.T) {
assert := assert.New(t)
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", ModelPath)
}
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
t.Skip("Skipping test, sample not found:", SamplePath)
}
// Open samples
fh, err := os.Open(SamplePath)
assert.NoError(err)
defer fh.Close()
// Read samples
d := wav.NewDecoder(fh)
buf, err := d.FullPCMBuffer()
assert.NoError(err)
buf2 := buf.AsFloat32Buffer()
// Run whisper
ctx := whisper.Whisper_init(ModelPath)
assert.NotNil(ctx)
defer ctx.Whisper_free()
ret := ctx.Whisper_full(ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY), buf2.Data, nil, nil)
assert.Equal(0, ret)
// Print out tokens
num_segments := ctx.Whisper_full_n_segments()
assert.GreaterOrEqual(num_segments, 1)
for i := 0; i < num_segments; i++ {
str := ctx.Whisper_full_get_segment_text(i)
assert.NotEmpty(str)
t0 := time.Duration(ctx.Whisper_full_get_segment_t0(i)) * time.Millisecond
t1 := time.Duration(ctx.Whisper_full_get_segment_t1(i)) * time.Millisecond
t.Logf("[%6s->%-6s] %q", t0, t1, str)
}
}
Loading…
Cancel
Save