You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
157 lines
4.5 KiB
157 lines
4.5 KiB
package main
|
|
|
|
import (
|
|
"flag"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
// Packages
|
|
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
|
)
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
// TYPES
|
|
|
|
type Flags struct {
|
|
*flag.FlagSet
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
// LIFECYCLE
|
|
|
|
func NewFlags(name string, args []string) (*Flags, error) {
|
|
flags := &Flags{
|
|
FlagSet: flag.NewFlagSet(name, flag.ContinueOnError),
|
|
}
|
|
|
|
// Register the command line arguments
|
|
registerFlags(flags)
|
|
|
|
// Parse command line
|
|
if err := flags.Parse(args); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Return success
|
|
return flags, nil
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
// PUBLIC METHODS
|
|
|
|
func (flags *Flags) GetModel() string {
|
|
return flags.Lookup("model").Value.String()
|
|
}
|
|
|
|
func (flags *Flags) GetLanguage() string {
|
|
return flags.Lookup("language").Value.String()
|
|
}
|
|
|
|
func (flags *Flags) IsTranslate() bool {
|
|
return flags.Lookup("translate").Value.(flag.Getter).Get().(bool)
|
|
}
|
|
|
|
func (flags *Flags) GetOffset() time.Duration {
|
|
return flags.Lookup("offset").Value.(flag.Getter).Get().(time.Duration)
|
|
}
|
|
|
|
func (flags *Flags) GetDuration() time.Duration {
|
|
return flags.Lookup("duration").Value.(flag.Getter).Get().(time.Duration)
|
|
}
|
|
|
|
func (flags *Flags) GetThreads() uint {
|
|
return flags.Lookup("threads").Value.(flag.Getter).Get().(uint)
|
|
}
|
|
|
|
func (flags *Flags) GetOut() string {
|
|
return strings.ToLower(flags.Lookup("out").Value.String())
|
|
}
|
|
|
|
func (flags *Flags) IsSpeedup() bool {
|
|
return flags.Lookup("speedup").Value.String() == "true"
|
|
}
|
|
|
|
func (flags *Flags) IsTokens() bool {
|
|
return flags.Lookup("tokens").Value.String() == "true"
|
|
}
|
|
|
|
func (flags *Flags) IsColorize() bool {
|
|
return flags.Lookup("colorize").Value.String() == "true"
|
|
}
|
|
|
|
func (flags *Flags) GetMaxLen() uint {
|
|
return flags.Lookup("max-len").Value.(flag.Getter).Get().(uint)
|
|
}
|
|
|
|
func (flags *Flags) GetMaxTokens() uint {
|
|
return flags.Lookup("max-tokens").Value.(flag.Getter).Get().(uint)
|
|
}
|
|
|
|
func (flags *Flags) GetWordThreshold() float32 {
|
|
return float32(flags.Lookup("word-thold").Value.(flag.Getter).Get().(float64))
|
|
}
|
|
|
|
func (flags *Flags) SetParams(context whisper.Context) error {
|
|
if lang := flags.GetLanguage(); lang != "" && lang != "auto" {
|
|
fmt.Fprintf(flags.Output(), "Setting language to %q\n", lang)
|
|
if err := context.SetLanguage(lang); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
if flags.IsTranslate() && context.IsMultilingual() {
|
|
fmt.Fprintf(flags.Output(), "Setting translate to true\n")
|
|
context.SetTranslate(true)
|
|
}
|
|
if offset := flags.GetOffset(); offset != 0 {
|
|
fmt.Fprintf(flags.Output(), "Setting offset to %v\n", offset)
|
|
context.SetOffset(offset)
|
|
}
|
|
if duration := flags.GetDuration(); duration != 0 {
|
|
fmt.Fprintf(flags.Output(), "Setting duration to %v\n", duration)
|
|
context.SetDuration(duration)
|
|
}
|
|
if flags.IsSpeedup() {
|
|
fmt.Fprintf(flags.Output(), "Setting speedup to true\n")
|
|
context.SetSpeedup(true)
|
|
}
|
|
if threads := flags.GetThreads(); threads != 0 {
|
|
fmt.Fprintf(flags.Output(), "Setting threads to %d\n", threads)
|
|
context.SetThreads(threads)
|
|
}
|
|
if max_len := flags.GetMaxLen(); max_len != 0 {
|
|
fmt.Fprintf(flags.Output(), "Setting max_segment_length to %d\n", max_len)
|
|
context.SetMaxSegmentLength(max_len)
|
|
}
|
|
if max_tokens := flags.GetMaxTokens(); max_tokens != 0 {
|
|
fmt.Fprintf(flags.Output(), "Setting max_tokens to %d\n", max_tokens)
|
|
context.SetMaxTokensPerSegment(max_tokens)
|
|
}
|
|
if word_threshold := flags.GetWordThreshold(); word_threshold != 0 {
|
|
fmt.Fprintf(flags.Output(), "Setting word_threshold to %f\n", word_threshold)
|
|
context.SetTokenThreshold(word_threshold)
|
|
}
|
|
|
|
// Return success
|
|
return nil
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
// PRIVATE METHODS
|
|
|
|
func registerFlags(flag *Flags) {
|
|
flag.String("model", "", "Path to the model file")
|
|
flag.String("language", "", "Spoken language")
|
|
flag.Bool("translate", false, "Translate from source language to english")
|
|
flag.Duration("offset", 0, "Time offset")
|
|
flag.Duration("duration", 0, "Duration of audio to process")
|
|
flag.Uint("threads", 0, "Number of threads to use")
|
|
flag.Bool("speedup", false, "Enable speedup")
|
|
flag.Uint("max-len", 0, "Maximum segment length in characters")
|
|
flag.Uint("max-tokens", 0, "Maximum tokens per segment")
|
|
flag.Float64("word-thold", 0, "Maximum segment score")
|
|
flag.Bool("tokens", false, "Display tokens")
|
|
flag.Bool("colorize", false, "Colorize tokens")
|
|
flag.String("out", "", "Output format (srt, none or leave as empty string)")
|
|
}
|