Merge pull request #2 from djthorpe/djthorpe-20221218

Update to latest API
pull/287/head
David Thorpe 3 years ago committed by GitHub
commit b5d9c7aa43
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -14,6 +14,7 @@ whisper: mkdir
test: model-small whisper test: model-small whisper
@go mod tidy @go mod tidy
@go test -v . @go test -v .
@go test -v ./pkg/whisper/...
examples: $(EXAMPLES_DIR) examples: $(EXAMPLES_DIR)

@ -4,6 +4,7 @@ This package provides Go bindings for whisper.cpp. They have been tested on:
* Darwin (OS X) 12.6 on x64_64 * Darwin (OS X) 12.6 on x64_64
* Debian Linux on arm64 * Debian Linux on arm64
* Fedora Linux on x86_64
The "low level" bindings are in the `bindings/go` directory and there is a more The "low level" bindings are in the `bindings/go` directory and there is a more
Go-style package in the `bindings/go/pkg/whisper` directory. The most simple usage Go-style package in the `bindings/go/pkg/whisper` directory. The most simple usage
@ -30,7 +31,7 @@ func main() {
if err != nil { if err != nil {
panic(err) panic(err)
} }
if err := context.Process(samples); err != nil { if err := context.Process(samples, nil); err != nil {
return err return err
} }
@ -55,7 +56,7 @@ cd whisper.cpp/bindings/go
make test make test
``` ```
Then build the examples with: This will compile a static `libwhisper.a` in a `build` folder, download a model file, then run the tests. To build the examples:
```bash ```bash
make examples make examples

@ -42,10 +42,20 @@ func (flags *Flags) GetLanguage() string {
return flags.Lookup("language").Value.String() return flags.Lookup("language").Value.String()
} }
func (flags *Flags) IsSpeedup() bool {
return flags.Lookup("speedup").Value.String() == "true"
}
func (flags *Flags) IsTokens() bool {
return flags.Lookup("tokens").Value.String() == "true"
}
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS // PRIVATE METHODS
func registerFlags(flag *Flags) { func registerFlags(flag *Flags) {
flag.String("model", "", "Path to the model file") flag.String("model", "", "Path to the model file")
flag.String("language", "", "Language") flag.String("language", "", "Language")
flag.Bool("speedup", false, "Enable speedup")
flag.Bool("tokens", false, "Display tokens")
} }

@ -36,7 +36,7 @@ func main() {
// Process files // Process files
for _, filename := range flags.Args() { for _, filename := range flags.Args() {
fmt.Println("Processing", filename) fmt.Println("Processing", filename)
if err := Process(model, filename); err != nil { if err := Process(model, filename, flags.GetLanguage(), flags.IsSpeedup(), flags.IsTokens()); err != nil {
fmt.Fprintln(os.Stderr, err) fmt.Fprintln(os.Stderr, err)
continue continue
} }

@ -11,7 +11,7 @@ import (
wav "github.com/go-audio/wav" wav "github.com/go-audio/wav"
) )
func Process(model whisper.Model, path string) error { func Process(model whisper.Model, path string, lang string, speedup, tokens bool) error {
var data []float32 var data []float32
// Create processing context // Create processing context
@ -39,8 +39,28 @@ func Process(model whisper.Model, path string) error {
data = buf.AsFloat32Buffer().Data data = buf.AsFloat32Buffer().Data
} }
// Set the parameters
var cb whisper.SegmentCallback
if lang != "" {
if err := context.SetLanguage(lang); err != nil {
return err
}
}
if speedup {
context.SetSpeedup(true)
}
if tokens {
cb = func(segment whisper.Segment) {
fmt.Printf("%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
for _, token := range segment.Tokens {
fmt.Printf("%q ", token.Text)
}
fmt.Println("")
}
}
// Process the data // Process the data
if err := context.Process(data); err != nil { if err := context.Process(data, cb); err != nil {
return err return err
} }

@ -11,7 +11,7 @@ import (
// CGO // CGO
/* /*
#include <stdbool.h> #include <whisper.h>
*/ */
import "C" import "C"
@ -50,6 +50,35 @@ func (p *Params) SetSpeedup(v bool) {
p.speed_up = toBool(v) p.speed_up = toBool(v)
} }
func (p *Params) SetLanguage(lang int) error {
str := C.whisper_lang_str(C.int(lang))
if str == nil {
return ErrInvalidLanguage
} else {
p.language = str
}
return nil
}
func (p *Params) Language() int {
if p.language == nil {
return -1
}
return int(C.whisper_lang_id(p.language))
}
func (p *Params) SetThreads(threads int) {
p.n_threads = C.int(threads)
}
func (p *Params) SetOffset(offset_ms int) {
p.offset_ms = C.int(offset_ms)
}
func (p *Params) SetDuration(duration_ms int) {
p.duration_ms = C.int(duration_ms)
}
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS // PRIVATE METHODS

@ -11,9 +11,10 @@ import (
// ERRORS // ERRORS
var ( var (
ErrUnableToLoadModel = errors.New("unable to load model") ErrUnableToLoadModel = errors.New("unable to load model")
ErrInternalAppError = errors.New("internal application error") ErrInternalAppError = errors.New("internal application error")
ErrProcessingFailed = errors.New("processing failed") ErrProcessingFailed = errors.New("processing failed")
ErrUnsupportedLanguage = errors.New("unsupported language")
) )
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////

@ -1,11 +1,11 @@
package whisper package whisper
import ( import (
// Bindings
"io" "io"
"strings" "strings"
"time" "time"
// Bindings
whisper "github.com/ggerganov/whisper.cpp/bindings/go" whisper "github.com/ggerganov/whisper.cpp/bindings/go"
) )
@ -34,23 +34,66 @@ func NewContext(model *model, params whisper.Params) (Context, error) {
} }
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// STRINGIFY // PUBLIC METHODS
func (context *context) String() string { // Set the language to use for speech recognition.
str := "<whisper.context" func (context *context) SetLanguage(lang string) error {
return str + ">" if context.model.ctx == nil {
return ErrInternalAppError
}
if id := context.model.ctx.Whisper_lang_id(lang); id < 0 {
return ErrUnsupportedLanguage
} else if err := context.params.SetLanguage(id); err != nil {
return err
}
// Return success
return nil
} }
/////////////////////////////////////////////////////////////////////////////// // Get language
// PUBLIC METHODS func (context *context) Language() string {
return whisper.Whisper_lang_str(context.params.Language())
}
// Set speedup flag
func (context *context) SetSpeedup(v bool) {
context.params.SetSpeedup(v)
}
// Process new sample data and return any errors // Process new sample data and return any errors
func (context *context) Process(data []float32) error { func (context *context) Process(data []float32, cb SegmentCallback) error {
// Process data
if context.model.ctx == nil { if context.model.ctx == nil {
return ErrInternalAppError return ErrInternalAppError
} else if ret := context.model.ctx.Whisper_full(context.params, data, nil, nil); ret != 0 { }
return ErrProcessingFailed // If the callback is defined then we force on single_segment mode
if cb != nil {
context.params.SetSingleSegment(true)
}
// We don't do parallel processing at the moment
processors := 0
if processors > 1 {
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
if cb != nil {
num_segments := context.model.ctx.Whisper_full_n_segments()
s0 := num_segments - new
for i := s0; i < num_segments; i++ {
cb(toSegment(context.model.ctx, i))
}
}
}); err != nil {
return err
}
} else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
if cb != nil {
num_segments := context.model.ctx.Whisper_full_n_segments()
s0 := num_segments - new
for i := s0; i < num_segments; i++ {
cb(toSegment(context.model.ctx, i))
}
}
}); err != nil {
return err
} }
// Return success // Return success
@ -59,18 +102,15 @@ func (context *context) Process(data []float32) error {
// Return the next segment of tokens // Return the next segment of tokens
func (context *context) NextSegment() (Segment, error) { func (context *context) NextSegment() (Segment, error) {
result := Segment{}
if context.model.ctx == nil { if context.model.ctx == nil {
return result, ErrInternalAppError return Segment{}, ErrInternalAppError
} }
if context.n >= context.model.ctx.Whisper_full_n_segments() { if context.n >= context.model.ctx.Whisper_full_n_segments() {
return result, io.EOF return Segment{}, io.EOF
} }
// Populate result // Populate result
result.Text = strings.TrimSpace(context.model.ctx.Whisper_full_get_segment_text(context.n)) result := toSegment(context.model.ctx, context.n)
result.Start = time.Duration(context.model.ctx.Whisper_full_get_segment_t0(context.n)) * time.Millisecond
result.End = time.Duration(context.model.ctx.Whisper_full_get_segment_t1(context.n)) * time.Millisecond
// Increment the cursor // Increment the cursor
context.n++ context.n++
@ -78,3 +118,28 @@ func (context *context) NextSegment() (Segment, error) {
// Return success // Return success
return result, nil return result, nil
} }
///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS
func toSegment(ctx *whisper.Context, n int) Segment {
return Segment{
Num: n,
Text: strings.TrimSpace(ctx.Whisper_full_get_segment_text(n)),
Start: time.Duration(ctx.Whisper_full_get_segment_t0(n)) * time.Millisecond * 10,
End: time.Duration(ctx.Whisper_full_get_segment_t1(n)) * time.Millisecond * 10,
Tokens: toTokens(ctx, n),
}
}
func toTokens(ctx *whisper.Context, n int) []Token {
result := make([]Token, ctx.Whisper_full_n_tokens(n))
for i := 0; i < len(result); i++ {
result[i] = Token{
Id: int(ctx.Whisper_full_get_token_id(n, i)),
Text: strings.TrimSpace(ctx.Whisper_full_get_token_text(n, i)),
P: ctx.Whisper_full_get_token_p(n, i),
}
}
return result
}

@ -0,0 +1,55 @@
package whisper_test
import (
"os"
"testing"
// Packages
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
assert "github.com/stretchr/testify/assert"
)
const (
ModelPath = "../../models/ggml-tiny.bin"
SamplePath = "../../samples/jfk.wav"
)
func Test_Whisper_000(t *testing.T) {
assert := assert.New(t)
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", ModelPath)
}
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
t.Skip("Skipping test, sample not found:", SamplePath)
}
// Load model
model, err := whisper.New(ModelPath)
assert.NoError(err)
assert.NotNil(model)
assert.NoError(model.Close())
t.Log("languages=", model.Languages())
}
func Test_Whisper_001(t *testing.T) {
assert := assert.New(t)
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", ModelPath)
}
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
t.Skip("Skipping test, sample not found:", SamplePath)
}
// Load model
model, err := whisper.New(ModelPath)
assert.NoError(err)
assert.NotNil(model)
defer model.Close()
// Get context for decoding
ctx, err := model.NewContext()
assert.NoError(err)
assert.NotNil(ctx)
}

@ -8,28 +8,56 @@ import (
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// TYPES // TYPES
// SegmentCallback is the callback function for processing segments in real
// time. It is called during the Process function
type SegmentCallback func(Segment)
// Model is the interface to a whisper model. Create a new model with the // Model is the interface to a whisper model. Create a new model with the
// function whisper.New(string) // function whisper.New(string)
type Model interface { type Model interface {
io.Closer io.Closer
// Return a new speech-to-text context.
NewContext() (Context, error) NewContext() (Context, error)
// Return all languages supported.
Languages() []string
} }
// Context is the speach recognition context. // Context is the speach recognition context.
type Context interface { type Context interface {
SetLanguage(string) error // Set the language to use for speech recognition.
Language() string // Get language
SetSpeedup(bool) // Set speedup flag
// Process mono audio data and return any errors. // Process mono audio data and return any errors.
Process([]float32) error // If defined, newly generated segments are passed to the
// callback function during processing.
Process([]float32, SegmentCallback) error
// Return segments until the end of the stream is reached, when io.EOF is // After process is called, return segments until the end of the stream
// returned. // is reached, when io.EOF is returned.
NextSegment() (Segment, error) NextSegment() (Segment, error)
} }
// Segment is the text result of a speech recognition.
type Segment struct { type Segment struct {
// Segment Number
Num int
// Time beginning and end timestamps for the segment. // Time beginning and end timestamps for the segment.
Start, End time.Duration Start, End time.Duration
// The text of the segment. // The text of the segment.
Text string Text string
// The tokens of the segment.
Tokens []Token
}
// Token is a text or special token
type Token struct {
Id int
Text string
P float32
} }

@ -3,6 +3,7 @@ package whisper
import ( import (
"fmt" "fmt"
"os" "os"
"runtime"
// Bindings // Bindings
whisper "github.com/ggerganov/whisper.cpp/bindings/go" whisper "github.com/ggerganov/whisper.cpp/bindings/go"
@ -63,6 +64,18 @@ func (model *model) String() string {
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS // PUBLIC METHODS
// Return all recognized languages. Initially it is set to auto-detect
func (model *model) Languages() []string {
result := make([]string, 0, whisper.Whisper_lang_max_id())
for i := 0; i < whisper.Whisper_lang_max_id(); i++ {
str := whisper.Whisper_lang_str(i)
if model.ctx.Whisper_lang_id(str) >= 0 {
result = append(result, str)
}
}
return result
}
func (model *model) NewContext() (Context, error) { func (model *model) NewContext() (Context, error) {
if model.ctx == nil { if model.ctx == nil {
return nil, ErrInternalAppError return nil, ErrInternalAppError
@ -75,7 +88,7 @@ func (model *model) NewContext() (Context, error) {
params.SetPrintProgress(false) params.SetPrintProgress(false)
params.SetPrintRealtime(false) params.SetPrintRealtime(false)
params.SetPrintTimestamps(false) params.SetPrintTimestamps(false)
params.SetSpeedup(false) params.SetThreads(runtime.NumCPU())
// Return new context // Return new context
return NewContext(model, params) return NewContext(model, params)

@ -1,6 +1,9 @@
package whisper package whisper
import "unsafe" import (
"errors"
"unsafe"
)
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// CGO // CGO
@ -74,6 +77,13 @@ const (
ChunkSize = C.WHISPER_CHUNK_SIZE ChunkSize = C.WHISPER_CHUNK_SIZE
) )
var (
ErrTokenizerFailed = errors.New("whisper_tokenize failed")
ErrAutoDetectFailed = errors.New("whisper_lang_auto_detect failed")
ErrConversionFailed = errors.New("whisper_convert failed")
ErrInvalidLanguage = errors.New("invalid language")
)
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS // PUBLIC METHODS
@ -96,34 +106,46 @@ func (ctx *Context) Whisper_free() {
// Convert RAW PCM audio to log mel spectrogram. // Convert RAW PCM audio to log mel spectrogram.
// The resulting spectrogram is stored inside the provided whisper context. // The resulting spectrogram is stored inside the provided whisper context.
// Returns 0 on success func (ctx *Context) Whisper_pcm_to_mel(data []float32, threads int) error {
func (ctx *Context) Whisper_pcm_to_mel(data []float32, threads int) int { if C.whisper_pcm_to_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(threads)) == 0 {
return int(C.whisper_pcm_to_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(threads))) return nil
} else {
return ErrConversionFailed
}
} }
// This can be used to set a custom log mel spectrogram inside the provided whisper context. // This can be used to set a custom log mel spectrogram inside the provided whisper context.
// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram. // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
// n_mel must be 80 // n_mel must be 80
// Returns 0 on success func (ctx *Context) Whisper_set_mel(data []float32, n_mel int) error {
func (ctx *Context) Whisper_set_mel(data []float32, n_mel int) int { if C.whisper_set_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(n_mel)) == 0 {
return int(C.whisper_set_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(n_mel))) return nil
} else {
return ErrConversionFailed
}
} }
// Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context. // Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context.
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first. // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
// offset can be used to specify the offset of the first frame in the spectrogram. // offset can be used to specify the offset of the first frame in the spectrogram.
// Returns 0 on success func (ctx *Context) Whisper_encode(offset, threads int) error {
func (ctx *Context) Whisper_encode(offset, threads int) int { if C.whisper_encode((*C.struct_whisper_context)(ctx), C.int(offset), C.int(threads)) == 0 {
return int(C.whisper_encode((*C.struct_whisper_context)(ctx), C.int(offset), C.int(threads))) return nil
} else {
return ErrConversionFailed
}
} }
// Run the Whisper decoder to obtain the logits and probabilities for the next token. // Run the Whisper decoder to obtain the logits and probabilities for the next token.
// Make sure to call whisper_encode() first. // Make sure to call whisper_encode() first.
// tokens + n_tokens is the provided context for the decoder. // tokens + n_tokens is the provided context for the decoder.
// n_past is the number of tokens to use from previous decoder calls. // n_past is the number of tokens to use from previous decoder calls.
// Returns 0 on success func (ctx *Context) Whisper_decode(tokens []Token, past, threads int) error {
func (ctx *Context) Whisper_decode(tokens []Token, past, threads int) int { if C.whisper_decode((*C.struct_whisper_context)(ctx), (*C.whisper_token)(&tokens[0]), C.int(len(tokens)), C.int(past), C.int(threads)) == 0 {
return int(C.whisper_decode((*C.struct_whisper_context)(ctx), (*C.whisper_token)(&tokens[0]), C.int(len(tokens)), C.int(past), C.int(threads))) return nil
} else {
return ErrConversionFailed
}
} }
// whisper_sample_best() returns the token with the highest probability // whisper_sample_best() returns the token with the highest probability
@ -136,11 +158,47 @@ func (ctx *Context) Whisper_sample_timestamp(is_initial bool) TokenData {
return TokenData(C.whisper_sample_timestamp((*C.struct_whisper_context)(ctx), C.bool(is_initial))) return TokenData(C.whisper_sample_timestamp((*C.struct_whisper_context)(ctx), C.bool(is_initial)))
} }
// Convert the provided text into tokens. The tokens pointer must be large enough to hold the resulting tokens.
// Returns the number of tokens on success
func (ctx *Context) Whisper_tokenize(text string, tokens []Token) (int, error) {
cText := C.CString(text)
defer C.free(unsafe.Pointer(cText))
if n := C.whisper_tokenize((*C.struct_whisper_context)(ctx), cText, (*C.whisper_token)(&tokens[0]), C.int(len(tokens))); n >= 0 {
return int(n), nil
} else {
return 0, ErrTokenizerFailed
}
}
// Return the id of the specified language, returns -1 if not found // Return the id of the specified language, returns -1 if not found
func (ctx *Context) Whisper_lang_id(lang string) int { func (ctx *Context) Whisper_lang_id(lang string) int {
return int(C.whisper_lang_id(C.CString(lang))) return int(C.whisper_lang_id(C.CString(lang)))
} }
// Largest language id (i.e. number of available languages - 1)
func Whisper_lang_max_id() int {
return int(C.whisper_lang_max_id())
}
// Return the short string of the specified language id (e.g. 2 -> "de"),
// returns empty string if not found
func Whisper_lang_str(id int) string {
return C.GoString(C.whisper_lang_str(C.int(id)))
}
// Use mel data at offset_ms to try and auto-detect the spoken language
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
// Returns the probabilities of all languages.
// ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
func (ctx *Context) Whisper_lang_auto_detect(offset_ms, n_threads int) ([]float32, error) {
probs := make([]float32, Whisper_lang_max_id()+1)
if n := int(C.whisper_lang_auto_detect((*C.struct_whisper_context)(ctx), C.int(offset_ms), C.int(n_threads), (*C.float)(&probs[0]))); n < 0 {
return nil, ErrAutoDetectFailed
} else {
return probs, nil
}
}
func (ctx *Context) Whisper_n_len() int { func (ctx *Context) Whisper_n_len() int {
return int(C.whisper_n_len((*C.struct_whisper_context)(ctx))) return int(C.whisper_n_len((*C.struct_whisper_context)(ctx)))
} }
@ -197,6 +255,11 @@ func (ctx *Context) Whisper_token_beg() Token {
return Token(C.whisper_token_beg((*C.struct_whisper_context)(ctx))) return Token(C.whisper_token_beg((*C.struct_whisper_context)(ctx)))
} }
// Special tokens
func (ctx *Context) Whisper_token_lang(lang_id int) Token {
return Token(C.whisper_token_lang((*C.struct_whisper_context)(ctx), C.int(lang_id)))
}
// Task tokens // Task tokens
func Whisper_token_translate() Token { func Whisper_token_translate() Token {
return Token(C.whisper_token_translate()) return Token(C.whisper_token_translate())
@ -230,25 +293,32 @@ func (ctx *Context) Whisper_full_default_params(strategy SamplingStrategy) Param
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
// Uses the specified decoding strategy to obtain the text. // Uses the specified decoding strategy to obtain the text.
func (ctx *Context) Whisper_full(params Params, samples []float32, encoderBeginCallback func() bool, newSegmentCallback func(int)) int { func (ctx *Context) Whisper_full(params Params, samples []float32, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
registerEncoderBeginCallback(ctx, encoderBeginCallback) registerEncoderBeginCallback(ctx, encoderBeginCallback)
registerNewSegmentCallback(ctx, newSegmentCallback) registerNewSegmentCallback(ctx, newSegmentCallback)
result := int(C.whisper_full((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples)))) defer registerEncoderBeginCallback(ctx, nil)
registerEncoderBeginCallback(ctx, nil) defer registerNewSegmentCallback(ctx, nil)
registerNewSegmentCallback(ctx, nil) if C.whisper_full((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples))) == 0 {
return result return nil
} else {
return ErrConversionFailed
}
} }
// Split the input audio in chunks and process each chunk separately using whisper_full() // Split the input audio in chunks and process each chunk separately using whisper_full()
// It seems this approach can offer some speedup in some cases. // It seems this approach can offer some speedup in some cases.
// However, the transcription accuracy can be worse at the beginning and end of each chunk. // However, the transcription accuracy can be worse at the beginning and end of each chunk.
func (ctx *Context) Whisper_full_parallel(params Params, samples []float32, processors int, encoderBeginCallback func() bool, newSegmentCallback func(int)) int { func (ctx *Context) Whisper_full_parallel(params Params, samples []float32, processors int, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
registerEncoderBeginCallback(ctx, encoderBeginCallback) registerEncoderBeginCallback(ctx, encoderBeginCallback)
registerNewSegmentCallback(ctx, newSegmentCallback) registerNewSegmentCallback(ctx, newSegmentCallback)
result := int(C.whisper_full_parallel((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples)), C.int(processors))) defer registerEncoderBeginCallback(ctx, nil)
registerEncoderBeginCallback(ctx, nil) defer registerNewSegmentCallback(ctx, nil)
registerNewSegmentCallback(ctx, nil)
return result if C.whisper_full_parallel((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples)), C.int(processors)) == 0 {
return nil
} else {
return ErrConversionFailed
}
} }
// Number of generated text segments. // Number of generated text segments.

@ -2,12 +2,13 @@ package whisper_test
import ( import (
"os" "os"
"runtime"
"testing" "testing"
"time" "time"
// Packages // Packages
whisper "github.com/ggerganov/whisper.cpp/bindings/go" whisper "github.com/ggerganov/whisper.cpp/bindings/go"
"github.com/go-audio/wav" wav "github.com/go-audio/wav"
assert "github.com/stretchr/testify/assert" assert "github.com/stretchr/testify/assert"
) )
@ -44,14 +45,12 @@ func Test_Whisper_001(t *testing.T) {
d := wav.NewDecoder(fh) d := wav.NewDecoder(fh)
buf, err := d.FullPCMBuffer() buf, err := d.FullPCMBuffer()
assert.NoError(err) assert.NoError(err)
buf2 := buf.AsFloat32Buffer()
// Run whisper // Run whisper
ctx := whisper.Whisper_init(ModelPath) ctx := whisper.Whisper_init(ModelPath)
assert.NotNil(ctx) assert.NotNil(ctx)
defer ctx.Whisper_free() defer ctx.Whisper_free()
ret := ctx.Whisper_full(ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY), buf2.Data, nil, nil) assert.NoError(ctx.Whisper_full(ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY), buf.AsFloat32Buffer().Data, nil, nil))
assert.Equal(0, ret)
// Print out tokens // Print out tokens
num_segments := ctx.Whisper_full_n_segments() num_segments := ctx.Whisper_full_n_segments()
@ -64,3 +63,48 @@ func Test_Whisper_001(t *testing.T) {
t.Logf("[%6s->%-6s] %q", t0, t1, str) t.Logf("[%6s->%-6s] %q", t0, t1, str)
} }
} }
func Test_Whisper_002(t *testing.T) {
assert := assert.New(t)
for i := 0; i < whisper.Whisper_lang_max_id(); i++ {
str := whisper.Whisper_lang_str(i)
assert.NotEmpty(str)
t.Log(str)
}
}
func Test_Whisper_003(t *testing.T) {
threads := runtime.NumCPU()
assert := assert.New(t)
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", ModelPath)
}
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
t.Skip("Skipping test, sample not found:", SamplePath)
}
// Open samples
fh, err := os.Open(SamplePath)
assert.NoError(err)
defer fh.Close()
// Read samples
d := wav.NewDecoder(fh)
buf, err := d.FullPCMBuffer()
assert.NoError(err)
// Make the model
ctx := whisper.Whisper_init(ModelPath)
assert.NotNil(ctx)
defer ctx.Whisper_free()
// Get MEL
assert.NoError(ctx.Whisper_pcm_to_mel(buf.AsFloat32Buffer().Data, threads))
// Get Languages
languages, err := ctx.Whisper_lang_auto_detect(0, threads)
assert.NoError(err)
for i, p := range languages {
t.Logf("%s: %f", whisper.Whisper_lang_str(i), p)
}
}

Loading…
Cancel
Save