From 41b223941d2f3da2efd8670b1b58bf8469250a03 Mon Sep 17 00:00:00 2001 From: David Thorpe Date: Sun, 18 Dec 2022 19:57:18 +0100 Subject: [PATCH 1/2] Updated bindings --- bindings/go/README.md | 4 +- bindings/go/examples/go-whisper/flags.go | 10 ++ bindings/go/examples/go-whisper/main.go | 2 +- bindings/go/examples/go-whisper/process.go | 24 ++++- bindings/go/params.go | 31 +++++- bindings/go/pkg/whisper/consts.go | 7 +- bindings/go/pkg/whisper/context.go | 99 +++++++++++++++--- bindings/go/pkg/whisper/context_test.go | 55 ++++++++++ bindings/go/pkg/whisper/interface.go | 34 +++++- bindings/go/pkg/whisper/model.go | 15 ++- bindings/go/whisper.go | 116 +++++++++++++++++---- bindings/go/whisper_test.go | 52 ++++++++- 12 files changed, 393 insertions(+), 56 deletions(-) create mode 100755 bindings/go/pkg/whisper/context_test.go diff --git a/bindings/go/README.md b/bindings/go/README.md index 217235a..91a0ab2 100755 --- a/bindings/go/README.md +++ b/bindings/go/README.md @@ -55,9 +55,11 @@ cd whisper.cpp/bindings/go make test ``` -Then build the examples with: +This will compile a static `libwhisper.a` in a `build` folder, download a model file, then run the tests. To build the examples: ```bash +git clone https://github.com/ggerganov/whisper.cpp.git +cd whisper.cpp/bindings/go make examples ``` diff --git a/bindings/go/examples/go-whisper/flags.go b/bindings/go/examples/go-whisper/flags.go index 61c8af9..a5353d1 100755 --- a/bindings/go/examples/go-whisper/flags.go +++ b/bindings/go/examples/go-whisper/flags.go @@ -42,10 +42,20 @@ func (flags *Flags) GetLanguage() string { return flags.Lookup("language").Value.String() } +func (flags *Flags) IsSpeedup() bool { + return flags.Lookup("speedup").Value.String() == "true" +} + +func (flags *Flags) IsTokens() bool { + return flags.Lookup("tokens").Value.String() == "true" +} + /////////////////////////////////////////////////////////////////////////////// // PRIVATE METHODS func registerFlags(flag *Flags) { flag.String("model", "", "Path to the model file") flag.String("language", "", "Language") + flag.Bool("speedup", false, "Enable speedup") + flag.Bool("tokens", false, "Display tokens") } diff --git a/bindings/go/examples/go-whisper/main.go b/bindings/go/examples/go-whisper/main.go index 2ea83c2..b3a89db 100755 --- a/bindings/go/examples/go-whisper/main.go +++ b/bindings/go/examples/go-whisper/main.go @@ -36,7 +36,7 @@ func main() { // Process files for _, filename := range flags.Args() { fmt.Println("Processing", filename) - if err := Process(model, filename); err != nil { + if err := Process(model, filename, flags.GetLanguage(), flags.IsSpeedup(), flags.IsTokens()); err != nil { fmt.Fprintln(os.Stderr, err) continue } diff --git a/bindings/go/examples/go-whisper/process.go b/bindings/go/examples/go-whisper/process.go index 840bc9f..a0e2be8 100755 --- a/bindings/go/examples/go-whisper/process.go +++ b/bindings/go/examples/go-whisper/process.go @@ -11,7 +11,7 @@ import ( wav "github.com/go-audio/wav" ) -func Process(model whisper.Model, path string) error { +func Process(model whisper.Model, path string, lang string, speedup, tokens bool) error { var data []float32 // Create processing context @@ -39,8 +39,28 @@ func Process(model whisper.Model, path string) error { data = buf.AsFloat32Buffer().Data } + // Set the parameters + var cb whisper.SegmentCallback + if lang != "" { + if err := context.SetLanguage(lang); err != nil { + return err + } + } + if speedup { + context.SetSpeedup(true) + } + if tokens { + cb = func(segment whisper.Segment) { + fmt.Printf("%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond)) + for _, token := range segment.Tokens { + fmt.Printf("%q ", token.Text) + } + fmt.Println("") + } + } + // Process the data - if err := context.Process(data); err != nil { + if err := context.Process(data, cb); err != nil { return err } diff --git a/bindings/go/params.go b/bindings/go/params.go index 3537144..7f4c509 100644 --- a/bindings/go/params.go +++ b/bindings/go/params.go @@ -11,7 +11,7 @@ import ( // CGO /* -#include +#include */ import "C" @@ -50,6 +50,35 @@ func (p *Params) SetSpeedup(v bool) { p.speed_up = toBool(v) } +func (p *Params) SetLanguage(lang int) error { + str := C.whisper_lang_str(C.int(lang)) + if str == nil { + return ErrInvalidLanguage + } else { + p.language = str + } + return nil +} + +func (p *Params) Language() int { + if p.language == nil { + return -1 + } + return int(C.whisper_lang_id(p.language)) +} + +func (p *Params) SetThreads(threads int) { + p.n_threads = C.int(threads) +} + +func (p *Params) SetOffset(offset_ms int) { + p.offset_ms = C.int(offset_ms) +} + +func (p *Params) SetDuration(duration_ms int) { + p.duration_ms = C.int(duration_ms) +} + /////////////////////////////////////////////////////////////////////////////// // PRIVATE METHODS diff --git a/bindings/go/pkg/whisper/consts.go b/bindings/go/pkg/whisper/consts.go index d3e3fd8..710073f 100755 --- a/bindings/go/pkg/whisper/consts.go +++ b/bindings/go/pkg/whisper/consts.go @@ -11,9 +11,10 @@ import ( // ERRORS var ( - ErrUnableToLoadModel = errors.New("unable to load model") - ErrInternalAppError = errors.New("internal application error") - ErrProcessingFailed = errors.New("processing failed") + ErrUnableToLoadModel = errors.New("unable to load model") + ErrInternalAppError = errors.New("internal application error") + ErrProcessingFailed = errors.New("processing failed") + ErrUnsupportedLanguage = errors.New("unsupported language") ) /////////////////////////////////////////////////////////////////////////////// diff --git a/bindings/go/pkg/whisper/context.go b/bindings/go/pkg/whisper/context.go index fb48d78..baff611 100755 --- a/bindings/go/pkg/whisper/context.go +++ b/bindings/go/pkg/whisper/context.go @@ -1,11 +1,11 @@ package whisper import ( - // Bindings "io" "strings" "time" + // Bindings whisper "github.com/ggerganov/whisper.cpp/bindings/go" ) @@ -34,23 +34,66 @@ func NewContext(model *model, params whisper.Params) (Context, error) { } /////////////////////////////////////////////////////////////////////////////// -// STRINGIFY +// PUBLIC METHODS -func (context *context) String() string { - str := "" +// Set the language to use for speech recognition. +func (context *context) SetLanguage(lang string) error { + if context.model.ctx == nil { + return ErrInternalAppError + } + if id := context.model.ctx.Whisper_lang_id(lang); id < 0 { + return ErrUnsupportedLanguage + } else if err := context.params.SetLanguage(id); err != nil { + return err + } + // Return success + return nil } -/////////////////////////////////////////////////////////////////////////////// -// PUBLIC METHODS +// Get language +func (context *context) Language() string { + return whisper.Whisper_lang_str(context.params.Language()) +} + +// Set speedup flag +func (context *context) SetSpeedup(v bool) { + context.params.SetSpeedup(v) +} // Process new sample data and return any errors -func (context *context) Process(data []float32) error { - // Process data +func (context *context) Process(data []float32, cb SegmentCallback) error { if context.model.ctx == nil { return ErrInternalAppError - } else if ret := context.model.ctx.Whisper_full(context.params, data, nil, nil); ret != 0 { - return ErrProcessingFailed + } + // If the callback is defined then we force on single_segment mode + if cb != nil { + context.params.SetSingleSegment(true) + } + + // We don't do parallel processing at the moment + processors := 0 + if processors > 1 { + if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) { + if cb != nil { + num_segments := context.model.ctx.Whisper_full_n_segments() + s0 := num_segments - new + for i := s0; i < num_segments; i++ { + cb(toSegment(context.model.ctx, i)) + } + } + }); err != nil { + return err + } + } else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) { + if cb != nil { + num_segments := context.model.ctx.Whisper_full_n_segments() + s0 := num_segments - new + for i := s0; i < num_segments; i++ { + cb(toSegment(context.model.ctx, i)) + } + } + }); err != nil { + return err } // Return success @@ -59,18 +102,15 @@ func (context *context) Process(data []float32) error { // Return the next segment of tokens func (context *context) NextSegment() (Segment, error) { - result := Segment{} if context.model.ctx == nil { - return result, ErrInternalAppError + return Segment{}, ErrInternalAppError } if context.n >= context.model.ctx.Whisper_full_n_segments() { - return result, io.EOF + return Segment{}, io.EOF } // Populate result - result.Text = strings.TrimSpace(context.model.ctx.Whisper_full_get_segment_text(context.n)) - result.Start = time.Duration(context.model.ctx.Whisper_full_get_segment_t0(context.n)) * time.Millisecond - result.End = time.Duration(context.model.ctx.Whisper_full_get_segment_t1(context.n)) * time.Millisecond + result := toSegment(context.model.ctx, context.n) // Increment the cursor context.n++ @@ -78,3 +118,28 @@ func (context *context) NextSegment() (Segment, error) { // Return success return result, nil } + +/////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS + +func toSegment(ctx *whisper.Context, n int) Segment { + return Segment{ + Num: n, + Text: strings.TrimSpace(ctx.Whisper_full_get_segment_text(n)), + Start: time.Duration(ctx.Whisper_full_get_segment_t0(n)) * time.Millisecond * 10, + End: time.Duration(ctx.Whisper_full_get_segment_t1(n)) * time.Millisecond * 10, + Tokens: toTokens(ctx, n), + } +} + +func toTokens(ctx *whisper.Context, n int) []Token { + result := make([]Token, ctx.Whisper_full_n_tokens(n)) + for i := 0; i < len(result); i++ { + result[i] = Token{ + Id: int(ctx.Whisper_full_get_token_id(n, i)), + Text: strings.TrimSpace(ctx.Whisper_full_get_token_text(n, i)), + P: ctx.Whisper_full_get_token_p(n, i), + } + } + return result +} diff --git a/bindings/go/pkg/whisper/context_test.go b/bindings/go/pkg/whisper/context_test.go new file mode 100755 index 0000000..c8c6016 --- /dev/null +++ b/bindings/go/pkg/whisper/context_test.go @@ -0,0 +1,55 @@ +package whisper_test + +import ( + "os" + "testing" + + // Packages + whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" + assert "github.com/stretchr/testify/assert" +) + +const ( + ModelPath = "../../models/ggml-tiny.bin" + SamplePath = "../../samples/jfk.wav" +) + +func Test_Whisper_000(t *testing.T) { + assert := assert.New(t) + if _, err := os.Stat(ModelPath); os.IsNotExist(err) { + t.Skip("Skipping test, model not found:", ModelPath) + } + if _, err := os.Stat(SamplePath); os.IsNotExist(err) { + t.Skip("Skipping test, sample not found:", SamplePath) + } + + // Load model + model, err := whisper.New(ModelPath) + assert.NoError(err) + assert.NotNil(model) + assert.NoError(model.Close()) + + t.Log("languages=", model.Languages()) +} + +func Test_Whisper_001(t *testing.T) { + assert := assert.New(t) + if _, err := os.Stat(ModelPath); os.IsNotExist(err) { + t.Skip("Skipping test, model not found:", ModelPath) + } + if _, err := os.Stat(SamplePath); os.IsNotExist(err) { + t.Skip("Skipping test, sample not found:", SamplePath) + } + + // Load model + model, err := whisper.New(ModelPath) + assert.NoError(err) + assert.NotNil(model) + defer model.Close() + + // Get context for decoding + ctx, err := model.NewContext() + assert.NoError(err) + assert.NotNil(ctx) + +} diff --git a/bindings/go/pkg/whisper/interface.go b/bindings/go/pkg/whisper/interface.go index d9d91cf..53e4f3f 100755 --- a/bindings/go/pkg/whisper/interface.go +++ b/bindings/go/pkg/whisper/interface.go @@ -8,28 +8,56 @@ import ( /////////////////////////////////////////////////////////////////////////////// // TYPES +// SegmentCallback is the callback function for processing segments in real +// time. It is called during the Process function +type SegmentCallback func(Segment) + // Model is the interface to a whisper model. Create a new model with the // function whisper.New(string) type Model interface { io.Closer + // Return a new speech-to-text context. NewContext() (Context, error) + + // Return all languages supported. + Languages() []string } // Context is the speach recognition context. type Context interface { + SetLanguage(string) error // Set the language to use for speech recognition. + Language() string // Get language + SetSpeedup(bool) // Set speedup flag + // Process mono audio data and return any errors. - Process([]float32) error + // If defined, newly generated segments are passed to the + // callback function during processing. + Process([]float32, SegmentCallback) error - // Return segments until the end of the stream is reached, when io.EOF is - // returned. + // After process is called, return segments until the end of the stream + // is reached, when io.EOF is returned. NextSegment() (Segment, error) } +// Segment is the text result of a speech recognition. type Segment struct { + // Segment Number + Num int + // Time beginning and end timestamps for the segment. Start, End time.Duration // The text of the segment. Text string + + // The tokens of the segment. + Tokens []Token +} + +// Token is a text or special token +type Token struct { + Id int + Text string + P float32 } diff --git a/bindings/go/pkg/whisper/model.go b/bindings/go/pkg/whisper/model.go index ad9a3eb..13cb52c 100755 --- a/bindings/go/pkg/whisper/model.go +++ b/bindings/go/pkg/whisper/model.go @@ -3,6 +3,7 @@ package whisper import ( "fmt" "os" + "runtime" // Bindings whisper "github.com/ggerganov/whisper.cpp/bindings/go" @@ -63,6 +64,18 @@ func (model *model) String() string { /////////////////////////////////////////////////////////////////////////////// // PUBLIC METHODS +// Return all recognized languages. Initially it is set to auto-detect +func (model *model) Languages() []string { + result := make([]string, 0, whisper.Whisper_lang_max_id()) + for i := 0; i < whisper.Whisper_lang_max_id(); i++ { + str := whisper.Whisper_lang_str(i) + if model.ctx.Whisper_lang_id(str) >= 0 { + result = append(result, str) + } + } + return result +} + func (model *model) NewContext() (Context, error) { if model.ctx == nil { return nil, ErrInternalAppError @@ -75,7 +88,7 @@ func (model *model) NewContext() (Context, error) { params.SetPrintProgress(false) params.SetPrintRealtime(false) params.SetPrintTimestamps(false) - params.SetSpeedup(false) + params.SetThreads(runtime.NumCPU()) // Return new context return NewContext(model, params) diff --git a/bindings/go/whisper.go b/bindings/go/whisper.go index d9d8ddf..2584f7b 100644 --- a/bindings/go/whisper.go +++ b/bindings/go/whisper.go @@ -1,6 +1,9 @@ package whisper -import "unsafe" +import ( + "errors" + "unsafe" +) /////////////////////////////////////////////////////////////////////////////// // CGO @@ -74,6 +77,13 @@ const ( ChunkSize = C.WHISPER_CHUNK_SIZE ) +var ( + ErrTokenizerFailed = errors.New("whisper_tokenize failed") + ErrAutoDetectFailed = errors.New("whisper_lang_auto_detect failed") + ErrConversionFailed = errors.New("whisper_convert failed") + ErrInvalidLanguage = errors.New("invalid language") +) + /////////////////////////////////////////////////////////////////////////////// // PUBLIC METHODS @@ -96,34 +106,46 @@ func (ctx *Context) Whisper_free() { // Convert RAW PCM audio to log mel spectrogram. // The resulting spectrogram is stored inside the provided whisper context. -// Returns 0 on success -func (ctx *Context) Whisper_pcm_to_mel(data []float32, threads int) int { - return int(C.whisper_pcm_to_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(threads))) +func (ctx *Context) Whisper_pcm_to_mel(data []float32, threads int) error { + if C.whisper_pcm_to_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(threads)) == 0 { + return nil + } else { + return ErrConversionFailed + } } // This can be used to set a custom log mel spectrogram inside the provided whisper context. // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram. // n_mel must be 80 -// Returns 0 on success -func (ctx *Context) Whisper_set_mel(data []float32, n_mel int) int { - return int(C.whisper_set_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(n_mel))) +func (ctx *Context) Whisper_set_mel(data []float32, n_mel int) error { + if C.whisper_set_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(n_mel)) == 0 { + return nil + } else { + return ErrConversionFailed + } } // Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context. // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first. // offset can be used to specify the offset of the first frame in the spectrogram. -// Returns 0 on success -func (ctx *Context) Whisper_encode(offset, threads int) int { - return int(C.whisper_encode((*C.struct_whisper_context)(ctx), C.int(offset), C.int(threads))) +func (ctx *Context) Whisper_encode(offset, threads int) error { + if C.whisper_encode((*C.struct_whisper_context)(ctx), C.int(offset), C.int(threads)) == 0 { + return nil + } else { + return ErrConversionFailed + } } // Run the Whisper decoder to obtain the logits and probabilities for the next token. // Make sure to call whisper_encode() first. // tokens + n_tokens is the provided context for the decoder. // n_past is the number of tokens to use from previous decoder calls. -// Returns 0 on success -func (ctx *Context) Whisper_decode(tokens []Token, past, threads int) int { - return int(C.whisper_decode((*C.struct_whisper_context)(ctx), (*C.whisper_token)(&tokens[0]), C.int(len(tokens)), C.int(past), C.int(threads))) +func (ctx *Context) Whisper_decode(tokens []Token, past, threads int) error { + if C.whisper_decode((*C.struct_whisper_context)(ctx), (*C.whisper_token)(&tokens[0]), C.int(len(tokens)), C.int(past), C.int(threads)) == 0 { + return nil + } else { + return ErrConversionFailed + } } // whisper_sample_best() returns the token with the highest probability @@ -136,11 +158,47 @@ func (ctx *Context) Whisper_sample_timestamp(is_initial bool) TokenData { return TokenData(C.whisper_sample_timestamp((*C.struct_whisper_context)(ctx), C.bool(is_initial))) } +// Convert the provided text into tokens. The tokens pointer must be large enough to hold the resulting tokens. +// Returns the number of tokens on success +func (ctx *Context) Whisper_tokenize(text string, tokens []Token) (int, error) { + cText := C.CString(text) + defer C.free(unsafe.Pointer(cText)) + if n := C.whisper_tokenize((*C.struct_whisper_context)(ctx), cText, (*C.whisper_token)(&tokens[0]), C.int(len(tokens))); n >= 0 { + return int(n), nil + } else { + return 0, ErrTokenizerFailed + } +} + // Return the id of the specified language, returns -1 if not found func (ctx *Context) Whisper_lang_id(lang string) int { return int(C.whisper_lang_id(C.CString(lang))) } +// Largest language id (i.e. number of available languages - 1) +func Whisper_lang_max_id() int { + return int(C.whisper_lang_max_id()) +} + +// Return the short string of the specified language id (e.g. 2 -> "de"), +// returns empty string if not found +func Whisper_lang_str(id int) string { + return C.GoString(C.whisper_lang_str(C.int(id))) +} + +// Use mel data at offset_ms to try and auto-detect the spoken language +// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first. +// Returns the probabilities of all languages. +// ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69 +func (ctx *Context) Whisper_lang_auto_detect(offset_ms, n_threads int) ([]float32, error) { + probs := make([]float32, Whisper_lang_max_id()+1) + if n := int(C.whisper_lang_auto_detect((*C.struct_whisper_context)(ctx), C.int(offset_ms), C.int(n_threads), (*C.float)(&probs[0]))); n < 0 { + return nil, ErrAutoDetectFailed + } else { + return probs, nil + } +} + func (ctx *Context) Whisper_n_len() int { return int(C.whisper_n_len((*C.struct_whisper_context)(ctx))) } @@ -197,6 +255,11 @@ func (ctx *Context) Whisper_token_beg() Token { return Token(C.whisper_token_beg((*C.struct_whisper_context)(ctx))) } +// Special tokens +func (ctx *Context) Whisper_token_lang(lang_id int) Token { + return Token(C.whisper_token_lang((*C.struct_whisper_context)(ctx), C.int(lang_id))) +} + // Task tokens func Whisper_token_translate() Token { return Token(C.whisper_token_translate()) @@ -230,25 +293,32 @@ func (ctx *Context) Whisper_full_default_params(strategy SamplingStrategy) Param // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text // Uses the specified decoding strategy to obtain the text. -func (ctx *Context) Whisper_full(params Params, samples []float32, encoderBeginCallback func() bool, newSegmentCallback func(int)) int { +func (ctx *Context) Whisper_full(params Params, samples []float32, encoderBeginCallback func() bool, newSegmentCallback func(int)) error { registerEncoderBeginCallback(ctx, encoderBeginCallback) registerNewSegmentCallback(ctx, newSegmentCallback) - result := int(C.whisper_full((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples)))) - registerEncoderBeginCallback(ctx, nil) - registerNewSegmentCallback(ctx, nil) - return result + defer registerEncoderBeginCallback(ctx, nil) + defer registerNewSegmentCallback(ctx, nil) + if C.whisper_full((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples))) == 0 { + return nil + } else { + return ErrConversionFailed + } } // Split the input audio in chunks and process each chunk separately using whisper_full() // It seems this approach can offer some speedup in some cases. // However, the transcription accuracy can be worse at the beginning and end of each chunk. -func (ctx *Context) Whisper_full_parallel(params Params, samples []float32, processors int, encoderBeginCallback func() bool, newSegmentCallback func(int)) int { +func (ctx *Context) Whisper_full_parallel(params Params, samples []float32, processors int, encoderBeginCallback func() bool, newSegmentCallback func(int)) error { registerEncoderBeginCallback(ctx, encoderBeginCallback) registerNewSegmentCallback(ctx, newSegmentCallback) - result := int(C.whisper_full_parallel((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples)), C.int(processors))) - registerEncoderBeginCallback(ctx, nil) - registerNewSegmentCallback(ctx, nil) - return result + defer registerEncoderBeginCallback(ctx, nil) + defer registerNewSegmentCallback(ctx, nil) + + if C.whisper_full_parallel((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples)), C.int(processors)) == 0 { + return nil + } else { + return ErrConversionFailed + } } // Number of generated text segments. diff --git a/bindings/go/whisper_test.go b/bindings/go/whisper_test.go index 316438b..d7b8cae 100644 --- a/bindings/go/whisper_test.go +++ b/bindings/go/whisper_test.go @@ -2,12 +2,13 @@ package whisper_test import ( "os" + "runtime" "testing" "time" // Packages whisper "github.com/ggerganov/whisper.cpp/bindings/go" - "github.com/go-audio/wav" + wav "github.com/go-audio/wav" assert "github.com/stretchr/testify/assert" ) @@ -44,14 +45,12 @@ func Test_Whisper_001(t *testing.T) { d := wav.NewDecoder(fh) buf, err := d.FullPCMBuffer() assert.NoError(err) - buf2 := buf.AsFloat32Buffer() // Run whisper ctx := whisper.Whisper_init(ModelPath) assert.NotNil(ctx) defer ctx.Whisper_free() - ret := ctx.Whisper_full(ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY), buf2.Data, nil, nil) - assert.Equal(0, ret) + assert.NoError(ctx.Whisper_full(ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY), buf.AsFloat32Buffer().Data, nil, nil)) // Print out tokens num_segments := ctx.Whisper_full_n_segments() @@ -64,3 +63,48 @@ func Test_Whisper_001(t *testing.T) { t.Logf("[%6s->%-6s] %q", t0, t1, str) } } + +func Test_Whisper_002(t *testing.T) { + assert := assert.New(t) + for i := 0; i < whisper.Whisper_lang_max_id(); i++ { + str := whisper.Whisper_lang_str(i) + assert.NotEmpty(str) + t.Log(str) + } +} + +func Test_Whisper_003(t *testing.T) { + threads := runtime.NumCPU() + assert := assert.New(t) + if _, err := os.Stat(ModelPath); os.IsNotExist(err) { + t.Skip("Skipping test, model not found:", ModelPath) + } + if _, err := os.Stat(SamplePath); os.IsNotExist(err) { + t.Skip("Skipping test, sample not found:", SamplePath) + } + + // Open samples + fh, err := os.Open(SamplePath) + assert.NoError(err) + defer fh.Close() + + // Read samples + d := wav.NewDecoder(fh) + buf, err := d.FullPCMBuffer() + assert.NoError(err) + + // Make the model + ctx := whisper.Whisper_init(ModelPath) + assert.NotNil(ctx) + defer ctx.Whisper_free() + + // Get MEL + assert.NoError(ctx.Whisper_pcm_to_mel(buf.AsFloat32Buffer().Data, threads)) + + // Get Languages + languages, err := ctx.Whisper_lang_auto_detect(0, threads) + assert.NoError(err) + for i, p := range languages { + t.Logf("%s: %f", whisper.Whisper_lang_str(i), p) + } +} From b16e0c3fc5180c373541c1e7a24a816721efd451 Mon Sep 17 00:00:00 2001 From: David Thorpe Date: Sun, 18 Dec 2022 20:06:42 +0100 Subject: [PATCH 2/2] Makefile update to add in more tests --- bindings/go/Makefile | 1 + bindings/go/README.md | 5 ++--- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/bindings/go/Makefile b/bindings/go/Makefile index 01dba95..3374212 100755 --- a/bindings/go/Makefile +++ b/bindings/go/Makefile @@ -14,6 +14,7 @@ whisper: mkdir test: model-small whisper @go mod tidy @go test -v . + @go test -v ./pkg/whisper/... examples: $(EXAMPLES_DIR) diff --git a/bindings/go/README.md b/bindings/go/README.md index 91a0ab2..8ae89c7 100755 --- a/bindings/go/README.md +++ b/bindings/go/README.md @@ -4,6 +4,7 @@ This package provides Go bindings for whisper.cpp. They have been tested on: * Darwin (OS X) 12.6 on x64_64 * Debian Linux on arm64 + * Fedora Linux on x86_64 The "low level" bindings are in the `bindings/go` directory and there is a more Go-style package in the `bindings/go/pkg/whisper` directory. The most simple usage @@ -30,7 +31,7 @@ func main() { if err != nil { panic(err) } - if err := context.Process(samples); err != nil { + if err := context.Process(samples, nil); err != nil { return err } @@ -58,8 +59,6 @@ make test This will compile a static `libwhisper.a` in a `build` folder, download a model file, then run the tests. To build the examples: ```bash -git clone https://github.com/ggerganov/whisper.cpp.git -cd whisper.cpp/bindings/go make examples ```