Fixed go bindings + js

pull/494/head
Sandro Hanea 2 years ago
parent d4ee0266ef
commit aac01503cc

@ -118,12 +118,12 @@ func (context *context) SetMaxTokensPerSegment(n uint) {
// ResetTimings resets the mode timings. Should be called before processing // ResetTimings resets the mode timings. Should be called before processing
func (context *context) ResetTimings() { func (context *context) ResetTimings() {
context.model.ctx.Whisper_reset_timings() context.model.state.Whisper_reset_timings()
} }
// PrintTimings prints the model timings to stdout. // PrintTimings prints the model timings to stdout.
func (context *context) PrintTimings() { func (context *context) PrintTimings() {
context.model.ctx.Whisper_print_timings() context.model.ctx.Whisper_print_timings(context.model.state)
} }
// SystemInfo returns the system information // SystemInfo returns the system information
@ -139,7 +139,7 @@ func (context *context) SystemInfo() string {
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first. // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
// Returns the probabilities of all languages. // Returns the probabilities of all languages.
func (context *context) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]float32, error) { func (context *context) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]float32, error) {
langProbs, err := context.model.ctx.Whisper_lang_auto_detect(offset_ms, n_threads) langProbs, err := context.model.ctx.Whisper_lang_auto_detect(context.model.state, offset_ms, n_threads)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -159,23 +159,23 @@ func (context *context) Process(data []float32, cb SegmentCallback) error {
// We don't do parallel processing at the moment // We don't do parallel processing at the moment
processors := 0 processors := 0
if processors > 1 { if processors > 1 {
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) { if err := context.model.ctx.Whisper_full_parallel(context.model.state, context.params, data, processors, nil, func(new int) {
if cb != nil { if cb != nil {
num_segments := context.model.ctx.Whisper_full_n_segments() num_segments := context.model.state.Whisper_full_n_segments()
s0 := num_segments - new s0 := num_segments - new
for i := s0; i < num_segments; i++ { for i := s0; i < num_segments; i++ {
cb(toSegment(context.model.ctx, i)) cb(toSegment(context.model, i))
} }
} }
}); err != nil { }); err != nil {
return err return err
} }
} else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) { } else if err := context.model.ctx.Whisper_full_with_state( context.model.state, context.params, data, nil, func(new int) {
if cb != nil { if cb != nil {
num_segments := context.model.ctx.Whisper_full_n_segments() num_segments := context.model.state.Whisper_full_n_segments()
s0 := num_segments - new s0 := num_segments - new
for i := s0; i < num_segments; i++ { for i := s0; i < num_segments; i++ {
cb(toSegment(context.model.ctx, i)) cb(toSegment(context.model, i))
} }
} }
}); err != nil { }); err != nil {
@ -188,15 +188,15 @@ func (context *context) Process(data []float32, cb SegmentCallback) error {
// Return the next segment of tokens // Return the next segment of tokens
func (context *context) NextSegment() (Segment, error) { func (context *context) NextSegment() (Segment, error) {
if context.model.ctx == nil { if context.model.state == nil {
return Segment{}, ErrInternalAppError return Segment{}, ErrInternalAppError
} }
if context.n >= context.model.ctx.Whisper_full_n_segments() { if context.n >= context.model.state.Whisper_full_n_segments() {
return Segment{}, io.EOF return Segment{}, io.EOF
} }
// Populate result // Populate result
result := toSegment(context.model.ctx, context.n) result := toSegment(context.model, context.n)
// Increment the cursor // Increment the cursor
context.n++ context.n++
@ -267,23 +267,23 @@ func (context *context) IsLANG(t Token, lang string) bool {
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS // PRIVATE METHODS
func toSegment(ctx *whisper.Context, n int) Segment { func toSegment(model *model, n int) Segment {
return Segment{ return Segment{
Num: n, Num: n,
Text: strings.TrimSpace(ctx.Whisper_full_get_segment_text(n)), Text: strings.TrimSpace(model.state.Whisper_full_get_segment_text(n)),
Start: time.Duration(ctx.Whisper_full_get_segment_t0(n)) * time.Millisecond * 10, Start: time.Duration(model.state.Whisper_full_get_segment_t0(n)) * time.Millisecond * 10,
End: time.Duration(ctx.Whisper_full_get_segment_t1(n)) * time.Millisecond * 10, End: time.Duration(model.state.Whisper_full_get_segment_t1(n)) * time.Millisecond * 10,
Tokens: toTokens(ctx, n), Tokens: toTokens(model, n),
} }
} }
func toTokens(ctx *whisper.Context, n int) []Token { func toTokens(model *model, n int) []Token {
result := make([]Token, ctx.Whisper_full_n_tokens(n)) result := make([]Token, model.state.Whisper_full_n_tokens(n))
for i := 0; i < len(result); i++ { for i := 0; i < len(result); i++ {
result[i] = Token{ result[i] = Token{
Id: int(ctx.Whisper_full_get_token_id(n, i)), Id: int(model.state.Whisper_full_get_token_id(n, i)),
Text: strings.TrimSpace(ctx.Whisper_full_get_token_text(n, i)), Text: strings.TrimSpace(model.ctx.Whisper_full_get_token_text(model.state, n, i)),
P: ctx.Whisper_full_get_token_p(n, i), P: model.state.Whisper_full_get_token_p(n, i),
} }
} }
return result return result

@ -15,6 +15,7 @@ import (
type model struct { type model struct {
path string path string
ctx *whisper.Context ctx *whisper.Context
state *whisper.State
} }
// Make sure model adheres to the interface // Make sure model adheres to the interface
@ -30,6 +31,11 @@ func New(path string) (Model, error) {
} else if ctx := whisper.Whisper_init(path); ctx == nil { } else if ctx := whisper.Whisper_init(path); ctx == nil {
return nil, ErrUnableToLoadModel return nil, ErrUnableToLoadModel
} else { } else {
state := ctx.Whisper_init_state();
if state == nil {
return nil, ErrUnableToLoadModel
}
model.state = state
model.ctx = ctx model.ctx = ctx
model.path = path model.path = path
} }
@ -43,8 +49,13 @@ func (model *model) Close() error {
model.ctx.Whisper_free() model.ctx.Whisper_free()
} }
if model.state != nil {
model.state.Whisper_free_state()
}
// Release resources // Release resources
model.ctx = nil model.ctx = nil
model.state = nil
// Return success // Return success
return nil return nil

@ -20,8 +20,8 @@ extern bool callEncoderBegin(void* user_data);
// Text segment callback // Text segment callback
// Called on every newly generated text segment // Called on every newly generated text segment
// Use the whisper_full_...() functions to obtain the text segments // Use the whisper_full_...() functions to obtain the text segments
static void whisper_new_segment_cb(struct whisper_context* ctx, int n_new, void* user_data) { static void whisper_new_segment_cb(struct whisper_context* ctx, struct whisper_state* state, int n_new, void* user_data) {
if(user_data != NULL && ctx != NULL) { if(user_data != NULL && ctx != NULL && state != NULL) {
callNewSegment(user_data, n_new); callNewSegment(user_data, n_new);
} }
} }
@ -29,8 +29,8 @@ static void whisper_new_segment_cb(struct whisper_context* ctx, int n_new, void*
// Encoder begin callback // Encoder begin callback
// If not NULL, called before the encoder starts // If not NULL, called before the encoder starts
// If it returns false, the computation is aborted // If it returns false, the computation is aborted
static bool whisper_encoder_begin_cb(struct whisper_context* ctx, void* user_data) { static bool whisper_encoder_begin_cb(struct whisper_context* ctx, struct whisper_state* state, void* user_data) {
if(user_data != NULL && ctx != NULL) { if(user_data != NULL && ctx != NULL && state != NULL) {
return callEncoderBegin(user_data); return callEncoderBegin(user_data);
} }
return false; return false;
@ -53,6 +53,7 @@ import "C"
type ( type (
Context C.struct_whisper_context Context C.struct_whisper_context
State C.struct_whisper_state
Token C.whisper_token Token C.whisper_token
TokenData C.struct_whisper_token_data TokenData C.struct_whisper_token_data
SamplingStrategy C.enum_whisper_sampling_strategy SamplingStrategy C.enum_whisper_sampling_strategy
@ -98,15 +99,28 @@ func Whisper_init(path string) *Context {
} }
} }
func (ctx *Context) Whisper_init_state() *State {
state := C.whisper_init_state((*C.struct_whisper_context)(ctx))
if state != nil {
return (*State)(state)
} else {
return nil
}
}
// Frees all memory allocated by the model. // Frees all memory allocated by the model.
func (ctx *Context) Whisper_free() { func (ctx *Context) Whisper_free() {
C.whisper_free((*C.struct_whisper_context)(ctx)) C.whisper_free((*C.struct_whisper_context)(ctx))
} }
func (state *State) Whisper_free_state() {
C.whisper_free_state((*C.struct_whisper_state)(state))
}
// Convert RAW PCM audio to log mel spectrogram. // Convert RAW PCM audio to log mel spectrogram.
// The resulting spectrogram is stored inside the provided whisper context. // The resulting spectrogram is stored inside the provided whisper context.
func (ctx *Context) Whisper_pcm_to_mel(data []float32, threads int) error { func (ctx *Context) Whisper_pcm_to_mel(state *State, data []float32, threads int) error {
if C.whisper_pcm_to_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(threads)) == 0 { if C.whisper_pcm_to_mel((*C.struct_whisper_context)(ctx),(*C.struct_whisper_state)(state), (*C.float)(&data[0]), C.int(len(data)), C.int(threads)) == 0 {
return nil return nil
} else { } else {
return ErrConversionFailed return ErrConversionFailed
@ -116,8 +130,8 @@ func (ctx *Context) Whisper_pcm_to_mel(data []float32, threads int) error {
// This can be used to set a custom log mel spectrogram inside the provided whisper context. // This can be used to set a custom log mel spectrogram inside the provided whisper context.
// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram. // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
// n_mel must be 80 // n_mel must be 80
func (ctx *Context) Whisper_set_mel(data []float32, n_mel int) error { func (state *State) Whisper_set_mel(data []float32, n_mel int) error {
if C.whisper_set_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(n_mel)) == 0 { if C.whisper_set_mel((*C.struct_whisper_state)(state), (*C.float)(&data[0]), C.int(len(data)), C.int(n_mel)) == 0 {
return nil return nil
} else { } else {
return ErrConversionFailed return ErrConversionFailed
@ -127,8 +141,8 @@ func (ctx *Context) Whisper_set_mel(data []float32, n_mel int) error {
// Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context. // Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context.
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first. // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
// offset can be used to specify the offset of the first frame in the spectrogram. // offset can be used to specify the offset of the first frame in the spectrogram.
func (ctx *Context) Whisper_encode(offset, threads int) error { func (ctx *Context) Whisper_encode(state *State, offset, threads int) error {
if C.whisper_encode((*C.struct_whisper_context)(ctx), C.int(offset), C.int(threads)) == 0 { if C.whisper_encode((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), C.int(offset), C.int(threads)) == 0 {
return nil return nil
} else { } else {
return ErrConversionFailed return ErrConversionFailed
@ -139,8 +153,8 @@ func (ctx *Context) Whisper_encode(offset, threads int) error {
// Make sure to call whisper_encode() first. // Make sure to call whisper_encode() first.
// tokens + n_tokens is the provided context for the decoder. // tokens + n_tokens is the provided context for the decoder.
// n_past is the number of tokens to use from previous decoder calls. // n_past is the number of tokens to use from previous decoder calls.
func (ctx *Context) Whisper_decode(tokens []Token, past, threads int) error { func (ctx *Context) Whisper_decode(state *State, tokens []Token, past, threads int) error {
if C.whisper_decode((*C.struct_whisper_context)(ctx), (*C.whisper_token)(&tokens[0]), C.int(len(tokens)), C.int(past), C.int(threads)) == 0 { if C.whisper_decode((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), (*C.whisper_token)(&tokens[0]), C.int(len(tokens)), C.int(past), C.int(threads)) == 0 {
return nil return nil
} else { } else {
return ErrConversionFailed return ErrConversionFailed
@ -183,17 +197,17 @@ func Whisper_lang_str(id int) string {
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first. // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
// Returns the probabilities of all languages. // Returns the probabilities of all languages.
// ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69 // ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
func (ctx *Context) Whisper_lang_auto_detect(offset_ms, n_threads int) ([]float32, error) { func (ctx *Context) Whisper_lang_auto_detect(state *State, offset_ms, n_threads int) ([]float32, error) {
probs := make([]float32, Whisper_lang_max_id()+1) probs := make([]float32, Whisper_lang_max_id()+1)
if n := int(C.whisper_lang_auto_detect((*C.struct_whisper_context)(ctx), C.int(offset_ms), C.int(n_threads), (*C.float)(&probs[0]))); n < 0 { if n := int(C.whisper_lang_auto_detect((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), C.int(offset_ms), C.int(n_threads), (*C.float)(&probs[0]))); n < 0 {
return nil, ErrAutoDetectFailed return nil, ErrAutoDetectFailed
} else { } else {
return probs, nil return probs, nil
} }
} }
func (ctx *Context) Whisper_n_len() int { func (state *State) Whisper_n_len() int {
return int(C.whisper_n_len((*C.struct_whisper_context)(ctx))) return int(C.whisper_n_len((*C.struct_whisper_state)(state)))
} }
func (ctx *Context) Whisper_n_vocab() int { func (ctx *Context) Whisper_n_vocab() int {
@ -268,13 +282,13 @@ func Whisper_token_transcribe() Token {
} }
// Performance information // Performance information
func (ctx *Context) Whisper_print_timings() { func (ctx *Context) Whisper_print_timings(state *State) {
C.whisper_print_timings((*C.struct_whisper_context)(ctx)) C.whisper_print_timings((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state))
} }
// Performance information // Performance information
func (ctx *Context) Whisper_reset_timings() { func (state *State) Whisper_reset_timings() {
C.whisper_reset_timings((*C.struct_whisper_context)(ctx)) C.whisper_reset_timings((*C.struct_whisper_state)(state))
} }
// Print system information // Print system information
@ -302,16 +316,30 @@ func (ctx *Context) Whisper_full(params Params, samples []float32, encoderBeginC
} }
} }
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text with the given state
// Uses the specified decoding strategy to obtain the text.
func (ctx *Context) Whisper_full_with_state(state *State, params Params, samples []float32, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
registerEncoderBeginCallback(ctx, encoderBeginCallback)
registerNewSegmentCallback(ctx, newSegmentCallback)
defer registerEncoderBeginCallback(ctx, nil)
defer registerNewSegmentCallback(ctx, nil)
if C.whisper_full_with_state((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples))) == 0 {
return nil
} else {
return ErrConversionFailed
}
}
// Split the input audio in chunks and process each chunk separately using whisper_full() // Split the input audio in chunks and process each chunk separately using whisper_full()
// It seems this approach can offer some speedup in some cases. // It seems this approach can offer some speedup in some cases.
// However, the transcription accuracy can be worse at the beginning and end of each chunk. // However, the transcription accuracy can be worse at the beginning and end of each chunk.
func (ctx *Context) Whisper_full_parallel(params Params, samples []float32, processors int, encoderBeginCallback func() bool, newSegmentCallback func(int)) error { func (ctx *Context) Whisper_full_parallel(state *State, params Params, samples []float32, processors int, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
registerEncoderBeginCallback(ctx, encoderBeginCallback) registerEncoderBeginCallback(ctx, encoderBeginCallback)
registerNewSegmentCallback(ctx, newSegmentCallback) registerNewSegmentCallback(ctx, newSegmentCallback)
defer registerEncoderBeginCallback(ctx, nil) defer registerEncoderBeginCallback(ctx, nil)
defer registerNewSegmentCallback(ctx, nil) defer registerNewSegmentCallback(ctx, nil)
if C.whisper_full_parallel((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples)), C.int(processors)) == 0 { if C.whisper_full_parallel((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples)), C.int(processors)) == 0 {
return nil return nil
} else { } else {
return ErrConversionFailed return ErrConversionFailed
@ -320,49 +348,49 @@ func (ctx *Context) Whisper_full_parallel(params Params, samples []float32, proc
// Number of generated text segments. // Number of generated text segments.
// A segment can be a few words, a sentence, or even a paragraph. // A segment can be a few words, a sentence, or even a paragraph.
func (ctx *Context) Whisper_full_n_segments() int { func (state *State) Whisper_full_n_segments() int {
return int(C.whisper_full_n_segments((*C.struct_whisper_context)(ctx))) return int(C.whisper_full_n_segments((*C.struct_whisper_state)(state)))
} }
// Get the start and end time of the specified segment. // Get the start and end time of the specified segment.
func (ctx *Context) Whisper_full_get_segment_t0(segment int) int64 { func (state *State) Whisper_full_get_segment_t0(segment int) int64 {
return int64(C.whisper_full_get_segment_t0((*C.struct_whisper_context)(ctx), C.int(segment))) return int64(C.whisper_full_get_segment_t0((*C.struct_whisper_state)(state), C.int(segment)))
} }
// Get the start and end time of the specified segment. // Get the start and end time of the specified segment.
func (ctx *Context) Whisper_full_get_segment_t1(segment int) int64 { func (state *State) Whisper_full_get_segment_t1(segment int) int64 {
return int64(C.whisper_full_get_segment_t1((*C.struct_whisper_context)(ctx), C.int(segment))) return int64(C.whisper_full_get_segment_t1((*C.struct_whisper_state)(state), C.int(segment)))
} }
// Get the text of the specified segment. // Get the text of the specified segment.
func (ctx *Context) Whisper_full_get_segment_text(segment int) string { func (state *State) Whisper_full_get_segment_text(segment int) string {
return C.GoString(C.whisper_full_get_segment_text((*C.struct_whisper_context)(ctx), C.int(segment))) return C.GoString(C.whisper_full_get_segment_text((*C.struct_whisper_state)(state), C.int(segment)))
} }
// Get number of tokens in the specified segment. // Get number of tokens in the specified segment.
func (ctx *Context) Whisper_full_n_tokens(segment int) int { func (state *State) Whisper_full_n_tokens(segment int) int {
return int(C.whisper_full_n_tokens((*C.struct_whisper_context)(ctx), C.int(segment))) return int(C.whisper_full_n_tokens((*C.struct_whisper_state)(state), C.int(segment)))
} }
// Get the token text of the specified token index in the specified segment. // Get the token text of the specified token index in the specified segment.
func (ctx *Context) Whisper_full_get_token_text(segment int, token int) string { func (ctx *Context) Whisper_full_get_token_text(state *State, segment int, token int) string {
return C.GoString(C.whisper_full_get_token_text((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token))) return C.GoString(C.whisper_full_get_token_text((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), C.int(segment), C.int(token)))
} }
// Get the token of the specified token index in the specified segment. // Get the token of the specified token index in the specified segment.
func (ctx *Context) Whisper_full_get_token_id(segment int, token int) Token { func (state *State) Whisper_full_get_token_id(segment int, token int) Token {
return Token(C.whisper_full_get_token_id((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token))) return Token(C.whisper_full_get_token_id((*C.struct_whisper_state)(state), C.int(segment), C.int(token)))
} }
// Get token data for the specified token in the specified segment. // Get token data for the specified token in the specified segment.
// This contains probabilities, timestamps, etc. // This contains probabilities, timestamps, etc.
func (ctx *Context) whisper_full_get_token_data(segment int, token int) TokenData { func (state *State) whisper_full_get_token_data(segment int, token int) TokenData {
return TokenData(C.whisper_full_get_token_data((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token))) return TokenData(C.whisper_full_get_token_data((*C.struct_whisper_state)(state), C.int(segment), C.int(token)))
} }
// Get the probability of the specified token in the specified segment. // Get the probability of the specified token in the specified segment.
func (ctx *Context) Whisper_full_get_token_p(segment int, token int) float32 { func (state *State) Whisper_full_get_token_p(segment int, token int) float32 {
return float32(C.whisper_full_get_token_p((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token))) return float32(C.whisper_full_get_token_p((*C.struct_whisper_state)(state), C.int(segment), C.int(token)))
} }
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////

@ -49,20 +49,23 @@ func Test_Whisper_001(t *testing.T) {
// Run whisper // Run whisper
ctx := whisper.Whisper_init(ModelPath) ctx := whisper.Whisper_init(ModelPath)
assert.NotNil(ctx) assert.NotNil(ctx)
state := ctx.Whisper_init_state()
assert.NotNil(state)
defer state.Whisper_free_state()
defer ctx.Whisper_free() defer ctx.Whisper_free()
params := ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY) params := ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
data := buf.AsFloat32Buffer().Data data := buf.AsFloat32Buffer().Data
err = ctx.Whisper_full(params, data, nil, nil) err = ctx.Whisper_full_with_state(state, params, data, nil, nil)
assert.NoError(err) assert.NoError(err)
// Print out tokens // Print out tokens
num_segments := ctx.Whisper_full_n_segments() num_segments := state.Whisper_full_n_segments()
assert.GreaterOrEqual(num_segments, 1) assert.GreaterOrEqual(num_segments, 1)
for i := 0; i < num_segments; i++ { for i := 0; i < num_segments; i++ {
str := ctx.Whisper_full_get_segment_text(i) str := state.Whisper_full_get_segment_text(i)
assert.NotEmpty(str) assert.NotEmpty(str)
t0 := time.Duration(ctx.Whisper_full_get_segment_t0(i)) * time.Millisecond t0 := time.Duration(state.Whisper_full_get_segment_t0(i)) * time.Millisecond
t1 := time.Duration(ctx.Whisper_full_get_segment_t1(i)) * time.Millisecond t1 := time.Duration(state.Whisper_full_get_segment_t1(i)) * time.Millisecond
t.Logf("[%6s->%-6s] %q", t0, t1, str) t.Logf("[%6s->%-6s] %q", t0, t1, str)
} }
} }
@ -101,11 +104,16 @@ func Test_Whisper_003(t *testing.T) {
assert.NotNil(ctx) assert.NotNil(ctx)
defer ctx.Whisper_free() defer ctx.Whisper_free()
// Create the state for manual transformation
state := ctx.Whisper_init_state()
assert.NotNil(state)
defer state.Whisper_free_state()
// Get MEL // Get MEL
assert.NoError(ctx.Whisper_pcm_to_mel(buf.AsFloat32Buffer().Data, threads)) assert.NoError(ctx.Whisper_pcm_to_mel(state, buf.AsFloat32Buffer().Data, threads))
// Get Languages // Get Languages
languages, err := ctx.Whisper_lang_auto_detect(0, threads) languages, err := ctx.Whisper_lang_auto_detect(state, 0, threads)
assert.NoError(err) assert.NoError(err)
for i, p := range languages { for i, p := range languages {
t.Logf("%s: %f", whisper.Whisper_lang_str(i), p) t.Logf("%s: %f", whisper.Whisper_lang_str(i), p)

Loading…
Cancel
Save