diff --git a/bindings/javascript/emscripten.cpp b/bindings/javascript/emscripten.cpp index 789ad8b..6ef68f0 100644 --- a/bindings/javascript/emscripten.cpp +++ b/bindings/javascript/emscripten.cpp @@ -16,13 +16,15 @@ #include struct whisper_context * g_context; +struct whisper_state * g_state; EMSCRIPTEN_BINDINGS(whisper) { emscripten::function("init", emscripten::optional_override([](const std::string & path_model) { if (g_context == nullptr) { g_context = whisper_init_from_file(path_model.c_str()); if (g_context != nullptr) { - return true; + g_state = whisper_init_state(g_context); + return g_state != nullptr; } else { return false; } @@ -36,10 +38,14 @@ EMSCRIPTEN_BINDINGS(whisper) { whisper_free(g_context); g_context = nullptr; } + if (g_state) { + whisper_free_state(g_state); + g_state = nullptr; + } })); emscripten::function("full_default", emscripten::optional_override([](const emscripten::val & audio, const std::string & lang, bool translate) { - if (g_context == nullptr) { + if (g_context == nullptr || g_state == nullptr) { return -1; } @@ -83,9 +89,9 @@ EMSCRIPTEN_BINDINGS(whisper) { // run whisper { - whisper_reset_timings(g_context); - whisper_full(g_context, params, pcmf32.data(), pcmf32.size()); - whisper_print_timings(g_context); + whisper_reset_timings(g_state); + whisper_full_with_state(g_context, g_state, params, pcmf32.data(), pcmf32.size()); + whisper_print_timings(g_state); } return 0; diff --git a/whisper.cpp b/whisper.cpp index dedc58e..cb67dc9 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -2481,7 +2481,7 @@ struct whisper_state* whisper_init_state(whisper_context * ctx) { if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->wtype, ctx->model.hparams.n_text_ctx)) { fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); - return NULL; + return nullptr; } { @@ -2491,7 +2491,7 @@ struct whisper_state* whisper_init_state(whisper_context * ctx) { if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_CROSS.at(ctx->model.type), state->kv_cross, ctx->wtype, ctx->model.hparams.n_audio_ctx)) { fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__); - return NULL; + return nullptr; } { diff --git a/whisper.h b/whisper.h index 0d67125..7d01772 100644 --- a/whisper.h +++ b/whisper.h @@ -103,7 +103,7 @@ extern "C" { WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader); // Function to create a new state used to run a transformation. - WHISPER_API struct whisper_state* whisper_init_state(whisper_context* ctx); + WHISPER_API struct whisper_state * whisper_init_state(struct whisper_context * ctx); // Frees all memory allocated by the model. WHISPER_API void whisper_free(struct whisper_context * ctx);