go : adding features to the go-whisper example, go ci, etc (#384)

* Updated bindings so they can be used in third pary packages.

* Updated makefiles to set FMA flag on optionally, for xeon E5 on Darwin

* Added test script

* Changes for examples

* Reverted

* Made the NewContext method private
pull/388/head
David Thorpe 2 years ago committed by GitHub
parent f30b5d322c
commit f078a6f20e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,17 @@
name: Bindings Tests
on:
push:
paths:
- bindings/go/**
jobs:
ubuntu-latest:
runs-on: ubuntu-latest
steps:
- uses: actions/setup-go@v3
with:
go-version: '^1.19'
- uses: actions/checkout@v1
- run: |
cd bindings/go
make test

@ -0,0 +1,22 @@
package main
import "fmt"
///////////////////////////////////////////////////////////////////////////////
// CONSTANTS
const (
Reset = "\033[0m"
RGBPrefix = "\033[38;5;" // followed by RGB values in decimal format separated by colons
RGBSuffix = "m"
)
///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS
// Colorize text with RGB values, from 0 to 23
func Colorize(text string, v int) string {
// https://en.wikipedia.org/wiki/ANSI_escape_code#8-bit
// Grayscale colors are in the range 232-255
return RGBPrefix + fmt.Sprint(v%24+232) + RGBSuffix + text + Reset
}

@ -2,6 +2,12 @@ package main
import ( import (
"flag" "flag"
"fmt"
"strings"
"time"
// Packages
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
) )
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@ -42,6 +48,26 @@ func (flags *Flags) GetLanguage() string {
return flags.Lookup("language").Value.String() return flags.Lookup("language").Value.String()
} }
func (flags *Flags) IsTranslate() bool {
return flags.Lookup("translate").Value.(flag.Getter).Get().(bool)
}
func (flags *Flags) GetOffset() time.Duration {
return flags.Lookup("offset").Value.(flag.Getter).Get().(time.Duration)
}
func (flags *Flags) GetDuration() time.Duration {
return flags.Lookup("duration").Value.(flag.Getter).Get().(time.Duration)
}
func (flags *Flags) GetThreads() uint {
return flags.Lookup("threads").Value.(flag.Getter).Get().(uint)
}
func (flags *Flags) GetOut() string {
return strings.ToLower(flags.Lookup("out").Value.String())
}
func (flags *Flags) IsSpeedup() bool { func (flags *Flags) IsSpeedup() bool {
return flags.Lookup("speedup").Value.String() == "true" return flags.Lookup("speedup").Value.String() == "true"
} }
@ -50,12 +76,81 @@ func (flags *Flags) IsTokens() bool {
return flags.Lookup("tokens").Value.String() == "true" return flags.Lookup("tokens").Value.String() == "true"
} }
func (flags *Flags) IsColorize() bool {
return flags.Lookup("colorize").Value.String() == "true"
}
func (flags *Flags) GetMaxLen() uint {
return flags.Lookup("max-len").Value.(flag.Getter).Get().(uint)
}
func (flags *Flags) GetMaxTokens() uint {
return flags.Lookup("max-tokens").Value.(flag.Getter).Get().(uint)
}
func (flags *Flags) GetWordThreshold() float32 {
return float32(flags.Lookup("word-thold").Value.(flag.Getter).Get().(float64))
}
func (flags *Flags) SetParams(context whisper.Context) error {
if lang := flags.GetLanguage(); lang != "" && lang != "auto" {
fmt.Fprintf(flags.Output(), "Setting language to %q\n", lang)
if err := context.SetLanguage(lang); err != nil {
return err
}
}
if flags.IsTranslate() && context.IsMultilingual() {
fmt.Fprintf(flags.Output(), "Setting translate to true\n")
context.SetTranslate(true)
}
if offset := flags.GetOffset(); offset != 0 {
fmt.Fprintf(flags.Output(), "Setting offset to %v\n", offset)
context.SetOffset(offset)
}
if duration := flags.GetDuration(); duration != 0 {
fmt.Fprintf(flags.Output(), "Setting duration to %v\n", duration)
context.SetDuration(duration)
}
if flags.IsSpeedup() {
fmt.Fprintf(flags.Output(), "Setting speedup to true\n")
context.SetSpeedup(true)
}
if threads := flags.GetThreads(); threads != 0 {
fmt.Fprintf(flags.Output(), "Setting threads to %d\n", threads)
context.SetThreads(threads)
}
if max_len := flags.GetMaxLen(); max_len != 0 {
fmt.Fprintf(flags.Output(), "Setting max_segment_length to %d\n", max_len)
context.SetMaxSegmentLength(max_len)
}
if max_tokens := flags.GetMaxTokens(); max_tokens != 0 {
fmt.Fprintf(flags.Output(), "Setting max_tokens to %d\n", max_tokens)
context.SetMaxTokensPerSegment(max_tokens)
}
if word_threshold := flags.GetWordThreshold(); word_threshold != 0 {
fmt.Fprintf(flags.Output(), "Setting word_threshold to %f\n", word_threshold)
context.SetTokenThreshold(word_threshold)
}
// Return success
return nil
}
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS // PRIVATE METHODS
func registerFlags(flag *Flags) { func registerFlags(flag *Flags) {
flag.String("model", "", "Path to the model file") flag.String("model", "", "Path to the model file")
flag.String("language", "", "Language") flag.String("language", "", "Spoken language")
flag.Bool("translate", false, "Translate from source language to english")
flag.Duration("offset", 0, "Time offset")
flag.Duration("duration", 0, "Duration of audio to process")
flag.Uint("threads", 0, "Number of threads to use")
flag.Bool("speedup", false, "Enable speedup") flag.Bool("speedup", false, "Enable speedup")
flag.Uint("max-len", 0, "Maximum segment length in characters")
flag.Uint("max-tokens", 0, "Maximum tokens per segment")
flag.Float64("word-thold", 0, "Maximum segment score")
flag.Bool("tokens", false, "Display tokens") flag.Bool("tokens", false, "Display tokens")
flag.Bool("colorize", false, "Colorize tokens")
flag.String("out", "", "Output format (srt, none or leave as empty string)")
} }

@ -35,8 +35,7 @@ func main() {
// Process files // Process files
for _, filename := range flags.Args() { for _, filename := range flags.Args() {
fmt.Println("Processing", filename) if err := Process(model, filename, flags); err != nil {
if err := Process(model, filename, flags.GetLanguage(), flags.IsSpeedup(), flags.IsTokens()); err != nil {
fmt.Fprintln(os.Stderr, err) fmt.Fprintln(os.Stderr, err)
continue continue
} }

@ -11,7 +11,7 @@ import (
wav "github.com/go-audio/wav" wav "github.com/go-audio/wav"
) )
func Process(model whisper.Model, path string, lang string, speedup, tokens bool) error { func Process(model whisper.Model, path string, flags *Flags) error {
var data []float32 var data []float32
// Create processing context // Create processing context
@ -20,14 +20,20 @@ func Process(model whisper.Model, path string, lang string, speedup, tokens bool
return err return err
} }
// Set the parameters
if err := flags.SetParams(context); err != nil {
return err
}
// Open the file // Open the file
fmt.Fprintf(flags.Output(), "Loading %q\n", path)
fh, err := os.Open(path) fh, err := os.Open(path)
if err != nil { if err != nil {
return err return err
} }
defer fh.Close() defer fh.Close()
// Decode the WAV file // Decode the WAV file - load the full buffer
dec := wav.NewDecoder(fh) dec := wav.NewDecoder(fh)
if buf, err := dec.FullPCMBuffer(); err != nil { if buf, err := dec.FullPCMBuffer(); err != nil {
return err return err
@ -39,42 +45,83 @@ func Process(model whisper.Model, path string, lang string, speedup, tokens bool
data = buf.AsFloat32Buffer().Data data = buf.AsFloat32Buffer().Data
} }
// Set the parameters // Segment callback when -tokens is specified
var cb whisper.SegmentCallback var cb whisper.SegmentCallback
if lang != "" { if flags.IsTokens() {
if err := context.SetLanguage(lang); err != nil {
return err
}
}
if speedup {
context.SetSpeedup(true)
}
if tokens {
cb = func(segment whisper.Segment) { cb = func(segment whisper.Segment) {
fmt.Printf("%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond)) fmt.Fprintf(flags.Output(), "%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
for _, token := range segment.Tokens { for _, token := range segment.Tokens {
fmt.Printf("%q ", token.Text) if flags.IsColorize() && context.IsText(token) {
fmt.Fprint(flags.Output(), Colorize(token.Text, int(token.P*24.0)), " ")
} else {
fmt.Fprint(flags.Output(), token.Text, " ")
}
} }
fmt.Println("") fmt.Fprintln(flags.Output(), "")
fmt.Fprintln(flags.Output(), "")
} }
} }
// Process the data // Process the data
fmt.Fprintf(flags.Output(), " ...processing %q\n", path)
if err := context.Process(data, cb); err != nil { if err := context.Process(data, cb); err != nil {
return err return err
} }
// Print out the results // Print out the results
switch {
case flags.GetOut() == "srt":
return OutputSRT(os.Stdout, context)
case flags.GetOut() == "none":
return nil
default:
return Output(os.Stdout, context, flags.IsColorize())
}
}
// Output text as SRT file
func OutputSRT(w io.Writer, context whisper.Context) error {
n := 1
for { for {
segment, err := context.NextSegment() segment, err := context.NextSegment()
if err == io.EOF { if err == io.EOF {
break return nil
} else if err != nil { } else if err != nil {
return err return err
} }
fmt.Printf("[%6s->%6s] %s\n", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond), segment.Text) fmt.Fprintln(w, n)
fmt.Fprintln(w, srtTimestamp(segment.Start), " --> ", srtTimestamp(segment.End))
fmt.Fprintln(w, segment.Text)
fmt.Fprintln(w, "")
n++
} }
}
// Output text to terminal
func Output(w io.Writer, context whisper.Context, colorize bool) error {
for {
segment, err := context.NextSegment()
if err == io.EOF {
return nil
} else if err != nil {
return err
}
fmt.Fprintf(w, "[%6s->%6s]", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
if colorize {
for _, token := range segment.Tokens {
if !context.IsText(token) {
continue
}
fmt.Fprint(w, " ", Colorize(token.Text, int(token.P*24.0)))
}
fmt.Fprint(w, "\n")
} else {
fmt.Fprintln(w, " ", segment.Text)
}
}
}
// Return success // Return srtTimestamp
return nil func srtTimestamp(t time.Duration) string {
return fmt.Sprintf("%02d:%02d:%02d,%03d", t/time.Hour, (t%time.Hour)/time.Minute, (t%time.Minute)/time.Second, (t%time.Second)/time.Millisecond)
} }

@ -47,6 +47,7 @@ func (p *Params) SetSpeedup(v bool) {
p.speed_up = toBool(v) p.speed_up = toBool(v)
} }
// Set language id
func (p *Params) SetLanguage(lang int) error { func (p *Params) SetLanguage(lang int) error {
str := C.whisper_lang_str(C.int(lang)) str := C.whisper_lang_str(C.int(lang))
if str == nil { if str == nil {
@ -57,6 +58,7 @@ func (p *Params) SetLanguage(lang int) error {
return nil return nil
} }
// Get language id
func (p *Params) Language() int { func (p *Params) Language() int {
if p.language == nil { if p.language == nil {
return -1 return -1
@ -64,18 +66,41 @@ func (p *Params) Language() int {
return int(C.whisper_lang_id(p.language)) return int(C.whisper_lang_id(p.language))
} }
// Set number of threads to use
func (p *Params) SetThreads(threads int) { func (p *Params) SetThreads(threads int) {
p.n_threads = C.int(threads) p.n_threads = C.int(threads)
} }
// Set start offset in ms
func (p *Params) SetOffset(offset_ms int) { func (p *Params) SetOffset(offset_ms int) {
p.offset_ms = C.int(offset_ms) p.offset_ms = C.int(offset_ms)
} }
// Set audio duration to process in ms
func (p *Params) SetDuration(duration_ms int) { func (p *Params) SetDuration(duration_ms int) {
p.duration_ms = C.int(duration_ms) p.duration_ms = C.int(duration_ms)
} }
// Set timestamp token probability threshold (~0.01)
func (p *Params) SetTokenThreshold(t float32) {
p.thold_pt = C.float(t)
}
// Set timestamp token sum probability threshold (~0.01)
func (p *Params) SetTokenSumThreshold(t float32) {
p.thold_ptsum = C.float(t)
}
// Set max segment length in characters
func (p *Params) SetMaxSegmentLength(n int) {
p.max_len = C.int(n)
}
// Set max tokens per segment (0 = no limit)
func (p *Params) SetMaxTokensPerSegment(n int) {
p.max_tokens = C.int(n)
}
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS // PRIVATE METHODS

@ -11,10 +11,11 @@ import (
// ERRORS // ERRORS
var ( var (
ErrUnableToLoadModel = errors.New("unable to load model") ErrUnableToLoadModel = errors.New("unable to load model")
ErrInternalAppError = errors.New("internal application error") ErrInternalAppError = errors.New("internal application error")
ErrProcessingFailed = errors.New("processing failed") ErrProcessingFailed = errors.New("processing failed")
ErrUnsupportedLanguage = errors.New("unsupported language") ErrUnsupportedLanguage = errors.New("unsupported language")
ErrModelNotMultilingual = errors.New("model is not multilingual")
) )
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////

@ -24,7 +24,7 @@ var _ Context = (*context)(nil)
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// LIFECYCLE // LIFECYCLE
func NewContext(model *model, params whisper.Params) (Context, error) { func newContext(model *model, params whisper.Params) (Context, error) {
context := new(context) context := new(context)
context.model = model context.model = model
context.params = params context.params = params
@ -41,6 +41,9 @@ func (context *context) SetLanguage(lang string) error {
if context.model.ctx == nil { if context.model.ctx == nil {
return ErrInternalAppError return ErrInternalAppError
} }
if !context.model.IsMultilingual() {
return ErrModelNotMultilingual
}
if id := context.model.ctx.Whisper_lang_id(lang); id < 0 { if id := context.model.ctx.Whisper_lang_id(lang); id < 0 {
return ErrUnsupportedLanguage return ErrUnsupportedLanguage
} else if err := context.params.SetLanguage(id); err != nil { } else if err := context.params.SetLanguage(id); err != nil {
@ -50,16 +53,60 @@ func (context *context) SetLanguage(lang string) error {
return nil return nil
} }
func (context *context) IsMultilingual() bool {
return context.model.IsMultilingual()
}
// Get language // Get language
func (context *context) Language() string { func (context *context) Language() string {
return whisper.Whisper_lang_str(context.params.Language()) return whisper.Whisper_lang_str(context.params.Language())
} }
// Set translate flag
func (context *context) SetTranslate(v bool) {
context.params.SetTranslate(v)
}
// Set speedup flag // Set speedup flag
func (context *context) SetSpeedup(v bool) { func (context *context) SetSpeedup(v bool) {
context.params.SetSpeedup(v) context.params.SetSpeedup(v)
} }
// Set number of threads to use
func (context *context) SetThreads(v uint) {
context.params.SetThreads(int(v))
}
// Set time offset
func (context *context) SetOffset(v time.Duration) {
context.params.SetOffset(int(v.Milliseconds()))
}
// Set duration of audio to process
func (context *context) SetDuration(v time.Duration) {
context.params.SetOffset(int(v.Milliseconds()))
}
// Set timestamp token probability threshold (~0.01)
func (context *context) SetTokenThreshold(t float32) {
context.params.SetTokenThreshold(t)
}
// Set timestamp token sum probability threshold (~0.01)
func (context *context) SetTokenSumThreshold(t float32) {
context.params.SetTokenSumThreshold(t)
}
// Set max segment length in characters
func (context *context) SetMaxSegmentLength(n uint) {
context.params.SetMaxSegmentLength(int(n))
}
// Set max tokens per segment (0 = no limit)
func (context *context) SetMaxTokensPerSegment(n uint) {
context.params.SetMaxTokensPerSegment(int(n))
}
// Process new sample data and return any errors // Process new sample data and return any errors
func (context *context) Process(data []float32, cb SegmentCallback) error { func (context *context) Process(data []float32, cb SegmentCallback) error {
if context.model.ctx == nil { if context.model.ctx == nil {
@ -119,6 +166,65 @@ func (context *context) NextSegment() (Segment, error) {
return result, nil return result, nil
} }
// Test for text tokens
func (context *context) IsText(t Token) bool {
switch {
case context.IsBEG(t):
return false
case context.IsSOT(t):
return false
case whisper.Token(t.Id) >= context.model.ctx.Whisper_token_eot():
return false
case context.IsPREV(t):
return false
case context.IsSOLM(t):
return false
case context.IsNOT(t):
return false
default:
return true
}
}
// Test for "begin" token
func (context *context) IsBEG(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_beg()
}
// Test for "start of transcription" token
func (context *context) IsSOT(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_sot()
}
// Test for "end of transcription" token
func (context *context) IsEOT(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_eot()
}
// Test for "start of prev" token
func (context *context) IsPREV(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_prev()
}
// Test for "start of lm" token
func (context *context) IsSOLM(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_solm()
}
// Test for "No timestamps" token
func (context *context) IsNOT(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_not()
}
// Test for token associated with a specific language
func (context *context) IsLANG(t Token, lang string) bool {
if id := context.model.ctx.Whisper_lang_id(lang); id >= 0 {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_lang(id)
} else {
return false
}
}
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS // PRIVATE METHODS

@ -20,6 +20,9 @@ type Model interface {
// Return a new speech-to-text context. // Return a new speech-to-text context.
NewContext() (Context, error) NewContext() (Context, error)
// Return true if the model is multilingual.
IsMultilingual() bool
// Return all languages supported. // Return all languages supported.
Languages() []string Languages() []string
} }
@ -27,8 +30,18 @@ type Model interface {
// Context is the speach recognition context. // Context is the speach recognition context.
type Context interface { type Context interface {
SetLanguage(string) error // Set the language to use for speech recognition. SetLanguage(string) error // Set the language to use for speech recognition.
SetTranslate(bool) // Set translate flag
IsMultilingual() bool // Return true if the model is multilingual.
Language() string // Get language Language() string // Get language
SetSpeedup(bool) // Set speedup flag
SetOffset(time.Duration) // Set offset
SetDuration(time.Duration) // Set duration
SetThreads(uint) // Set number of threads to use
SetSpeedup(bool) // Set speedup flag
SetTokenThreshold(float32) // Set timestamp token probability threshold
SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold
SetMaxSegmentLength(uint) // Set max segment length in characters
SetMaxTokensPerSegment(uint) // Set max tokens per segment (0 = no limit)
// Process mono audio data and return any errors. // Process mono audio data and return any errors.
// If defined, newly generated segments are passed to the // If defined, newly generated segments are passed to the
@ -38,6 +51,15 @@ type Context interface {
// After process is called, return segments until the end of the stream // After process is called, return segments until the end of the stream
// is reached, when io.EOF is returned. // is reached, when io.EOF is returned.
NextSegment() (Segment, error) NextSegment() (Segment, error)
IsBEG(Token) bool // Test for "begin" token
IsSOT(Token) bool // Test for "start of transcription" token
IsEOT(Token) bool // Test for "end of transcription" token
IsPREV(Token) bool // Test for "start of prev" token
IsSOLM(Token) bool // Test for "start of lm" token
IsNOT(Token) bool // Test for "No timestamps" token
IsLANG(Token, string) bool // Test for token associated with a specific language
IsText(Token) bool // Test for text token
} }
// Segment is the text result of a speech recognition. // Segment is the text result of a speech recognition.

@ -23,7 +23,7 @@ var _ Model = (*model)(nil)
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// LIFECYCLE // LIFECYCLE
func New(path string) (*model, error) { func New(path string) (Model, error) {
model := new(model) model := new(model)
if _, err := os.Stat(path); err != nil { if _, err := os.Stat(path); err != nil {
return nil, err return nil, err
@ -64,6 +64,11 @@ func (model *model) String() string {
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS // PUBLIC METHODS
// Return true if model is multilingual (language and translation options are supported)
func (model *model) IsMultilingual() bool {
return model.ctx.Whisper_is_multilingual() != 0
}
// Return all recognized languages. Initially it is set to auto-detect // Return all recognized languages. Initially it is set to auto-detect
func (model *model) Languages() []string { func (model *model) Languages() []string {
result := make([]string, 0, whisper.Whisper_lang_max_id()) result := make([]string, 0, whisper.Whisper_lang_max_id())
@ -91,5 +96,5 @@ func (model *model) NewContext() (Context, error) {
params.SetThreads(runtime.NumCPU()) params.SetThreads(runtime.NumCPU())
// Return new context // Return new context
return NewContext(model, params) return newContext(model, params)
} }

Loading…
Cancel
Save