You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
whisper.cpp/bindings/go/examples/go-whisper/flags.go

157 lines
4.5 KiB

package main
import (
"flag"
"fmt"
"strings"
"time"
// Packages
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
)
///////////////////////////////////////////////////////////////////////////////
// TYPES
type Flags struct {
*flag.FlagSet
}
///////////////////////////////////////////////////////////////////////////////
// LIFECYCLE
func NewFlags(name string, args []string) (*Flags, error) {
flags := &Flags{
FlagSet: flag.NewFlagSet(name, flag.ContinueOnError),
}
// Register the command line arguments
registerFlags(flags)
// Parse command line
if err := flags.Parse(args); err != nil {
return nil, err
}
// Return success
return flags, nil
}
///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS
func (flags *Flags) GetModel() string {
return flags.Lookup("model").Value.String()
}
func (flags *Flags) GetLanguage() string {
return flags.Lookup("language").Value.String()
}
func (flags *Flags) IsTranslate() bool {
return flags.Lookup("translate").Value.(flag.Getter).Get().(bool)
}
func (flags *Flags) GetOffset() time.Duration {
return flags.Lookup("offset").Value.(flag.Getter).Get().(time.Duration)
}
func (flags *Flags) GetDuration() time.Duration {
return flags.Lookup("duration").Value.(flag.Getter).Get().(time.Duration)
}
func (flags *Flags) GetThreads() uint {
return flags.Lookup("threads").Value.(flag.Getter).Get().(uint)
}
func (flags *Flags) GetOut() string {
return strings.ToLower(flags.Lookup("out").Value.String())
}
func (flags *Flags) IsSpeedup() bool {
return flags.Lookup("speedup").Value.String() == "true"
}
func (flags *Flags) IsTokens() bool {
return flags.Lookup("tokens").Value.String() == "true"
}
func (flags *Flags) IsColorize() bool {
return flags.Lookup("colorize").Value.String() == "true"
}
func (flags *Flags) GetMaxLen() uint {
return flags.Lookup("max-len").Value.(flag.Getter).Get().(uint)
}
func (flags *Flags) GetMaxTokens() uint {
return flags.Lookup("max-tokens").Value.(flag.Getter).Get().(uint)
}
func (flags *Flags) GetWordThreshold() float32 {
return float32(flags.Lookup("word-thold").Value.(flag.Getter).Get().(float64))
}
func (flags *Flags) SetParams(context whisper.Context) error {
if lang := flags.GetLanguage(); lang != "" && lang != "auto" {
fmt.Fprintf(flags.Output(), "Setting language to %q\n", lang)
if err := context.SetLanguage(lang); err != nil {
return err
}
}
if flags.IsTranslate() && context.IsMultilingual() {
fmt.Fprintf(flags.Output(), "Setting translate to true\n")
context.SetTranslate(true)
}
if offset := flags.GetOffset(); offset != 0 {
fmt.Fprintf(flags.Output(), "Setting offset to %v\n", offset)
context.SetOffset(offset)
}
if duration := flags.GetDuration(); duration != 0 {
fmt.Fprintf(flags.Output(), "Setting duration to %v\n", duration)
context.SetDuration(duration)
}
if flags.IsSpeedup() {
fmt.Fprintf(flags.Output(), "Setting speedup to true\n")
context.SetSpeedup(true)
}
if threads := flags.GetThreads(); threads != 0 {
fmt.Fprintf(flags.Output(), "Setting threads to %d\n", threads)
context.SetThreads(threads)
}
if max_len := flags.GetMaxLen(); max_len != 0 {
fmt.Fprintf(flags.Output(), "Setting max_segment_length to %d\n", max_len)
context.SetMaxSegmentLength(max_len)
}
if max_tokens := flags.GetMaxTokens(); max_tokens != 0 {
fmt.Fprintf(flags.Output(), "Setting max_tokens to %d\n", max_tokens)
context.SetMaxTokensPerSegment(max_tokens)
}
if word_threshold := flags.GetWordThreshold(); word_threshold != 0 {
fmt.Fprintf(flags.Output(), "Setting word_threshold to %f\n", word_threshold)
context.SetTokenThreshold(word_threshold)
}
// Return success
return nil
}
///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS
func registerFlags(flag *Flags) {
flag.String("model", "", "Path to the model file")
flag.String("language", "", "Spoken language")
flag.Bool("translate", false, "Translate from source language to english")
flag.Duration("offset", 0, "Time offset")
flag.Duration("duration", 0, "Duration of audio to process")
flag.Uint("threads", 0, "Number of threads to use")
flag.Bool("speedup", false, "Enable speedup")
flag.Uint("max-len", 0, "Maximum segment length in characters")
flag.Uint("max-tokens", 0, "Maximum tokens per segment")
flag.Float64("word-thold", 0, "Maximum segment score")
flag.Bool("tokens", false, "Display tokens")
flag.Bool("colorize", false, "Colorize tokens")
flag.String("out", "", "Output format (srt, none or leave as empty string)")
}