Forgot to add some files

pull/494/head
Sandro Hanea 2 years ago
parent aac01503cc
commit edb25e4730

@ -16,13 +16,15 @@
#include <vector> #include <vector>
struct whisper_context * g_context; struct whisper_context * g_context;
struct whisper_state * g_state;
EMSCRIPTEN_BINDINGS(whisper) { EMSCRIPTEN_BINDINGS(whisper) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) { emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
if (g_context == nullptr) { if (g_context == nullptr) {
g_context = whisper_init_from_file(path_model.c_str()); g_context = whisper_init_from_file(path_model.c_str());
if (g_context != nullptr) { if (g_context != nullptr) {
return true; g_state = whisper_init_state(g_context);
return g_state != nullptr;
} else { } else {
return false; return false;
} }
@ -36,10 +38,14 @@ EMSCRIPTEN_BINDINGS(whisper) {
whisper_free(g_context); whisper_free(g_context);
g_context = nullptr; g_context = nullptr;
} }
if (g_state) {
whisper_free_state(g_state);
g_state = nullptr;
}
})); }));
emscripten::function("full_default", emscripten::optional_override([](const emscripten::val & audio, const std::string & lang, bool translate) { emscripten::function("full_default", emscripten::optional_override([](const emscripten::val & audio, const std::string & lang, bool translate) {
if (g_context == nullptr) { if (g_context == nullptr || g_state == nullptr) {
return -1; return -1;
} }
@ -83,9 +89,9 @@ EMSCRIPTEN_BINDINGS(whisper) {
// run whisper // run whisper
{ {
whisper_reset_timings(g_context); whisper_reset_timings(g_state);
whisper_full(g_context, params, pcmf32.data(), pcmf32.size()); whisper_full_with_state(g_context, g_state, params, pcmf32.data(), pcmf32.size());
whisper_print_timings(g_context); whisper_print_timings(g_state);
} }
return 0; return 0;

@ -2481,7 +2481,7 @@ struct whisper_state* whisper_init_state(whisper_context * ctx) {
if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->wtype, ctx->model.hparams.n_text_ctx)) { if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->wtype, ctx->model.hparams.n_text_ctx)) {
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
return NULL; return nullptr;
} }
{ {
@ -2491,7 +2491,7 @@ struct whisper_state* whisper_init_state(whisper_context * ctx) {
if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_CROSS.at(ctx->model.type), state->kv_cross, ctx->wtype, ctx->model.hparams.n_audio_ctx)) { if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_CROSS.at(ctx->model.type), state->kv_cross, ctx->wtype, ctx->model.hparams.n_audio_ctx)) {
fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__); fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__);
return NULL; return nullptr;
} }
{ {

@ -103,7 +103,7 @@ extern "C" {
WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader); WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader);
// Function to create a new state used to run a transformation. // Function to create a new state used to run a transformation.
WHISPER_API struct whisper_state* whisper_init_state(whisper_context* ctx); WHISPER_API struct whisper_state * whisper_init_state(struct whisper_context * ctx);
// Frees all memory allocated by the model. // Frees all memory allocated by the model.
WHISPER_API void whisper_free(struct whisper_context * ctx); WHISPER_API void whisper_free(struct whisper_context * ctx);

Loading…
Cancel
Save