diff --git a/bindings/go/pkg/whisper/context.go b/bindings/go/pkg/whisper/context.go index 0a6e9cb..a1a0222 100644 --- a/bindings/go/pkg/whisper/context.go +++ b/bindings/go/pkg/whisper/context.go @@ -118,12 +118,12 @@ func (context *context) SetMaxTokensPerSegment(n uint) { // ResetTimings resets the mode timings. Should be called before processing func (context *context) ResetTimings() { - context.model.ctx.Whisper_reset_timings() + context.model.state.Whisper_reset_timings() } // PrintTimings prints the model timings to stdout. func (context *context) PrintTimings() { - context.model.ctx.Whisper_print_timings() + context.model.ctx.Whisper_print_timings(context.model.state) } // SystemInfo returns the system information @@ -139,7 +139,7 @@ func (context *context) SystemInfo() string { // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first. // Returns the probabilities of all languages. func (context *context) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]float32, error) { - langProbs, err := context.model.ctx.Whisper_lang_auto_detect(offset_ms, n_threads) + langProbs, err := context.model.ctx.Whisper_lang_auto_detect(context.model.state, offset_ms, n_threads) if err != nil { return nil, err } @@ -159,23 +159,23 @@ func (context *context) Process(data []float32, cb SegmentCallback) error { // We don't do parallel processing at the moment processors := 0 if processors > 1 { - if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) { + if err := context.model.ctx.Whisper_full_parallel(context.model.state, context.params, data, processors, nil, func(new int) { if cb != nil { - num_segments := context.model.ctx.Whisper_full_n_segments() + num_segments := context.model.state.Whisper_full_n_segments() s0 := num_segments - new for i := s0; i < num_segments; i++ { - cb(toSegment(context.model.ctx, i)) + cb(toSegment(context.model, i)) } } }); err != nil { return err } - } else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) { + } else if err := context.model.ctx.Whisper_full_with_state( context.model.state, context.params, data, nil, func(new int) { if cb != nil { - num_segments := context.model.ctx.Whisper_full_n_segments() + num_segments := context.model.state.Whisper_full_n_segments() s0 := num_segments - new for i := s0; i < num_segments; i++ { - cb(toSegment(context.model.ctx, i)) + cb(toSegment(context.model, i)) } } }); err != nil { @@ -188,15 +188,15 @@ func (context *context) Process(data []float32, cb SegmentCallback) error { // Return the next segment of tokens func (context *context) NextSegment() (Segment, error) { - if context.model.ctx == nil { + if context.model.state == nil { return Segment{}, ErrInternalAppError } - if context.n >= context.model.ctx.Whisper_full_n_segments() { + if context.n >= context.model.state.Whisper_full_n_segments() { return Segment{}, io.EOF } // Populate result - result := toSegment(context.model.ctx, context.n) + result := toSegment(context.model, context.n) // Increment the cursor context.n++ @@ -267,23 +267,23 @@ func (context *context) IsLANG(t Token, lang string) bool { /////////////////////////////////////////////////////////////////////////////// // PRIVATE METHODS -func toSegment(ctx *whisper.Context, n int) Segment { +func toSegment(model *model, n int) Segment { return Segment{ Num: n, - Text: strings.TrimSpace(ctx.Whisper_full_get_segment_text(n)), - Start: time.Duration(ctx.Whisper_full_get_segment_t0(n)) * time.Millisecond * 10, - End: time.Duration(ctx.Whisper_full_get_segment_t1(n)) * time.Millisecond * 10, - Tokens: toTokens(ctx, n), + Text: strings.TrimSpace(model.state.Whisper_full_get_segment_text(n)), + Start: time.Duration(model.state.Whisper_full_get_segment_t0(n)) * time.Millisecond * 10, + End: time.Duration(model.state.Whisper_full_get_segment_t1(n)) * time.Millisecond * 10, + Tokens: toTokens(model, n), } } -func toTokens(ctx *whisper.Context, n int) []Token { - result := make([]Token, ctx.Whisper_full_n_tokens(n)) +func toTokens(model *model, n int) []Token { + result := make([]Token, model.state.Whisper_full_n_tokens(n)) for i := 0; i < len(result); i++ { result[i] = Token{ - Id: int(ctx.Whisper_full_get_token_id(n, i)), - Text: strings.TrimSpace(ctx.Whisper_full_get_token_text(n, i)), - P: ctx.Whisper_full_get_token_p(n, i), + Id: int(model.state.Whisper_full_get_token_id(n, i)), + Text: strings.TrimSpace(model.ctx.Whisper_full_get_token_text(model.state, n, i)), + P: model.state.Whisper_full_get_token_p(n, i), } } return result diff --git a/bindings/go/pkg/whisper/model.go b/bindings/go/pkg/whisper/model.go index 94c2197..6f80dc9 100644 --- a/bindings/go/pkg/whisper/model.go +++ b/bindings/go/pkg/whisper/model.go @@ -15,6 +15,7 @@ import ( type model struct { path string ctx *whisper.Context + state *whisper.State } // Make sure model adheres to the interface @@ -30,6 +31,11 @@ func New(path string) (Model, error) { } else if ctx := whisper.Whisper_init(path); ctx == nil { return nil, ErrUnableToLoadModel } else { + state := ctx.Whisper_init_state(); + if state == nil { + return nil, ErrUnableToLoadModel + } + model.state = state model.ctx = ctx model.path = path } @@ -43,8 +49,13 @@ func (model *model) Close() error { model.ctx.Whisper_free() } + if model.state != nil { + model.state.Whisper_free_state() + } + // Release resources model.ctx = nil + model.state = nil // Return success return nil diff --git a/bindings/go/whisper.go b/bindings/go/whisper.go index 78ca07d..e0dec0d 100644 --- a/bindings/go/whisper.go +++ b/bindings/go/whisper.go @@ -20,8 +20,8 @@ extern bool callEncoderBegin(void* user_data); // Text segment callback // Called on every newly generated text segment // Use the whisper_full_...() functions to obtain the text segments -static void whisper_new_segment_cb(struct whisper_context* ctx, int n_new, void* user_data) { - if(user_data != NULL && ctx != NULL) { +static void whisper_new_segment_cb(struct whisper_context* ctx, struct whisper_state* state, int n_new, void* user_data) { + if(user_data != NULL && ctx != NULL && state != NULL) { callNewSegment(user_data, n_new); } } @@ -29,8 +29,8 @@ static void whisper_new_segment_cb(struct whisper_context* ctx, int n_new, void* // Encoder begin callback // If not NULL, called before the encoder starts // If it returns false, the computation is aborted -static bool whisper_encoder_begin_cb(struct whisper_context* ctx, void* user_data) { - if(user_data != NULL && ctx != NULL) { +static bool whisper_encoder_begin_cb(struct whisper_context* ctx, struct whisper_state* state, void* user_data) { + if(user_data != NULL && ctx != NULL && state != NULL) { return callEncoderBegin(user_data); } return false; @@ -53,6 +53,7 @@ import "C" type ( Context C.struct_whisper_context + State C.struct_whisper_state Token C.whisper_token TokenData C.struct_whisper_token_data SamplingStrategy C.enum_whisper_sampling_strategy @@ -98,15 +99,28 @@ func Whisper_init(path string) *Context { } } +func (ctx *Context) Whisper_init_state() *State { + state := C.whisper_init_state((*C.struct_whisper_context)(ctx)) + if state != nil { + return (*State)(state) + } else { + return nil + } +} + // Frees all memory allocated by the model. func (ctx *Context) Whisper_free() { C.whisper_free((*C.struct_whisper_context)(ctx)) } +func (state *State) Whisper_free_state() { + C.whisper_free_state((*C.struct_whisper_state)(state)) +} + // Convert RAW PCM audio to log mel spectrogram. // The resulting spectrogram is stored inside the provided whisper context. -func (ctx *Context) Whisper_pcm_to_mel(data []float32, threads int) error { - if C.whisper_pcm_to_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(threads)) == 0 { +func (ctx *Context) Whisper_pcm_to_mel(state *State, data []float32, threads int) error { + if C.whisper_pcm_to_mel((*C.struct_whisper_context)(ctx),(*C.struct_whisper_state)(state), (*C.float)(&data[0]), C.int(len(data)), C.int(threads)) == 0 { return nil } else { return ErrConversionFailed @@ -116,8 +130,8 @@ func (ctx *Context) Whisper_pcm_to_mel(data []float32, threads int) error { // This can be used to set a custom log mel spectrogram inside the provided whisper context. // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram. // n_mel must be 80 -func (ctx *Context) Whisper_set_mel(data []float32, n_mel int) error { - if C.whisper_set_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(n_mel)) == 0 { +func (state *State) Whisper_set_mel(data []float32, n_mel int) error { + if C.whisper_set_mel((*C.struct_whisper_state)(state), (*C.float)(&data[0]), C.int(len(data)), C.int(n_mel)) == 0 { return nil } else { return ErrConversionFailed @@ -127,8 +141,8 @@ func (ctx *Context) Whisper_set_mel(data []float32, n_mel int) error { // Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context. // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first. // offset can be used to specify the offset of the first frame in the spectrogram. -func (ctx *Context) Whisper_encode(offset, threads int) error { - if C.whisper_encode((*C.struct_whisper_context)(ctx), C.int(offset), C.int(threads)) == 0 { +func (ctx *Context) Whisper_encode(state *State, offset, threads int) error { + if C.whisper_encode((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), C.int(offset), C.int(threads)) == 0 { return nil } else { return ErrConversionFailed @@ -139,8 +153,8 @@ func (ctx *Context) Whisper_encode(offset, threads int) error { // Make sure to call whisper_encode() first. // tokens + n_tokens is the provided context for the decoder. // n_past is the number of tokens to use from previous decoder calls. -func (ctx *Context) Whisper_decode(tokens []Token, past, threads int) error { - if C.whisper_decode((*C.struct_whisper_context)(ctx), (*C.whisper_token)(&tokens[0]), C.int(len(tokens)), C.int(past), C.int(threads)) == 0 { +func (ctx *Context) Whisper_decode(state *State, tokens []Token, past, threads int) error { + if C.whisper_decode((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), (*C.whisper_token)(&tokens[0]), C.int(len(tokens)), C.int(past), C.int(threads)) == 0 { return nil } else { return ErrConversionFailed @@ -183,17 +197,17 @@ func Whisper_lang_str(id int) string { // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first. // Returns the probabilities of all languages. // ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69 -func (ctx *Context) Whisper_lang_auto_detect(offset_ms, n_threads int) ([]float32, error) { +func (ctx *Context) Whisper_lang_auto_detect(state *State, offset_ms, n_threads int) ([]float32, error) { probs := make([]float32, Whisper_lang_max_id()+1) - if n := int(C.whisper_lang_auto_detect((*C.struct_whisper_context)(ctx), C.int(offset_ms), C.int(n_threads), (*C.float)(&probs[0]))); n < 0 { + if n := int(C.whisper_lang_auto_detect((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), C.int(offset_ms), C.int(n_threads), (*C.float)(&probs[0]))); n < 0 { return nil, ErrAutoDetectFailed } else { return probs, nil } } -func (ctx *Context) Whisper_n_len() int { - return int(C.whisper_n_len((*C.struct_whisper_context)(ctx))) +func (state *State) Whisper_n_len() int { + return int(C.whisper_n_len((*C.struct_whisper_state)(state))) } func (ctx *Context) Whisper_n_vocab() int { @@ -268,13 +282,13 @@ func Whisper_token_transcribe() Token { } // Performance information -func (ctx *Context) Whisper_print_timings() { - C.whisper_print_timings((*C.struct_whisper_context)(ctx)) +func (ctx *Context) Whisper_print_timings(state *State) { + C.whisper_print_timings((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state)) } // Performance information -func (ctx *Context) Whisper_reset_timings() { - C.whisper_reset_timings((*C.struct_whisper_context)(ctx)) +func (state *State) Whisper_reset_timings() { + C.whisper_reset_timings((*C.struct_whisper_state)(state)) } // Print system information @@ -302,16 +316,30 @@ func (ctx *Context) Whisper_full(params Params, samples []float32, encoderBeginC } } +// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text with the given state +// Uses the specified decoding strategy to obtain the text. +func (ctx *Context) Whisper_full_with_state(state *State, params Params, samples []float32, encoderBeginCallback func() bool, newSegmentCallback func(int)) error { + registerEncoderBeginCallback(ctx, encoderBeginCallback) + registerNewSegmentCallback(ctx, newSegmentCallback) + defer registerEncoderBeginCallback(ctx, nil) + defer registerNewSegmentCallback(ctx, nil) + if C.whisper_full_with_state((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples))) == 0 { + return nil + } else { + return ErrConversionFailed + } +} + // Split the input audio in chunks and process each chunk separately using whisper_full() // It seems this approach can offer some speedup in some cases. // However, the transcription accuracy can be worse at the beginning and end of each chunk. -func (ctx *Context) Whisper_full_parallel(params Params, samples []float32, processors int, encoderBeginCallback func() bool, newSegmentCallback func(int)) error { +func (ctx *Context) Whisper_full_parallel(state *State, params Params, samples []float32, processors int, encoderBeginCallback func() bool, newSegmentCallback func(int)) error { registerEncoderBeginCallback(ctx, encoderBeginCallback) registerNewSegmentCallback(ctx, newSegmentCallback) defer registerEncoderBeginCallback(ctx, nil) defer registerNewSegmentCallback(ctx, nil) - if C.whisper_full_parallel((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples)), C.int(processors)) == 0 { + if C.whisper_full_parallel((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples)), C.int(processors)) == 0 { return nil } else { return ErrConversionFailed @@ -320,49 +348,49 @@ func (ctx *Context) Whisper_full_parallel(params Params, samples []float32, proc // Number of generated text segments. // A segment can be a few words, a sentence, or even a paragraph. -func (ctx *Context) Whisper_full_n_segments() int { - return int(C.whisper_full_n_segments((*C.struct_whisper_context)(ctx))) +func (state *State) Whisper_full_n_segments() int { + return int(C.whisper_full_n_segments((*C.struct_whisper_state)(state))) } // Get the start and end time of the specified segment. -func (ctx *Context) Whisper_full_get_segment_t0(segment int) int64 { - return int64(C.whisper_full_get_segment_t0((*C.struct_whisper_context)(ctx), C.int(segment))) +func (state *State) Whisper_full_get_segment_t0(segment int) int64 { + return int64(C.whisper_full_get_segment_t0((*C.struct_whisper_state)(state), C.int(segment))) } // Get the start and end time of the specified segment. -func (ctx *Context) Whisper_full_get_segment_t1(segment int) int64 { - return int64(C.whisper_full_get_segment_t1((*C.struct_whisper_context)(ctx), C.int(segment))) +func (state *State) Whisper_full_get_segment_t1(segment int) int64 { + return int64(C.whisper_full_get_segment_t1((*C.struct_whisper_state)(state), C.int(segment))) } // Get the text of the specified segment. -func (ctx *Context) Whisper_full_get_segment_text(segment int) string { - return C.GoString(C.whisper_full_get_segment_text((*C.struct_whisper_context)(ctx), C.int(segment))) +func (state *State) Whisper_full_get_segment_text(segment int) string { + return C.GoString(C.whisper_full_get_segment_text((*C.struct_whisper_state)(state), C.int(segment))) } // Get number of tokens in the specified segment. -func (ctx *Context) Whisper_full_n_tokens(segment int) int { - return int(C.whisper_full_n_tokens((*C.struct_whisper_context)(ctx), C.int(segment))) +func (state *State) Whisper_full_n_tokens(segment int) int { + return int(C.whisper_full_n_tokens((*C.struct_whisper_state)(state), C.int(segment))) } // Get the token text of the specified token index in the specified segment. -func (ctx *Context) Whisper_full_get_token_text(segment int, token int) string { - return C.GoString(C.whisper_full_get_token_text((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token))) +func (ctx *Context) Whisper_full_get_token_text(state *State, segment int, token int) string { + return C.GoString(C.whisper_full_get_token_text((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), C.int(segment), C.int(token))) } // Get the token of the specified token index in the specified segment. -func (ctx *Context) Whisper_full_get_token_id(segment int, token int) Token { - return Token(C.whisper_full_get_token_id((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token))) +func (state *State) Whisper_full_get_token_id(segment int, token int) Token { + return Token(C.whisper_full_get_token_id((*C.struct_whisper_state)(state), C.int(segment), C.int(token))) } // Get token data for the specified token in the specified segment. // This contains probabilities, timestamps, etc. -func (ctx *Context) whisper_full_get_token_data(segment int, token int) TokenData { - return TokenData(C.whisper_full_get_token_data((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token))) +func (state *State) whisper_full_get_token_data(segment int, token int) TokenData { + return TokenData(C.whisper_full_get_token_data((*C.struct_whisper_state)(state), C.int(segment), C.int(token))) } // Get the probability of the specified token in the specified segment. -func (ctx *Context) Whisper_full_get_token_p(segment int, token int) float32 { - return float32(C.whisper_full_get_token_p((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token))) +func (state *State) Whisper_full_get_token_p(segment int, token int) float32 { + return float32(C.whisper_full_get_token_p((*C.struct_whisper_state)(state), C.int(segment), C.int(token))) } /////////////////////////////////////////////////////////////////////////////// diff --git a/bindings/go/whisper_test.go b/bindings/go/whisper_test.go index 2c95c81..de29cbe 100644 --- a/bindings/go/whisper_test.go +++ b/bindings/go/whisper_test.go @@ -49,20 +49,23 @@ func Test_Whisper_001(t *testing.T) { // Run whisper ctx := whisper.Whisper_init(ModelPath) assert.NotNil(ctx) + state := ctx.Whisper_init_state() + assert.NotNil(state) + defer state.Whisper_free_state() defer ctx.Whisper_free() params := ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY) data := buf.AsFloat32Buffer().Data - err = ctx.Whisper_full(params, data, nil, nil) + err = ctx.Whisper_full_with_state(state, params, data, nil, nil) assert.NoError(err) // Print out tokens - num_segments := ctx.Whisper_full_n_segments() + num_segments := state.Whisper_full_n_segments() assert.GreaterOrEqual(num_segments, 1) for i := 0; i < num_segments; i++ { - str := ctx.Whisper_full_get_segment_text(i) + str := state.Whisper_full_get_segment_text(i) assert.NotEmpty(str) - t0 := time.Duration(ctx.Whisper_full_get_segment_t0(i)) * time.Millisecond - t1 := time.Duration(ctx.Whisper_full_get_segment_t1(i)) * time.Millisecond + t0 := time.Duration(state.Whisper_full_get_segment_t0(i)) * time.Millisecond + t1 := time.Duration(state.Whisper_full_get_segment_t1(i)) * time.Millisecond t.Logf("[%6s->%-6s] %q", t0, t1, str) } } @@ -101,11 +104,16 @@ func Test_Whisper_003(t *testing.T) { assert.NotNil(ctx) defer ctx.Whisper_free() + // Create the state for manual transformation + state := ctx.Whisper_init_state() + assert.NotNil(state) + defer state.Whisper_free_state() + // Get MEL - assert.NoError(ctx.Whisper_pcm_to_mel(buf.AsFloat32Buffer().Data, threads)) + assert.NoError(ctx.Whisper_pcm_to_mel(state, buf.AsFloat32Buffer().Data, threads)) // Get Languages - languages, err := ctx.Whisper_lang_auto_detect(0, threads) + languages, err := ctx.Whisper_lang_auto_detect(state, 0, threads) assert.NoError(err) for i, p := range languages { t.Logf("%s: %f", whisper.Whisper_lang_str(i), p)