You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
207 lines
5.2 KiB
207 lines
5.2 KiB
package main
|
|
|
|
import (
|
|
"context"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"syscall"
|
|
"time"
|
|
)
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
// CONSTANTS
|
|
|
|
const (
|
|
srcUrl = "https://huggingface.co/" // The location of the models
|
|
srcPathPrefix = "/datasets/ggerganov/whisper.cpp/resolve/main/ggml" // Filename prefix
|
|
srcExt = ".bin" // Filename extension
|
|
bufSize = 1024 * 64 // Size of the buffer used for downloading the model
|
|
)
|
|
|
|
var (
|
|
// The models which will be downloaded, if no model is specified as an argument
|
|
modelNames = []string{"tiny.en", "tiny", "base.en", "base", "small.en", "small", "medium.en", "medium", "large-v1", "large"}
|
|
)
|
|
|
|
var (
|
|
// The output folder. When not set, use current working directory.
|
|
flagOut = flag.String("out", "", "Output folder")
|
|
|
|
// HTTP timeout parameter - will timeout if takes longer than this to download a model
|
|
flagTimeout = flag.Duration("timeout", 30*time.Minute, "HTTP timeout")
|
|
|
|
// Quiet parameter - will not print progress if set
|
|
flagQuiet = flag.Bool("quiet", false, "Quiet mode")
|
|
)
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
// MAIN
|
|
|
|
func main() {
|
|
flag.Usage = func() {
|
|
name := filepath.Base(flag.CommandLine.Name())
|
|
fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s [options] <model>\n\n", name)
|
|
flag.PrintDefaults()
|
|
}
|
|
flag.Parse()
|
|
|
|
// Get output path
|
|
out, err := GetOut()
|
|
if err != nil {
|
|
fmt.Fprintln(os.Stderr, "Error:", err)
|
|
os.Exit(-1)
|
|
}
|
|
|
|
// Create context which quits on SIGINT or SIGQUIT
|
|
ctx := ContextForSignal(os.Interrupt, syscall.SIGQUIT)
|
|
|
|
// Progress filehandle
|
|
progress := os.Stdout
|
|
if *flagQuiet {
|
|
progress, err = os.Open(os.DevNull)
|
|
if err != nil {
|
|
fmt.Fprintln(os.Stderr, "Error:", err)
|
|
os.Exit(-1)
|
|
}
|
|
defer progress.Close()
|
|
}
|
|
|
|
// Download models - exit on error or interrupt
|
|
for _, model := range GetModels() {
|
|
url, err := URLForModel(model)
|
|
if err != nil {
|
|
fmt.Fprintln(os.Stderr, "Error:", err)
|
|
continue
|
|
} else if path, err := Download(ctx, progress, url, out); err == nil || err == io.EOF {
|
|
continue
|
|
} else if err == context.Canceled {
|
|
os.Remove(path)
|
|
fmt.Fprintln(progress, "\nInterrupted")
|
|
break
|
|
} else if err == context.DeadlineExceeded {
|
|
os.Remove(path)
|
|
fmt.Fprintln(progress, "Timeout downloading model")
|
|
continue
|
|
} else {
|
|
os.Remove(path)
|
|
fmt.Fprintln(os.Stderr, "Error:", err)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
// PUBLIC METHODS
|
|
|
|
// GetOut returns the path to the output directory
|
|
func GetOut() (string, error) {
|
|
if *flagOut == "" {
|
|
return os.Getwd()
|
|
}
|
|
if info, err := os.Stat(*flagOut); err != nil {
|
|
return "", err
|
|
} else if !info.IsDir() {
|
|
return "", fmt.Errorf("not a directory: %s", info.Name())
|
|
} else {
|
|
return *flagOut, nil
|
|
}
|
|
}
|
|
|
|
// GetModels returns the list of models to download
|
|
func GetModels() []string {
|
|
if flag.NArg() == 0 {
|
|
return modelNames
|
|
} else {
|
|
return flag.Args()
|
|
}
|
|
}
|
|
|
|
// URLForModel returns the URL for the given model on huggingface.co
|
|
func URLForModel(model string) (string, error) {
|
|
url, err := url.Parse(srcUrl)
|
|
if err != nil {
|
|
return "", err
|
|
} else {
|
|
url.Path = srcPathPrefix + "-" + model + srcExt
|
|
}
|
|
return url.String(), nil
|
|
}
|
|
|
|
// Download downloads the model from the given URL to the given output directory
|
|
func Download(ctx context.Context, p io.Writer, model, out string) (string, error) {
|
|
// Create HTTP client
|
|
client := http.Client{
|
|
Timeout: *flagTimeout,
|
|
}
|
|
|
|
// Initiate the download
|
|
req, err := http.NewRequest("GET", model, nil)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
return "", fmt.Errorf("%s: %s", model, resp.Status)
|
|
}
|
|
|
|
// If output file exists and is the same size as the model, skip
|
|
path := filepath.Join(out, filepath.Base(model))
|
|
if info, err := os.Stat(path); err == nil && info.Size() == resp.ContentLength {
|
|
fmt.Fprintln(p, "Skipping", model, "as it already exists")
|
|
return "", nil
|
|
}
|
|
|
|
// Create file
|
|
w, err := os.Create(path)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer w.Close()
|
|
|
|
// Report
|
|
fmt.Fprintln(p, "Downloading", model, "to", out)
|
|
|
|
// Progressively download the model
|
|
data := make([]byte, bufSize)
|
|
count, pct := int64(0), int64(0)
|
|
ticker := time.NewTicker(5 * time.Second)
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
// Cancelled, return error
|
|
return path, ctx.Err()
|
|
case <-ticker.C:
|
|
pct = DownloadReport(p, pct, count, resp.ContentLength)
|
|
default:
|
|
// Read body
|
|
n, err := resp.Body.Read(data)
|
|
if err != nil {
|
|
DownloadReport(p, pct, count, resp.ContentLength)
|
|
return path, err
|
|
} else if m, err := w.Write(data[:n]); err != nil {
|
|
return path, err
|
|
} else {
|
|
count += int64(m)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Report periodically reports the download progress when percentage changes
|
|
func DownloadReport(w io.Writer, pct, count, total int64) int64 {
|
|
pct_ := count * 100 / total
|
|
if pct_ > pct {
|
|
fmt.Fprintf(w, " ...%d MB written (%d%%)\n", count/1e6, pct_)
|
|
}
|
|
return pct_
|
|
}
|