Merge branch 'ggerganov:master' into master

pull/444/head
Alex Bacart 2 years ago committed by GitHub
commit 7d9583cd5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,6 +3,11 @@ on:
push: push:
paths: paths:
- bindings/go/** - bindings/go/**
- whisper.h
pull_request:
paths:
- bindings/go/**
- whisper.h
jobs: jobs:
ubuntu-latest: ubuntu-latest:

@ -1,5 +1,5 @@
name: CI name: CI
on: [push] on: [push, pull_request]
jobs: jobs:
ubuntu-latest: ubuntu-latest:

1
.gitignore vendored

@ -1,4 +1,5 @@
*.o *.o
*.a
.cache/ .cache/
.vs/ .vs/
.vscode/ .vscode/

@ -64,10 +64,13 @@ func Process(model whisper.Model, path string, flags *Flags) error {
// Process the data // Process the data
fmt.Fprintf(flags.Output(), " ...processing %q\n", path) fmt.Fprintf(flags.Output(), " ...processing %q\n", path)
context.ResetTimings()
if err := context.Process(data, cb); err != nil { if err := context.Process(data, cb); err != nil {
return err return err
} }
context.PrintTimings()
// Print out the results // Print out the results
switch { switch {
case flags.GetOut() == "srt": case flags.GetOut() == "srt":

@ -107,6 +107,16 @@ func (context *context) SetMaxTokensPerSegment(n uint) {
context.params.SetMaxTokensPerSegment(int(n)) context.params.SetMaxTokensPerSegment(int(n))
} }
// ResetTimings resets the mode timings. Should be called before processing
func (context *context) ResetTimings() {
context.model.ctx.Whisper_reset_timings()
}
// PrintTimings prints the model timings to stdout.
func (context *context) PrintTimings() {
context.model.ctx.Whisper_print_timings()
}
// Process new sample data and return any errors // Process new sample data and return any errors
func (context *context) Process(data []float32, cb SegmentCallback) error { func (context *context) Process(data []float32, cb SegmentCallback) error {
if context.model.ctx == nil { if context.model.ctx == nil {

@ -60,6 +60,9 @@ type Context interface {
IsNOT(Token) bool // Test for "No timestamps" token IsNOT(Token) bool // Test for "No timestamps" token
IsLANG(Token, string) bool // Test for token associated with a specific language IsLANG(Token, string) bool // Test for token associated with a specific language
IsText(Token) bool // Test for text token IsText(Token) bool // Test for text token
PrintTimings()
ResetTimings()
} }
// Segment is the text result of a speech recognition. // Segment is the text result of a speech recognition.

File diff suppressed because one or more lines are too long

@ -88,11 +88,15 @@ async function fetchRemote(url, cbProgress, cbPrint) {
// - check if the data is already in the IndexedDB // - check if the data is already in the IndexedDB
// - if not, fetch it from the remote URL and store it in the IndexedDB // - if not, fetch it from the remote URL and store it in the IndexedDB
function loadRemote(url, dst, size_mb, cbProgress, cbReady, cbCancel, cbPrint) { function loadRemote(url, dst, size_mb, cbProgress, cbReady, cbCancel, cbPrint) {
if (!navigator.storage || !navigator.storage.estimate) {
cbPrint('loadRemote: navigator.storage.estimate() is not supported');
} else {
// query the storage quota and print it // query the storage quota and print it
navigator.storage.estimate().then(function (estimate) { navigator.storage.estimate().then(function (estimate) {
cbPrint('loadRemote: storage quota: ' + estimate.quota + ' bytes'); cbPrint('loadRemote: storage quota: ' + estimate.quota + ' bytes');
cbPrint('loadRemote: storage usage: ' + estimate.usage + ' bytes'); cbPrint('loadRemote: storage usage: ' + estimate.usage + ' bytes');
}); });
}
// check if the data is already in the IndexedDB // check if the data is already in the IndexedDB
var rq = indexedDB.open(dbName, dbVersion); var rq = indexedDB.open(dbName, dbVersion);

@ -3524,7 +3524,7 @@ int whisper_full(
prompt.clear(); prompt.clear();
// if we have already generated some text, use it as a prompt to condition the next generation // if we have already generated some text, use it as a prompt to condition the next generation
if (!prompt_past.empty() && t_cur < 0.5f) { if (!prompt_past.empty() && t_cur < 0.5f && params.n_max_text_ctx > 0) {
int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size())); int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
prompt = { whisper_token_prev(ctx) }; prompt = { whisper_token_prev(ctx) };
@ -3535,11 +3535,11 @@ int whisper_full(
prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
// print the prompt // print the prompt
//WHISPER_PRINT_DEBUG("\n\n"); WHISPER_PRINT_DEBUG("\n\n");
//for (int i = 0; i < (int) prompt.size(); i++) { for (int i = 0; i < (int) prompt.size(); i++) {
// WHISPER_PRINT_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str()); WHISPER_PRINT_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str());
//} }
//WHISPER_PRINT_DEBUG("\n\n"); WHISPER_PRINT_DEBUG("\n\n");
if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) { if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
fprintf(stderr, "%s: failed to decode\n", __func__); fprintf(stderr, "%s: failed to decode\n", __func__);

Loading…
Cancel
Save