You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
252 lines
6.4 KiB
252 lines
6.4 KiB
package whisper
|
|
|
|
import (
|
|
"io"
|
|
"strings"
|
|
"time"
|
|
|
|
// Bindings
|
|
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
|
)
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
// TYPES
|
|
|
|
type context struct {
|
|
n int
|
|
model *model
|
|
params whisper.Params
|
|
}
|
|
|
|
// Make sure context adheres to the interface
|
|
var _ Context = (*context)(nil)
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
// LIFECYCLE
|
|
|
|
func newContext(model *model, params whisper.Params) (Context, error) {
|
|
context := new(context)
|
|
context.model = model
|
|
context.params = params
|
|
|
|
// Return success
|
|
return context, nil
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
// PUBLIC METHODS
|
|
|
|
// Set the language to use for speech recognition.
|
|
func (context *context) SetLanguage(lang string) error {
|
|
if context.model.ctx == nil {
|
|
return ErrInternalAppError
|
|
}
|
|
if !context.model.IsMultilingual() {
|
|
return ErrModelNotMultilingual
|
|
}
|
|
if id := context.model.ctx.Whisper_lang_id(lang); id < 0 {
|
|
return ErrUnsupportedLanguage
|
|
} else if err := context.params.SetLanguage(id); err != nil {
|
|
return err
|
|
}
|
|
// Return success
|
|
return nil
|
|
}
|
|
|
|
func (context *context) IsMultilingual() bool {
|
|
return context.model.IsMultilingual()
|
|
}
|
|
|
|
// Get language
|
|
func (context *context) Language() string {
|
|
return whisper.Whisper_lang_str(context.params.Language())
|
|
}
|
|
|
|
// Set translate flag
|
|
func (context *context) SetTranslate(v bool) {
|
|
context.params.SetTranslate(v)
|
|
}
|
|
|
|
// Set speedup flag
|
|
func (context *context) SetSpeedup(v bool) {
|
|
context.params.SetSpeedup(v)
|
|
}
|
|
|
|
// Set number of threads to use
|
|
func (context *context) SetThreads(v uint) {
|
|
context.params.SetThreads(int(v))
|
|
}
|
|
|
|
// Set time offset
|
|
func (context *context) SetOffset(v time.Duration) {
|
|
context.params.SetOffset(int(v.Milliseconds()))
|
|
}
|
|
|
|
// Set duration of audio to process
|
|
func (context *context) SetDuration(v time.Duration) {
|
|
context.params.SetOffset(int(v.Milliseconds()))
|
|
}
|
|
|
|
// Set timestamp token probability threshold (~0.01)
|
|
func (context *context) SetTokenThreshold(t float32) {
|
|
context.params.SetTokenThreshold(t)
|
|
}
|
|
|
|
// Set timestamp token sum probability threshold (~0.01)
|
|
func (context *context) SetTokenSumThreshold(t float32) {
|
|
context.params.SetTokenSumThreshold(t)
|
|
}
|
|
|
|
// Set max segment length in characters
|
|
func (context *context) SetMaxSegmentLength(n uint) {
|
|
context.params.SetMaxSegmentLength(int(n))
|
|
}
|
|
|
|
// Set max tokens per segment (0 = no limit)
|
|
func (context *context) SetMaxTokensPerSegment(n uint) {
|
|
context.params.SetMaxTokensPerSegment(int(n))
|
|
}
|
|
|
|
// Process new sample data and return any errors
|
|
func (context *context) Process(data []float32, cb SegmentCallback) error {
|
|
if context.model.ctx == nil {
|
|
return ErrInternalAppError
|
|
}
|
|
// If the callback is defined then we force on single_segment mode
|
|
if cb != nil {
|
|
context.params.SetSingleSegment(true)
|
|
}
|
|
|
|
// We don't do parallel processing at the moment
|
|
processors := 0
|
|
if processors > 1 {
|
|
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
|
|
if cb != nil {
|
|
num_segments := context.model.ctx.Whisper_full_n_segments()
|
|
s0 := num_segments - new
|
|
for i := s0; i < num_segments; i++ {
|
|
cb(toSegment(context.model.ctx, i))
|
|
}
|
|
}
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
} else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
|
|
if cb != nil {
|
|
num_segments := context.model.ctx.Whisper_full_n_segments()
|
|
s0 := num_segments - new
|
|
for i := s0; i < num_segments; i++ {
|
|
cb(toSegment(context.model.ctx, i))
|
|
}
|
|
}
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Return success
|
|
return nil
|
|
}
|
|
|
|
// Return the next segment of tokens
|
|
func (context *context) NextSegment() (Segment, error) {
|
|
if context.model.ctx == nil {
|
|
return Segment{}, ErrInternalAppError
|
|
}
|
|
if context.n >= context.model.ctx.Whisper_full_n_segments() {
|
|
return Segment{}, io.EOF
|
|
}
|
|
|
|
// Populate result
|
|
result := toSegment(context.model.ctx, context.n)
|
|
|
|
// Increment the cursor
|
|
context.n++
|
|
|
|
// Return success
|
|
return result, nil
|
|
}
|
|
|
|
// Test for text tokens
|
|
func (context *context) IsText(t Token) bool {
|
|
switch {
|
|
case context.IsBEG(t):
|
|
return false
|
|
case context.IsSOT(t):
|
|
return false
|
|
case whisper.Token(t.Id) >= context.model.ctx.Whisper_token_eot():
|
|
return false
|
|
case context.IsPREV(t):
|
|
return false
|
|
case context.IsSOLM(t):
|
|
return false
|
|
case context.IsNOT(t):
|
|
return false
|
|
default:
|
|
return true
|
|
}
|
|
}
|
|
|
|
// Test for "begin" token
|
|
func (context *context) IsBEG(t Token) bool {
|
|
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_beg()
|
|
}
|
|
|
|
// Test for "start of transcription" token
|
|
func (context *context) IsSOT(t Token) bool {
|
|
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_sot()
|
|
}
|
|
|
|
// Test for "end of transcription" token
|
|
func (context *context) IsEOT(t Token) bool {
|
|
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_eot()
|
|
}
|
|
|
|
// Test for "start of prev" token
|
|
func (context *context) IsPREV(t Token) bool {
|
|
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_prev()
|
|
}
|
|
|
|
// Test for "start of lm" token
|
|
func (context *context) IsSOLM(t Token) bool {
|
|
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_solm()
|
|
}
|
|
|
|
// Test for "No timestamps" token
|
|
func (context *context) IsNOT(t Token) bool {
|
|
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_not()
|
|
}
|
|
|
|
// Test for token associated with a specific language
|
|
func (context *context) IsLANG(t Token, lang string) bool {
|
|
if id := context.model.ctx.Whisper_lang_id(lang); id >= 0 {
|
|
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_lang(id)
|
|
} else {
|
|
return false
|
|
}
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
// PRIVATE METHODS
|
|
|
|
func toSegment(ctx *whisper.Context, n int) Segment {
|
|
return Segment{
|
|
Num: n,
|
|
Text: strings.TrimSpace(ctx.Whisper_full_get_segment_text(n)),
|
|
Start: time.Duration(ctx.Whisper_full_get_segment_t0(n)) * time.Millisecond * 10,
|
|
End: time.Duration(ctx.Whisper_full_get_segment_t1(n)) * time.Millisecond * 10,
|
|
Tokens: toTokens(ctx, n),
|
|
}
|
|
}
|
|
|
|
func toTokens(ctx *whisper.Context, n int) []Token {
|
|
result := make([]Token, ctx.Whisper_full_n_tokens(n))
|
|
for i := 0; i < len(result); i++ {
|
|
result[i] = Token{
|
|
Id: int(ctx.Whisper_full_get_token_id(n, i)),
|
|
Text: strings.TrimSpace(ctx.Whisper_full_get_token_text(n, i)),
|
|
P: ctx.Whisper_full_get_token_p(n, i),
|
|
}
|
|
}
|
|
return result
|
|
}
|