diff --git a/examples/helpers.js b/examples/helpers.js new file mode 100644 index 0000000..a16a5e3 --- /dev/null +++ b/examples/helpers.js @@ -0,0 +1,176 @@ +// Common Javascript functions used by the examples + +function convertTypedArray(src, type) { + var buffer = new ArrayBuffer(src.byteLength); + var baseView = new src.constructor(buffer).set(src); + return new type(buffer); +} + +var printTextarea = (function() { + var element = document.getElementById('output'); + if (element) element.alue = ''; // clear browser cache + return function(text) { + if (arguments.length > 1) text = Array.prototype.slice.call(arguments).join(' '); + console.log(text); + if (element) { + element.value += text + "\n"; + element.scrollTop = element.scrollHeight; // focus on bottom + } + }; +})(); + +// fetch a remote file from remote URL using the Fetch API +async function fetchRemote(url, cbProgress, cbPrint) { + cbPrint('fetchRemote: downloading with fetch()...'); + + const response = await fetch( + url, + { + method: 'GET', + headers: { + 'Content-Type': 'application/octet-stream', + }, + } + ); + + if (!response.ok) { + cbPrint('fetchRemote: failed to fetch ' + url); + return; + } + + const contentLength = response.headers.get('content-length'); + const total = parseInt(contentLength, 10); + const reader = response.body.getReader(); + + var chunks = []; + var receivedLength = 0; + var progressLast = -1; + + while (true) { + const { done, value } = await reader.read(); + + if (done) { + break; + } + + chunks.push(value); + receivedLength += value.length; + + if (contentLength) { + cbProgress(receivedLength/total); + + var progressCur = Math.round((receivedLength / total) * 10); + if (progressCur != progressLast) { + cbPrint('fetchRemote: fetching ' + 10*progressCur + '% ...'); + progressLast = progressCur; + } + } + } + + var position = 0; + var chunksAll = new Uint8Array(receivedLength); + + for (var chunk of chunks) { + chunksAll.set(chunk, position); + position += chunk.length; + } + + return chunksAll; +} + +// load remote data +// - check if the data is already in the IndexedDB +// - if not, fetch it from the remote URL and store it in the IndexedDB +function loadRemote(url, dst, size_mb, cbProgress, cbReady, cbCancel, cbPrint) { + // query the storage quota and print it + navigator.storage.estimate().then(function (estimate) { + cbPrint('loadRemote: storage quota: ' + estimate.quota + ' bytes'); + cbPrint('loadRemote: storage usage: ' + estimate.usage + ' bytes'); + }); + + // check if the data is already in the IndexedDB + var rq = indexedDB.open(dbName, dbVersion); + + rq.onupgradeneeded = function (event) { + var db = event.target.result; + if (db.version == 1) { + var os = db.createObjectStore('models', { autoIncrement: false }); + cbPrint('loadRemote: created IndexedDB ' + db.name + ' version ' + db.version); + } else { + // clear the database + var os = event.currentTarget.transaction.objectStore('models'); + os.clear(); + cbPrint('loadRemote: cleared IndexedDB ' + db.name + ' version ' + db.version); + } + }; + + rq.onsuccess = function (event) { + var db = event.target.result; + var tx = db.transaction(['models'], 'readonly'); + var os = tx.objectStore('models'); + var rq = os.get(url); + + rq.onsuccess = function (event) { + if (rq.result) { + cbPrint('loadRemote: "' + url + '" is already in the IndexedDB'); + cbReady(dst, rq.result); + } else { + // data is not in the IndexedDB + cbPrint('loadRemote: "' + url + '" is not in the IndexedDB'); + + // alert and ask the user to confirm + if (!confirm( + 'You are about to download ' + size_mb + ' MB of data.\n' + + 'The model data will be cached in the browser for future use.\n\n' + + 'Press OK to continue.')) { + cbCancel(); + return; + } + + fetchRemote(url, cbProgress, cbPrint).then(function (data) { + if (data) { + // store the data in the IndexedDB + var rq = indexedDB.open(dbName, dbVersion); + rq.onsuccess = function (event) { + var db = event.target.result; + var tx = db.transaction(['models'], 'readwrite'); + var os = tx.objectStore('models'); + var rq = os.put(data, url); + + rq.onsuccess = function (event) { + cbPrint('loadRemote: "' + url + '" stored in the IndexedDB'); + cbReady(dst, data); + }; + + rq.onerror = function (event) { + cbPrint('loadRemote: failed to store "' + url + '" in the IndexedDB'); + cbCancel(); + }; + }; + } + }); + } + }; + + rq.onerror = function (event) { + cbPrint('loadRemote: failed to get data from the IndexedDB'); + cbCancel(); + }; + }; + + rq.onerror = function (event) { + cbPrint('loadRemote: failed to open IndexedDB'); + cbCancel(); + }; + + rq.onblocked = function (event) { + cbPrint('loadRemote: failed to open IndexedDB: blocked'); + cbCancel(); + }; + + rq.onabort = function (event) { + cbPrint('loadRemote: failed to open IndexedDB: abort'); + + }; +} + diff --git a/examples/talk.wasm/CMakeLists.txt b/examples/talk.wasm/CMakeLists.txt index 35f6223..567bc74 100644 --- a/examples/talk.wasm/CMakeLists.txt +++ b/examples/talk.wasm/CMakeLists.txt @@ -45,3 +45,4 @@ set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \ set(TARGET talk.wasm) configure_file(${CMAKE_CURRENT_SOURCE_DIR}/index-tmpl.html ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/index.html @ONLY) +configure_file(${CMAKE_CURRENT_SOURCE_DIR}/../helpers.js ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/helpers.js @ONLY) diff --git a/examples/talk.wasm/README.md b/examples/talk.wasm/README.md index 54da098..9d8c8b1 100644 --- a/examples/talk.wasm/README.md +++ b/examples/talk.wasm/README.md @@ -61,9 +61,8 @@ emcmake cmake .. make -j # copy the produced page to your HTTP path -cp bin/talk.wasm/index.html /path/to/html/ -cp bin/talk.wasm/talk.js /path/to/html/ -cp bin/libtalk.worker.js /path/to/html/ +cp bin/talk.wasm/* /path/to/html/ +cp bin/libtalk.worker.js /path/to/html/ ``` ## Feedback diff --git a/examples/talk.wasm/emscripten.cpp b/examples/talk.wasm/emscripten.cpp index f0add29..d6a578c 100644 --- a/examples/talk.wasm/emscripten.cpp +++ b/examples/talk.wasm/emscripten.cpp @@ -62,7 +62,7 @@ void talk_main(size_t index) { wparams.print_special_tokens = false; wparams.max_tokens = 32; - wparams.audio_ctx = 768; + wparams.audio_ctx = 768; // partial encoder context for better performance wparams.language = "en"; @@ -133,7 +133,7 @@ void talk_main(size_t index) { } } - talk_set_status("processing ..."); + talk_set_status("processing audio (whisper)..."); t_last = t_now; @@ -192,7 +192,7 @@ void talk_main(size_t index) { text_heard = std::regex_replace(text_heard, std::regex("^\\s+"), ""); text_heard = std::regex_replace(text_heard, std::regex("\\s+$"), ""); - talk_set_status("'" + text_heard + "' - thinking how to respond ..."); + talk_set_status("'" + text_heard + "' - thinking how to respond (gpt-2) ..."); const std::vector tokens = gpt2_tokenize(g_gpt2, text_heard.c_str()); diff --git a/examples/talk.wasm/index-tmpl.html b/examples/talk.wasm/index-tmpl.html index be95b1d..8588e77 100644 --- a/examples/talk.wasm/index-tmpl.html +++ b/examples/talk.wasm/index-tmpl.html @@ -51,7 +51,7 @@

- Whisper model: + Whisper model: @@ -64,7 +64,7 @@
- GPT-2 model: + GPT-2 model: @@ -158,20 +158,8 @@
+