You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
131 lines
3.2 KiB
131 lines
3.2 KiB
package main
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"time"
|
|
|
|
// Package imports
|
|
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
|
wav "github.com/go-audio/wav"
|
|
)
|
|
|
|
func Process(model whisper.Model, path string, flags *Flags) error {
|
|
var data []float32
|
|
|
|
// Create processing context
|
|
context, err := model.NewContext()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Set the parameters
|
|
if err := flags.SetParams(context); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Open the file
|
|
fmt.Fprintf(flags.Output(), "Loading %q\n", path)
|
|
fh, err := os.Open(path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer fh.Close()
|
|
|
|
// Decode the WAV file - load the full buffer
|
|
dec := wav.NewDecoder(fh)
|
|
if buf, err := dec.FullPCMBuffer(); err != nil {
|
|
return err
|
|
} else if dec.SampleRate != whisper.SampleRate {
|
|
return fmt.Errorf("unsupported sample rate: %d", dec.SampleRate)
|
|
} else if dec.NumChans != 1 {
|
|
return fmt.Errorf("unsupported number of channels: %d", dec.NumChans)
|
|
} else {
|
|
data = buf.AsFloat32Buffer().Data
|
|
}
|
|
|
|
// Segment callback when -tokens is specified
|
|
var cb whisper.SegmentCallback
|
|
if flags.IsTokens() {
|
|
cb = func(segment whisper.Segment) {
|
|
fmt.Fprintf(flags.Output(), "%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
|
|
for _, token := range segment.Tokens {
|
|
if flags.IsColorize() && context.IsText(token) {
|
|
fmt.Fprint(flags.Output(), Colorize(token.Text, int(token.P*24.0)), " ")
|
|
} else {
|
|
fmt.Fprint(flags.Output(), token.Text, " ")
|
|
}
|
|
}
|
|
fmt.Fprintln(flags.Output(), "")
|
|
fmt.Fprintln(flags.Output(), "")
|
|
}
|
|
}
|
|
|
|
// Process the data
|
|
fmt.Fprintf(flags.Output(), " ...processing %q\n", path)
|
|
context.ResetTimings()
|
|
if err := context.Process(data, cb); err != nil {
|
|
return err
|
|
}
|
|
|
|
context.PrintTimings()
|
|
|
|
// Print out the results
|
|
switch {
|
|
case flags.GetOut() == "srt":
|
|
return OutputSRT(os.Stdout, context)
|
|
case flags.GetOut() == "none":
|
|
return nil
|
|
default:
|
|
return Output(os.Stdout, context, flags.IsColorize())
|
|
}
|
|
}
|
|
|
|
// Output text as SRT file
|
|
func OutputSRT(w io.Writer, context whisper.Context) error {
|
|
n := 1
|
|
for {
|
|
segment, err := context.NextSegment()
|
|
if err == io.EOF {
|
|
return nil
|
|
} else if err != nil {
|
|
return err
|
|
}
|
|
fmt.Fprintln(w, n)
|
|
fmt.Fprintln(w, srtTimestamp(segment.Start), " --> ", srtTimestamp(segment.End))
|
|
fmt.Fprintln(w, segment.Text)
|
|
fmt.Fprintln(w, "")
|
|
n++
|
|
}
|
|
}
|
|
|
|
// Output text to terminal
|
|
func Output(w io.Writer, context whisper.Context, colorize bool) error {
|
|
for {
|
|
segment, err := context.NextSegment()
|
|
if err == io.EOF {
|
|
return nil
|
|
} else if err != nil {
|
|
return err
|
|
}
|
|
fmt.Fprintf(w, "[%6s->%6s]", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
|
|
if colorize {
|
|
for _, token := range segment.Tokens {
|
|
if !context.IsText(token) {
|
|
continue
|
|
}
|
|
fmt.Fprint(w, " ", Colorize(token.Text, int(token.P*24.0)))
|
|
}
|
|
fmt.Fprint(w, "\n")
|
|
} else {
|
|
fmt.Fprintln(w, " ", segment.Text)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Return srtTimestamp
|
|
func srtTimestamp(t time.Duration) string {
|
|
return fmt.Sprintf("%02d:%02d:%02d,%03d", t/time.Hour, (t%time.Hour)/time.Minute, (t%time.Minute)/time.Second, (t%time.Second)/time.Millisecond)
|
|
}
|