package main import ( "context" "flag" "fmt" "io" "net/http" "net/url" "os" "path/filepath" "syscall" "time" ) /////////////////////////////////////////////////////////////////////////////// // CONSTANTS const ( srcUrl = "https://huggingface.co/" // The location of the models srcPathPrefix = "/datasets/ggerganov/whisper.cpp/resolve/main/ggml" // Filename prefix srcExt = ".bin" // Filename extension bufSize = 1024 * 64 // Size of the buffer used for downloading the model ) var ( // The models which will be downloaded, if no model is specified as an argument modelNames = []string{"tiny.en", "tiny", "base.en", "base", "small.en", "small", "medium.en", "medium", "large-v1", "large"} ) var ( // The output folder. When not set, use current working directory. flagOut = flag.String("out", "", "Output folder") // HTTP timeout parameter - will timeout if takes longer than this to download a model flagTimeout = flag.Duration("timeout", 30*time.Minute, "HTTP timeout") // Quiet parameter - will not print progress if set flagQuiet = flag.Bool("quiet", false, "Quiet mode") ) /////////////////////////////////////////////////////////////////////////////// // MAIN func main() { flag.Usage = func() { name := filepath.Base(flag.CommandLine.Name()) fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s [options] \n\n", name) flag.PrintDefaults() } flag.Parse() // Get output path out, err := GetOut() if err != nil { fmt.Fprintln(os.Stderr, "Error:", err) os.Exit(-1) } // Create context which quits on SIGINT or SIGQUIT ctx := ContextForSignal(os.Interrupt, syscall.SIGQUIT) // Progress filehandle progress := os.Stdout if *flagQuiet { progress, err = os.Open(os.DevNull) if err != nil { fmt.Fprintln(os.Stderr, "Error:", err) os.Exit(-1) } defer progress.Close() } // Download models - exit on error or interrupt for _, model := range GetModels() { url, err := URLForModel(model) if err != nil { fmt.Fprintln(os.Stderr, "Error:", err) continue } else if path, err := Download(ctx, progress, url, out); err == nil || err == io.EOF { continue } else if err == context.Canceled { os.Remove(path) fmt.Fprintln(progress, "\nInterrupted") break } else if err == context.DeadlineExceeded { os.Remove(path) fmt.Fprintln(progress, "Timeout downloading model") continue } else { os.Remove(path) fmt.Fprintln(os.Stderr, "Error:", err) break } } } /////////////////////////////////////////////////////////////////////////////// // PUBLIC METHODS // GetOut returns the path to the output directory func GetOut() (string, error) { if *flagOut == "" { return os.Getwd() } if info, err := os.Stat(*flagOut); err != nil { return "", err } else if !info.IsDir() { return "", fmt.Errorf("not a directory: %s", info.Name()) } else { return *flagOut, nil } } // GetModels returns the list of models to download func GetModels() []string { if flag.NArg() == 0 { return modelNames } else { return flag.Args() } } // URLForModel returns the URL for the given model on huggingface.co func URLForModel(model string) (string, error) { url, err := url.Parse(srcUrl) if err != nil { return "", err } else { url.Path = srcPathPrefix + "-" + model + srcExt } return url.String(), nil } // Download downloads the model from the given URL to the given output directory func Download(ctx context.Context, p io.Writer, model, out string) (string, error) { // Create HTTP client client := http.Client{ Timeout: *flagTimeout, } // Initiate the download req, err := http.NewRequest("GET", model, nil) if err != nil { return "", err } resp, err := client.Do(req) if err != nil { return "", err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return "", fmt.Errorf("%s: %s", model, resp.Status) } // If output file exists and is the same size as the model, skip path := filepath.Join(out, filepath.Base(model)) if info, err := os.Stat(path); err == nil && info.Size() == resp.ContentLength { fmt.Fprintln(p, "Skipping", model, "as it already exists") return "", nil } // Create file w, err := os.Create(path) if err != nil { return "", err } defer w.Close() // Report fmt.Fprintln(p, "Downloading", model, "to", out) // Progressively download the model data := make([]byte, bufSize) count, pct := int64(0), int64(0) ticker := time.NewTicker(5 * time.Second) for { select { case <-ctx.Done(): // Cancelled, return error return path, ctx.Err() case <-ticker.C: pct = DownloadReport(p, pct, count, resp.ContentLength) default: // Read body n, err := resp.Body.Read(data) if err != nil { DownloadReport(p, pct, count, resp.ContentLength) return path, err } else if m, err := w.Write(data[:n]); err != nil { return path, err } else { count += int64(m) } } } } // Report periodically reports the download progress when percentage changes func DownloadReport(w io.Writer, pct, count, total int64) int64 { pct_ := count * 100 / total if pct_ > pct { fmt.Fprintf(w, " ...%d MB written (%d%%)\n", count/1e6, pct_) } return pct_ }