From 7094be1f37db71a63f4dd4c99e5ed415feb8be0b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 18 Oct 2022 19:12:07 +0300 Subject: [PATCH] sync : whisper.cpp - Add MSVC header - FP16 GELU - C interface fixes (no unions) - Minor CMake updates --- CMakeLists.txt | 4 +- examples/whisper/main.cpp | 2 +- examples/whisper/whisper.cpp | 87 +++++++++++++++++++++--------------- examples/whisper/whisper.h | 35 +++++++-------- src/ggml.c | 18 ++++---- src/msvc_thread_atomic.h | 31 +++++++++++++ 6 files changed, 112 insertions(+), 65 deletions(-) create mode 100644 src/msvc_thread_atomic.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 73e174c..d2f95cc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,7 +15,7 @@ endif() # options -option(GGML_ALL_WARNINGS "ggml: enable all compiler warnings" ON) +option(GGML_ALL_WARNINGS "ggml: enable all compiler warnings" ON) option(GGML_ALL_WARNINGS_3RD_PARTY "ggml: enable all compiler warnings in 3rd party libs" OFF) option(GGML_SANITIZE_THREAD "ggml: enable thread sanitizer" OFF) @@ -25,7 +25,7 @@ option(GGML_SANITIZE_UNDEFINED "ggml: enable undefined sanitizer" OFF) option(GGML_BUILD_TESTS "ggml: build tests" ${GGML_STANDALONE}) option(GGML_BUILD_EXAMPLES "ggml: build examples" ${GGML_STANDALONE}) -option(GGML_PERF "ggml: enable perf timings" ${GGML_PERF}) +option(GGML_PERF "ggml: enable perf timings" OFF) option(GGML_NO_ACCELERATE "ggml: disable Accelerate framework" OFF) # sanitizers diff --git a/examples/whisper/main.cpp b/examples/whisper/main.cpp index b913522..995eefc 100644 --- a/examples/whisper/main.cpp +++ b/examples/whisper/main.cpp @@ -216,7 +216,7 @@ int main(int argc, char ** argv) { // run the inference { - whisper_full_params wparams = whisper_full_default_params(WHISPER_DECODE_GREEDY); + whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); wparams.print_realtime = true; wparams.print_progress = false; diff --git a/examples/whisper/whisper.cpp b/examples/whisper/whisper.cpp index 9885278..236fcf1 100644 --- a/examples/whisper/whisper.cpp +++ b/examples/whisper/whisper.cpp @@ -2256,51 +2256,63 @@ void whisper_print_timings(struct whisper_context * ctx) { //////////////////////////////////////////////////////////////////////////// -struct whisper_full_params whisper_full_default_params(enum whisper_decode_strategy strategy) { +struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) { struct whisper_full_params result; switch (strategy) { - case WHISPER_DECODE_GREEDY: + case WHISPER_SAMPLING_GREEDY: { result = { - .strategy = WHISPER_DECODE_GREEDY, - .n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()), - .offset_ms = 0, + /*.strategy =*/ WHISPER_SAMPLING_GREEDY, - .translate = false, - .no_context = false, - .print_special_tokens = false, - .print_progress = true, - .print_realtime = false, - .print_timestamps = true, + /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), + /*.offset_ms =*/ 0, - .language = "en", + /*.translate =*/ false, + /*.no_context =*/ false, + /*.print_special_tokens =*/ false, + /*.print_progress =*/ true, + /*.print_realtime =*/ false, + /*.print_timestamps =*/ true, - .greedy = { - .n_past = 0, + /*.language =*/ "en", + + /*.greedy =*/ { + /*.n_past =*/ 0, + }, + + /*.beam_search =*/ { + /*.n_past =*/ -1, + /*.beam_width =*/ -1, + /*.n_best =*/ -1, }, }; } break; - case WHISPER_DECODE_BEAM_SEARCH: + case WHISPER_SAMPLING_BEAM_SEARCH: { result = { - .strategy = WHISPER_DECODE_GREEDY, - .n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()), - .offset_ms = 0, - - .translate = false, - .no_context = false, - .print_special_tokens = false, - .print_progress = true, - .print_realtime = false, - .print_timestamps = true, - - .language = "en", - - .beam_search = { - .n_past = 0, - .beam_width = 10, - .n_best = 5, + /*.strategy =*/ WHISPER_SAMPLING_BEAM_SEARCH, + + /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), + /*.offset_ms =*/ 0, + + /*.translate =*/ false, + /*.no_context =*/ false, + /*.print_special_tokens =*/ false, + /*.print_progress =*/ true, + /*.print_realtime =*/ false, + /*.print_timestamps =*/ true, + + /*.language =*/ "en", + + /*.greedy =*/ { + /*.n_past =*/ -1, + }, + + /*.beam_search =*/ { + /*.n_past =*/ 0, + /*.beam_width =*/ 10, + /*.n_best =*/ 5, }, }; } break; @@ -2425,7 +2437,7 @@ int whisper_full( whisper_token id = 0; whisper_token tid = whisper_token_beg(ctx); - id = whisper_sample_best(ctx, result_len == 0 || i > 32); + id = whisper_sample_best(ctx, result_len == 0); if (i > 0) { tid = whisper_sample_timestamp(ctx); } @@ -2445,9 +2457,12 @@ int whisper_full( // end of text token if (id == whisper_token_eot(ctx)) { if (result_len == 0) { - // TODO: figure out how to resolve this - fprintf(stderr, "\n%s: failed to generate timestamp token - this should not happen\n\n", __func__); - //result_len = i + 1; + if (seek + seek_delta + 100 >= whisper_n_len(ctx)) { + result_len = i + 1; + } else { + // TODO: figure out how to resolve this + fprintf(stderr, "\n%s: failed to generate timestamp token - this should not happen\n\n", __func__); + } } break; } diff --git a/examples/whisper/whisper.h b/examples/whisper/whisper.h index 381afd7..45faa5b 100644 --- a/examples/whisper/whisper.h +++ b/examples/whisper/whisper.h @@ -31,7 +31,8 @@ extern "C" { // // C interface // - + // The following interface is thread-safe as long as the sample whisper_context is not used by multiple threads + // concurrently. // // Basic usage: // @@ -153,14 +154,14 @@ extern "C" { //////////////////////////////////////////////////////////////////////////// - // Available decoding strategies - enum whisper_decode_strategy { - WHISPER_DECODE_GREEDY, // Always select the most probable token - WHISPER_DECODE_BEAM_SEARCH, // TODO: not implemented yet! + // Available sampling strategies + enum whisper_sampling_strategy { + WHISPER_SAMPLING_GREEDY, // Always select the most probable token + WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet! }; struct whisper_full_params { - enum whisper_decode_strategy strategy; + enum whisper_sampling_strategy strategy; int n_threads; int offset_ms; @@ -174,20 +175,18 @@ extern "C" { const char * language; - union { - struct { - int n_past; - } greedy; - - struct { - int n_past; - int beam_width; - int n_best; - } beam_search; - }; + struct { + int n_past; + } greedy; + + struct { + int n_past; + int beam_width; + int n_best; + } beam_search; }; - WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_decode_strategy strategy); + WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy); // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text // Uses the specified decoding strategy to obtain the text. diff --git a/src/ggml.c b/src/ggml.c index 7f11c96..4861f24 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -14,7 +14,6 @@ #include #include - #if defined _MSC_VER #include "msvc_thread_atomic.h" #else @@ -24,6 +23,7 @@ typedef void* thread_ret_t; #endif #define GGML_DEBUG 0 +#define GGML_GELU_FP16 #if UINTPTR_MAX == 0xFFFFFFFF #define GGML_MEM_ALIGN 4 @@ -723,20 +723,22 @@ inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp } } +#ifdef GGML_GELU_FP16 inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { uint16_t t; for (int i = 0; i < n; ++i) { ggml_fp16_t fp16 = ggml_fp32_to_fp16(x[i]); memcpy(&t, &fp16, sizeof(uint16_t)); - y[i] = table_gelu_f16[t]; + y[i] = ggml_fp16_to_fp32(table_gelu_f16[t]); } } - -//inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { -// for (int i = 0; i < n; ++i) { -// y[i] = ggml_gelu_f32(x[i]); -// } -//} +#else +inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { + for (int i = 0; i < n; ++i) { + y[i] = ggml_gelu_f32(x[i]); + } +} +#endif inline static void ggml_vec_sum_f32 (const int n, float * s, const float * x) { ggml_float sum = 0.0; for (int i = 0; i < n; ++i) sum += x[i]; *s += sum; } inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { ggml_vec_norm_f32(n, s, x); *s = 1./(*s); } diff --git a/src/msvc_thread_atomic.h b/src/msvc_thread_atomic.h new file mode 100644 index 0000000..52cd419 --- /dev/null +++ b/src/msvc_thread_atomic.h @@ -0,0 +1,31 @@ +#pragma once +#include + +typedef volatile LONG atomic_int; +typedef atomic_int atomic_bool; + +static void atomic_store(atomic_int* ptr, LONG val) { + InterlockedExchange(ptr, val); +} +static LONG atomic_load(atomic_int* ptr) { + return InterlockedCompareExchange(ptr, 0, 0); +} +static LONG atomic_fetch_add(atomic_int* ptr, LONG inc) { + return InterlockedExchangeAdd(ptr, inc); +} +static LONG atomic_fetch_sub(atomic_int* ptr, LONG dec) { + return atomic_fetch_add(ptr, -(dec)); +} + +typedef HANDLE pthread_t; + +typedef DWORD thread_ret_t; +static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) { + out = CreateThread(NULL, 0, func, arg, 0, NULL); + return out != NULL; +} + +static int pthread_join(pthread_t thread, void* unused) { + return (int) WaitForSingleObject(thread, INFINITE); +} +