sync : whisper.cpp

- Add MSVC header
- FP16 GELU
- C interface fixes (no unions)
- Minor CMake updates
pull/12/head
Georgi Gerganov 2 years ago
parent 270829aa9f
commit 7094be1f37
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -25,7 +25,7 @@ option(GGML_SANITIZE_UNDEFINED "ggml: enable undefined sanitizer" OFF)
option(GGML_BUILD_TESTS "ggml: build tests" ${GGML_STANDALONE}) option(GGML_BUILD_TESTS "ggml: build tests" ${GGML_STANDALONE})
option(GGML_BUILD_EXAMPLES "ggml: build examples" ${GGML_STANDALONE}) option(GGML_BUILD_EXAMPLES "ggml: build examples" ${GGML_STANDALONE})
option(GGML_PERF "ggml: enable perf timings" ${GGML_PERF}) option(GGML_PERF "ggml: enable perf timings" OFF)
option(GGML_NO_ACCELERATE "ggml: disable Accelerate framework" OFF) option(GGML_NO_ACCELERATE "ggml: disable Accelerate framework" OFF)
# sanitizers # sanitizers

@ -216,7 +216,7 @@ int main(int argc, char ** argv) {
// run the inference // run the inference
{ {
whisper_full_params wparams = whisper_full_default_params(WHISPER_DECODE_GREEDY); whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.print_realtime = true; wparams.print_realtime = true;
wparams.print_progress = false; wparams.print_progress = false;

@ -2256,51 +2256,63 @@ void whisper_print_timings(struct whisper_context * ctx) {
//////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////
struct whisper_full_params whisper_full_default_params(enum whisper_decode_strategy strategy) { struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
struct whisper_full_params result; struct whisper_full_params result;
switch (strategy) { switch (strategy) {
case WHISPER_DECODE_GREEDY: case WHISPER_SAMPLING_GREEDY:
{ {
result = { result = {
.strategy = WHISPER_DECODE_GREEDY, /*.strategy =*/ WHISPER_SAMPLING_GREEDY,
.n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()),
.offset_ms = 0,
.translate = false, /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
.no_context = false, /*.offset_ms =*/ 0,
.print_special_tokens = false,
.print_progress = true,
.print_realtime = false,
.print_timestamps = true,
.language = "en", /*.translate =*/ false,
/*.no_context =*/ false,
/*.print_special_tokens =*/ false,
/*.print_progress =*/ true,
/*.print_realtime =*/ false,
/*.print_timestamps =*/ true,
.greedy = { /*.language =*/ "en",
.n_past = 0,
/*.greedy =*/ {
/*.n_past =*/ 0,
},
/*.beam_search =*/ {
/*.n_past =*/ -1,
/*.beam_width =*/ -1,
/*.n_best =*/ -1,
}, },
}; };
} break; } break;
case WHISPER_DECODE_BEAM_SEARCH: case WHISPER_SAMPLING_BEAM_SEARCH:
{ {
result = { result = {
.strategy = WHISPER_DECODE_GREEDY, /*.strategy =*/ WHISPER_SAMPLING_BEAM_SEARCH,
.n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()),
.offset_ms = 0, /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
/*.offset_ms =*/ 0,
.translate = false,
.no_context = false, /*.translate =*/ false,
.print_special_tokens = false, /*.no_context =*/ false,
.print_progress = true, /*.print_special_tokens =*/ false,
.print_realtime = false, /*.print_progress =*/ true,
.print_timestamps = true, /*.print_realtime =*/ false,
/*.print_timestamps =*/ true,
.language = "en",
/*.language =*/ "en",
.beam_search = {
.n_past = 0, /*.greedy =*/ {
.beam_width = 10, /*.n_past =*/ -1,
.n_best = 5, },
/*.beam_search =*/ {
/*.n_past =*/ 0,
/*.beam_width =*/ 10,
/*.n_best =*/ 5,
}, },
}; };
} break; } break;
@ -2425,7 +2437,7 @@ int whisper_full(
whisper_token id = 0; whisper_token id = 0;
whisper_token tid = whisper_token_beg(ctx); whisper_token tid = whisper_token_beg(ctx);
id = whisper_sample_best(ctx, result_len == 0 || i > 32); id = whisper_sample_best(ctx, result_len == 0);
if (i > 0) { if (i > 0) {
tid = whisper_sample_timestamp(ctx); tid = whisper_sample_timestamp(ctx);
} }
@ -2445,9 +2457,12 @@ int whisper_full(
// end of text token // end of text token
if (id == whisper_token_eot(ctx)) { if (id == whisper_token_eot(ctx)) {
if (result_len == 0) { if (result_len == 0) {
if (seek + seek_delta + 100 >= whisper_n_len(ctx)) {
result_len = i + 1;
} else {
// TODO: figure out how to resolve this // TODO: figure out how to resolve this
fprintf(stderr, "\n%s: failed to generate timestamp token - this should not happen\n\n", __func__); fprintf(stderr, "\n%s: failed to generate timestamp token - this should not happen\n\n", __func__);
//result_len = i + 1; }
} }
break; break;
} }

@ -31,7 +31,8 @@ extern "C" {
// //
// C interface // C interface
// //
// The following interface is thread-safe as long as the sample whisper_context is not used by multiple threads
// concurrently.
// //
// Basic usage: // Basic usage:
// //
@ -153,14 +154,14 @@ extern "C" {
//////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////
// Available decoding strategies // Available sampling strategies
enum whisper_decode_strategy { enum whisper_sampling_strategy {
WHISPER_DECODE_GREEDY, // Always select the most probable token WHISPER_SAMPLING_GREEDY, // Always select the most probable token
WHISPER_DECODE_BEAM_SEARCH, // TODO: not implemented yet! WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet!
}; };
struct whisper_full_params { struct whisper_full_params {
enum whisper_decode_strategy strategy; enum whisper_sampling_strategy strategy;
int n_threads; int n_threads;
int offset_ms; int offset_ms;
@ -174,7 +175,6 @@ extern "C" {
const char * language; const char * language;
union {
struct { struct {
int n_past; int n_past;
} greedy; } greedy;
@ -185,9 +185,8 @@ extern "C" {
int n_best; int n_best;
} beam_search; } beam_search;
}; };
};
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_decode_strategy strategy); WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
// Uses the specified decoding strategy to obtain the text. // Uses the specified decoding strategy to obtain the text.

@ -14,7 +14,6 @@
#include <stdint.h> #include <stdint.h>
#include <stdio.h> #include <stdio.h>
#if defined _MSC_VER #if defined _MSC_VER
#include "msvc_thread_atomic.h" #include "msvc_thread_atomic.h"
#else #else
@ -24,6 +23,7 @@ typedef void* thread_ret_t;
#endif #endif
#define GGML_DEBUG 0 #define GGML_DEBUG 0
#define GGML_GELU_FP16
#if UINTPTR_MAX == 0xFFFFFFFF #if UINTPTR_MAX == 0xFFFFFFFF
#define GGML_MEM_ALIGN 4 #define GGML_MEM_ALIGN 4
@ -723,20 +723,22 @@ inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp
} }
} }
#ifdef GGML_GELU_FP16
inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
uint16_t t; uint16_t t;
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
ggml_fp16_t fp16 = ggml_fp32_to_fp16(x[i]); ggml_fp16_t fp16 = ggml_fp32_to_fp16(x[i]);
memcpy(&t, &fp16, sizeof(uint16_t)); memcpy(&t, &fp16, sizeof(uint16_t));
y[i] = table_gelu_f16[t]; y[i] = ggml_fp16_to_fp32(table_gelu_f16[t]);
} }
} }
#else
//inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
// for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
// y[i] = ggml_gelu_f32(x[i]); y[i] = ggml_gelu_f32(x[i]);
// } }
//} }
#endif
inline static void ggml_vec_sum_f32 (const int n, float * s, const float * x) { ggml_float sum = 0.0; for (int i = 0; i < n; ++i) sum += x[i]; *s += sum; } inline static void ggml_vec_sum_f32 (const int n, float * s, const float * x) { ggml_float sum = 0.0; for (int i = 0; i < n; ++i) sum += x[i]; *s += sum; }
inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { ggml_vec_norm_f32(n, s, x); *s = 1./(*s); } inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { ggml_vec_norm_f32(n, s, x); *s = 1./(*s); }

@ -0,0 +1,31 @@
#pragma once
#include <Windows.h>
typedef volatile LONG atomic_int;
typedef atomic_int atomic_bool;
static void atomic_store(atomic_int* ptr, LONG val) {
InterlockedExchange(ptr, val);
}
static LONG atomic_load(atomic_int* ptr) {
return InterlockedCompareExchange(ptr, 0, 0);
}
static LONG atomic_fetch_add(atomic_int* ptr, LONG inc) {
return InterlockedExchangeAdd(ptr, inc);
}
static LONG atomic_fetch_sub(atomic_int* ptr, LONG dec) {
return atomic_fetch_add(ptr, -(dec));
}
typedef HANDLE pthread_t;
typedef DWORD thread_ret_t;
static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) {
out = CreateThread(NULL, 0, func, arg, 0, NULL);
return out != NULL;
}
static int pthread_join(pthread_t thread, void* unused) {
return (int) WaitForSingleObject(thread, INFINITE);
}
Loading…
Cancel
Save